All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.nn.LSTMLayer Maven / Gradle / Ivy

package com.expleague.ml.methods.seq.nn;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.random.FastRandom;

public class LSTMLayer implements NetworkLayer {
  private final LSTMNode[] nodes;
  private Vec[][] nodeInputs;

  public LSTMLayer(int nodeCount, int inputDim, FastRandom random) {
    nodes = new LSTMNode[nodeCount];
    for (int i = 0; i < nodeCount; i++) {
      nodes[i] = new LSTMNode(inputDim, random);
    }
  }

  /**
   *
   * @param x i-th row of x is a signal value at the moment i
   * @return output.get(i, j) is a signal value of j-th node at the moment i
   */
  @Override
  public Mx value(Mx x) {
    Mx result = new VecBasedMx(x.rows(), nodes.length);
    nodeInputs = new Vec[nodes.length][x.rows()];

    for (int node = 0; node < nodes.length; node++) {
      Vec lastNodeOutput = new ArrayVec(2);
      for (int moment = 0; moment < x.rows(); moment++) {
        nodeInputs[node][moment] = VecTools.concat(x.row(moment), lastNodeOutput);
        lastNodeOutput = nodes[node].value(nodeInputs[node][moment]);
        result.set(moment, node, lastNodeOutput.get(0));
      }
    }

    return result;
  }

  //Fixme: for now assuming that this is the first layer in the network
  @Override
  public LayerGrad gradient(Mx x, Mx outputGrad, boolean isAfterValue) {
    if (!isAfterValue) {
      value(x);
    }

    final int paramCount = nodes[0].params().dim();
    Vec gradByParams = new ArrayVec(nodes.length * paramCount);
    Mx gradByInput = new VecBasedMx(x.rows(), x.columns());

    for (int node = 0; node < nodes.length; node++) {
      Vec nodeOutputGrad = new ArrayVec(2);
      for (int moment = x.rows() - 1; moment >= 0; moment--) {
        nodeOutputGrad.adjust(0, outputGrad.get(moment, node));
        NetworkNode.NodeGrad grad = nodes[node].grad(nodeInputs[node][moment], nodeOutputGrad);
        nodeOutputGrad = grad.gradByInput;
        VecTools.incscale(gradByParams.sub(node * paramCount, paramCount), grad.gradByParams, 1);
      }
    }

    return new LayerGrad(gradByParams, gradByInput);
  }

  @Override
  public void adjustParams(Vec dW) {
    final int nodeParamCount = nodes[0].params().dim();
    for (int i = 0; i < nodes.length; i++) {
      VecTools.append(nodes[i].params(), dW.sub(i * nodeParamCount, nodeParamCount));
    }
  }

  @Override
  public void setParams(Vec newW) {
    final int nodeParamCount = nodes[0].params().dim();
    for (int i = 0; i < nodes.length; i++) {
      VecTools.assign(nodes[i].params(), newW.sub(i * nodeParamCount, nodeParamCount));
    }
  }

  @Override
  public int paramCount() {
    return nodes.length * nodes[0].params().dim();
  }

  @Override
  public Vec paramsView() {
    Vec[] params = new Vec[nodes.length];
    for (int i =0; i < nodes.length; i++) {
      params[i] = nodes[i].params();
    }
    return VecTools.join(params);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy