All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.nn.LogisticLayer Maven / Gradle / Ivy

package com.expleague.ml.methods.seq.nn;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.random.FastRandom;

public class LogisticLayer implements NetworkLayer {
  private final NetworkNode[] nodes;

  public LogisticLayer(int nodeCount, int inputDim, FastRandom random) {
    nodes = new NetworkNode[nodeCount];
    for (int i = 0; i < nodeCount; i++) {
      nodes[i] = new LogisticNode(inputDim, random);
    }
  }
  @Override
  public Mx value(Mx x) {
    Mx result = new VecBasedMx(x.rows(), nodes.length);
    for (int i = 0; i < x.rows(); i++) {
      for (int j = 0; j < nodes.length; j++) {
        result.set(i, j, nodes[j].value(x.row(i)).get(0));
      }
    }
    return result;
  }

  @Override
  public LayerGrad gradient(Mx x, Mx outputGrad, boolean isAfterValue) {
    final Vec gradByParam = new ArrayVec(paramCount());
    final Mx gradByInput = new VecBasedMx(x.rows(), x.columns());
    final int nodeParamCount = nodes[0].params().dim();
    for (int i = 0; i < x.rows(); i++) {
      for (int j = 0; j < nodes.length; j++) {
        NetworkNode.NodeGrad grad = nodes[j].grad(x.row(i), new SingleValueVec(outputGrad.get(i, j)));
        VecTools.append(gradByParam.sub(nodeParamCount * j, nodeParamCount), grad.gradByParams);
        VecTools.append(gradByInput.row(i), grad.gradByInput);
      }
    }
    return new LayerGrad(gradByParam, gradByInput);
  }

  @Override
  public void adjustParams(Vec dW) {
    final int nodeParamCount = nodes[0].params().dim();
    for (int i = 0; i < nodes.length; i++) {
      VecTools.append(nodes[i].params(), dW.sub(i * nodeParamCount, nodeParamCount));
    }
  }

  @Override
  public void setParams(Vec newW) {
    final int nodeParamCount = nodes[0].params().dim();
    for (int i = 0; i < nodes.length; i++) {
      VecTools.assign(nodes[i].params(), newW.sub(i * nodeParamCount, nodeParamCount));
    }
  }

  @Override
  public int paramCount() {
    return nodes[0].params().dim() * nodes.length;
  }

  @Override
  public Vec paramsView() {
    Vec[] params = new Vec[nodes.length];
    for (int i =0; i < nodes.length; i++) {
      params[i] = nodes[i].params();
    }
    return VecTools.join(params);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy