All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.nn.nfa.WeightsCalculator Maven / Gradle / Ivy

package com.expleague.ml.models.nn.nfa;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.ThreadLocalArrayVec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;

import java.util.function.Function;

/**
* User: solar
* Date: 29.06.15
* Time: 17:12
*/
class WeightsCalculator implements Function {
  private final int statesCount;
  private final int wStart;
  private final int wLen;
  private boolean[] dropOut;

  WeightsCalculator(int statesCount, int wStart, int wLen) {
    this.statesCount = statesCount;
    this.wStart = wStart;
    this.wLen = wLen;
  }

  final ThreadLocalArrayVec w = new ThreadLocalArrayVec();
  public Mx computeInner(Vec betta) {
    final VecBasedMx b = new VecBasedMx(statesCount - 1, betta.sub(wStart, wLen));
    final VecBasedMx w = new VecBasedMx(statesCount, this.w.get(statesCount * statesCount));
    for (int i = 0; i < statesCount - 1; i++) {
      if (dropOut[i])
        continue;
      double sum = 1;
      for (int j = 0; j < statesCount - 1; j++) {
        if (dropOut[j])
          continue;
        sum += Math.exp(b.get(i, j));
      }
      for (int j = 0; j < statesCount; j++) {
        if (dropOut[j])
          continue;
        final double selectedExp = j < statesCount - 1 ? Math.exp(b.get(i, j)) : 1;
        w.set(j, i, selectedExp / sum);
      }
    }
    return w;
  }

  private Vec cacheArg;
  private Mx cacheVal;
  @Override
  public Mx apply(Vec betta) {
    if (!betta.isImmutable())
      return computeInner(betta);
    if (betta == cacheArg)
      return cacheVal;
    cacheArg = betta;
    return cacheVal = computeInner(betta);
  }

  public void setDropOut(boolean[] dropOut) {
    cacheArg = null;
    this.dropOut = dropOut;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy