All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.param.BettaTwoVecParametrization Maven / Gradle / Ivy

package com.expleague.ml.methods.seq.param;

import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;

public class BettaTwoVecParametrization implements BettaParametrization {
  private final double addToDiag;

  public BettaTwoVecParametrization(double addToDiag) {
    this.addToDiag = addToDiag;
  }

  @Override
  public int paramCount(int stateCount) {
    return 2 * stateCount;
  }

  @Override
  public Mx getBettaMx(Vec params, int c, int stateCount) {
    int bettaSize = paramCount(stateCount);
    Mx betta = new VecBasedMx(stateCount, stateCount);

    for (int i = 0; i < stateCount; i++) {
      for (int j = 0; j < stateCount; j++) {
        double value = Math.min(1e9, Math.max(-1e9, params.get(c * bettaSize + i) * params.get(c * bettaSize + stateCount + j)));
        if (i == j) {
          value += addToDiag;
        }
        betta.set(i, j, value);
      }
    }

    return betta;
  }

  @Override
  public void gradientTo(Vec params, Mx dBetta, Vec dParam, int c, int stateCount) {
    int bettaSize = paramCount(stateCount);

    final Vec v = params.sub(c * bettaSize, stateCount);
    final Vec u = params.sub(c * bettaSize + stateCount, stateCount);

    final Vec dv = dParam.sub(0, stateCount);
    final Vec du = dParam.sub(stateCount, stateCount);

    for (int k = 0; k < dBetta.rows(); k++) {
      for (int s = 0; s < dBetta.columns(); s++) {
        double m = dBetta.get(k, s);
        du.adjust(s, v.get(k) * m);
        dv.adjust(k, u.get(s) * m);
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy