com.expleague.ml.methods.seq.nn.NeuralNetwork Maven / Gradle / Ivy
package com.expleague.ml.methods.seq.nn;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
public class NeuralNetwork {
private final NetworkLayer[] layers;
private final int[] prefixParamCount;
public NeuralNetwork(NetworkLayer ...layers) {
this.layers = layers;
prefixParamCount = new int[layers.length + 1];
for (int i = 0; i < layers.length; i++) {
prefixParamCount[i + 1] = prefixParamCount[i] + layers[i].paramCount();
}
}
public Mx value(Mx input) {
Mx cur = input;
for (NetworkLayer layer: layers) {
cur = layer.value(cur);
}
return cur;
}
public Vec gradByParams(Mx input, Mx outputGrad, boolean isAfterValue) {
Mx inputs[] = new Mx[layers.length + 1];
inputs[0] = input;
for (int i = 0; i < layers.length; i++) {
inputs[i + 1] = layers[i].value(inputs[i]);
}
final Vec paramsGrad = new ArrayVec(prefixParamCount[layers.length]);
for (int i = layers.length - 1; i >= 0; i--) {
NetworkLayer.LayerGrad grad = layers[i].gradient(inputs[i], outputGrad, isAfterValue);
VecTools.append(paramsGrad.sub(prefixParamCount[i], layers[i].paramCount()), grad.gradByParams);
outputGrad = grad.gradByInput;
}
return paramsGrad;
}
public void adjustParams(Vec dW) {
for (int i = 0; i < layers.length; i++) {
layers[i].adjustParams(dW.sub(prefixParamCount[i], layers[i].paramCount()));
}
}
public void setParams(Vec dW) {
for (int i = 0; i < layers.length; i++) {
layers[i].setParams(dW.sub(prefixParamCount[i], layers[i].paramCount()));
}
}
public int paramCount() {
return prefixParamCount[layers.length];
}
public Vec paramsView() {
Vec[] params = new Vec[layers.length];
for (int i = 0; i < layers.length; i++) {
params[i] = layers[i].paramsView();
}
return VecTools.join(params);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy