All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.nn.MeanPoolLayer Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.seq.nn;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;

public class MeanPoolLayer implements NetworkLayer {
  @Override
  public Mx value(Mx x) {
    Mx result = new VecBasedMx(1, x.columns());
    for (int i = 0; i < x.rows(); i++) {
      for (int j = 0; j < x.columns(); j++) {
        result.adjust(0, j, x.get(i, j));
      }
    }
    VecTools.scale(result, 1.0 / x.rows());
    return result;
  }

  @Override
  public LayerGrad gradient(Mx x, Mx outputGrad, boolean isAfterValue) {
    Mx grad = new VecBasedMx(x.rows(), x.columns());
    for (int i = 0; i < x.rows(); i++) {
      for (int j = 0; j < x.columns(); j++) {
        grad.set(i, j, outputGrad.get(0, j) / x.rows());
      }
    }
    return new LayerGrad(new ArrayVec(), grad);
  }

  @Override
  public void adjustParams(Vec dW) {

  }

  @Override
  public void setParams(Vec newW) {

  }

  @Override
  public int paramCount() {
    return 0;
  }

  @Override
  public Vec paramsView() {
    return new ArrayVec();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy