All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.nn.nodes.FCNode Maven / Gradle / Ivy

package com.expleague.ml.models.nn.nodes;

import com.expleague.commons.math.AnalyticFunc;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.ml.models.nn.layers.Layer;

import static com.expleague.ml.models.nn.NeuralSpider.BackwardNode;
import static com.expleague.ml.models.nn.NeuralSpider.ForwardNode;

public class FCNode implements Layer.Node {
  private final int prevStateStart;
  private final int prevStateLength;
  private final int weightStart;
  private final int biasStart;
  private final int weightPerState;
  private final int layerStart;
  private final int layerLength;
  private final AnalyticFunc activation;

  public FCNode(int layerStart, int layerLength,
                int prevStateStart, int prevStateLength,
                int weightStart, int weightLength,
                AnalyticFunc activation) {
    this.layerStart = layerStart;
    this.layerLength = layerLength;

    this.prevStateStart = prevStateStart;
    this.prevStateLength = prevStateLength;
    this.weightStart = weightStart;

    final int mxWeightLength = weightLength - layerLength;
    this.weightPerState = mxWeightLength / layerLength;
    this.biasStart = weightStart + mxWeightLength;

    this.activation = activation;
  }

  @Override
  public ForwardNode forward() {
    return new ForwardUnit();
  }

  @Override
  public BackwardNode backward() {
    return new BackwardCalcer();
  }

  @Override
  public BackwardNode gradient() {
    return new GradCalcer();
  }

  private class ForwardUnit implements ForwardNode {
    @Override
    public double apply(Vec state, Vec betta, int nodeIdx) {
      final int localIdx = nodeIdx - layerStart;
      final int wStart = weightStart + localIdx * weightPerState;
      return VecTools.multiply(state.sub(prevStateStart, prevStateLength),
          betta.sub(wStart, prevStateLength)) + betta.get(biasStart + localIdx);
    }

    @Override
    public double activate(double value) {
      return activation.value(value);
    }

    @Override
    public double grad(double value) {
      return activation.gradient(value);
    }

    @Override
    public int start(int nodeIdx) {
      return prevStateStart;
    }

    @Override
    public int end(int nodeIdx) {
      return prevStateStart + prevStateLength;
    }
  }

  private class BackwardCalcer implements BackwardNode {
    /** It is assumed that state is already computed by forward pass. */

    @Override
    public double apply(Vec state, Vec gradState, Vec gradAct, Vec betta, int nodeIdx) {
      final int localIdx = nodeIdx - prevStateStart;
      double result = 0.;
      for (int i = 0; i < layerLength; i++) {
        result += gradState.get(layerStart + i)
            * gradAct.get(layerStart + i)
            * betta.get(weightStart + weightPerState * i + localIdx);
      }
      return result;
    }

    @Override
    public int start(int nodeIdx) {
      return layerStart;
    }

    @Override
    public int end(int nodeIdx) {
      return layerStart + layerLength;
    }
  }

  private class GradCalcer implements BackwardNode {
    @Override
    public double apply(Vec state, Vec gradState, Vec gradAct, Vec betta, int wIdx) {
      if (wIdx < biasStart) {
        final int localIdx = wIdx - weightStart;
        final int stateAfterIdx = localIdx / weightPerState;
        final int stateBeforeIdx = localIdx % weightPerState;

        final double gradS = gradState.get(layerStart + stateAfterIdx);
        final double gradA = gradAct.get(layerStart + stateAfterIdx);
        final double s = state.get(prevStateStart + stateBeforeIdx);
        return gradS * gradA * s;
      }

      final int localIdx = wIdx - biasStart;
      assert(localIdx < layerLength);
      final double dTds = gradState.get(layerStart + localIdx);
      final double dAct = gradAct.get(layerStart + localIdx);
      return dTds * dAct;
    }

    @Override
    public int start(int nodeIdx) {
      return -1;
    }

    @Override
    public int end(int nodeIdx) {
      return -1;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy