All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.nn.LSTMNode Maven / Gradle / Ivy

package com.expleague.ml.methods.seq.nn;

import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.func.generic.Sigmoid;
import com.expleague.ml.func.generic.Tanh;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.random.FastRandom;

public class LSTMNode implements NetworkNode {
  private final Vec params;
  private final Vec wForget;
  private final Vec bForget;
  private final Vec wInput;
  private final Vec bInput;
  private final Vec wCandidate;
  private final Vec bCandidate;
  private final Vec wOutput;
  private final Vec bOutput;

  private final int forgetStart;
  private final int inputStart;
  private final int candidateStart;
  private final int outputStart;

  /**
   *
   * @param inputDim xdim of input (not including previous cell value)
   * @param random
   */
  public LSTMNode(int inputDim, FastRandom random) {
    inputDim += 1; //for previous cell value

    params = new ArrayVec(4 * (inputDim + 1));
    for (int i = 0; i < params.dim(); i++) {
      if (i % (inputDim + 1) != 0) {
        params.set(i, random.nextGaussian() / inputDim);
      }
    }
    forgetStart = 0;
    wForget = params.sub(forgetStart, inputDim);
    bForget = params.sub(forgetStart + inputDim, 1);

    inputStart = inputDim + 1;
    wInput = params.sub(inputStart, inputDim);
    bInput = params.sub(inputStart + inputDim, 1);

    candidateStart = 2 * (inputDim + 1);
    wCandidate = params.sub(candidateStart, inputDim);
    bCandidate = params.sub(candidateStart + inputDim, 1);

    outputStart = 3 * (inputDim + 1);
    wOutput = params.sub(outputStart, inputDim);
    bOutput = params.sub(outputStart + inputDim, 1);
  }

  @Override
  public Vec params() {
    return params;
  }

  /**
   *
   * @param x
   * @return gradient by input - vector of (grad by prev output, grad by prev cell value), gradient by params
   */
  //Fixme: not calculating gradient by input, only by previous node values
  @Override
  public NetworkNode.NodeGrad grad(Vec x, Vec nodeOutputGrad) {
    final FunctionValues values = new FunctionValues(x);
    final Vec gradByParams = new ArrayVec(params.dim());
    final Vec gradByInput = new ArrayVec(2);
    final Tanh tanh = new Tanh();

    final Vec xAndOutput = x.sub(0, x.dim() - 1);
    final double prevCellValue = x.get(x.dim() - 1);

    final double outputGrad = tanh.value(new SingleValueVec(values.cellValue)); //d cellOutput / d output

    adjustGrad(gradByParams, outputStart, xAndOutput, outputGrad * values.outputGrad * nodeOutputGrad.get(0));


    final double cellValueGrad = values.output * tanh.gradient(new SingleValueVec(values.cellValue)).get(0);

    final double forgetGrad = cellValueGrad * prevCellValue; //d cellOutput / d forget
    adjustGrad(gradByParams, forgetStart, xAndOutput, forgetGrad * values.forgetGrad * nodeOutputGrad.get(0));
    adjustGrad(gradByParams, forgetStart, xAndOutput, prevCellValue * values.forgetGrad * nodeOutputGrad.get(1));


    final double inputGrad = cellValueGrad * values.candidate; //d cellOutput / d input
    final double candidateGrad = cellValueGrad * values.input;//d cellOutput / d candidate

    adjustGrad(gradByParams, inputStart, xAndOutput, inputGrad * values.inputGrad * nodeOutputGrad.get(0));
    adjustGrad(gradByParams, inputStart, xAndOutput, values.candidate * values.inputGrad * nodeOutputGrad.get(1));

    adjustGrad(gradByParams, candidateStart, xAndOutput, candidateGrad * values.candidateGrad * nodeOutputGrad.get(0));
    adjustGrad(gradByParams, candidateStart, xAndOutput, values.input * values.candidateGrad * nodeOutputGrad.get(1));


    gradByInput.adjust(0, outputGrad * values.outputGrad * wOutput.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(0));
    gradByInput.adjust(0, forgetGrad * values.forgetGrad * wForget.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(0));
    gradByInput.adjust(0, inputGrad * values.inputGrad * wInput.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(0));
    gradByInput.adjust(0, candidateGrad * values.candidateGrad * wCandidate.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(0));

    gradByInput.adjust(0, prevCellValue * values.forgetGrad * wForget.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(1));
    gradByInput.adjust(0, values.candidate * values.inputGrad * wInput.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(1));
    gradByInput.adjust(0, values.input * values.candidateGrad * wCandidate.get(xAndOutput.dim() - 1) * nodeOutputGrad.get(1));

    gradByInput.set(1, values.forget * nodeOutputGrad.get(1) + cellValueGrad * values.forget * nodeOutputGrad.get(0));

    return new NodeGrad(gradByParams, gradByInput);
  }

  private void adjustGrad(Vec grad, int start, Vec vec, double scale) {
    VecTools.incscale(grad.sub(start, vec.dim()), vec, scale);
    grad.adjust(start + vec.dim(), scale);
  }

  /**
   *
   * @param x vector of (input vector, previous output value, previous cell value)
   * @return vector of (cell output, cell value)
   */
  @Override
  public Vec value(Vec x) {
    final FunctionValues values = new FunctionValues(x);
    return new ArrayVec(values.cellOutput, values.cellValue);
  }

  private class FunctionValues {
    double forget;
    double forgetGrad;
    double input;
    double inputGrad;
    double candidate;
    double candidateGrad;
    double cellValue;
    double output;
    double outputGrad;
    double cellOutput;

    FunctionValues(Vec x) {
      final Sigmoid sigmoid = new Sigmoid();
      final Tanh tanh = new Tanh();
      final Vec xAndOutput = x.sub(0, x.dim() - 1);
      final double prevCellValue = x.get(x.dim() - 1);

      final Vec forgetX = new SingleValueVec(VecTools.multiply(wForget, xAndOutput) + bForget.get(0));
      forget = sigmoid.value(forgetX);
      forgetGrad = sigmoid.gradient(forgetX).get(0);

      final Vec inputX = new SingleValueVec(VecTools.multiply(wInput, xAndOutput) + bInput.get(0));
      input = sigmoid.value(inputX);
      inputGrad = sigmoid.gradient(inputX).get(0);

      final Vec candidateX = new SingleValueVec(VecTools.multiply(wCandidate, xAndOutput) + bCandidate.get(0));
      candidate = tanh.value(candidateX);
      candidateGrad = tanh.gradient(candidateX).get(0);

      cellValue = forget * prevCellValue + input * candidate;

      final Vec outputX = new SingleValueVec(VecTools.multiply(wOutput, xAndOutput) + bOutput.get(0));
      output = sigmoid.value(outputX);
      outputGrad = sigmoid.gradient(outputX).get(0);

      cellOutput = output * tanh.value(new SingleValueVec(cellValue));
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy