All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.seq.PNFAItemVecRegression Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.methods.seq;

import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.MxTools;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.SparseVec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.ml.methods.seq.param.BettaParametrization;
import com.expleague.ml.methods.seq.param.WeightParametrization;

import java.util.Arrays;

public class PNFAItemVecRegression extends FuncC1.Stub {
  private final IntSeq seq;
  private final Vec y;
  private final int stateCount;
  private final int alphabetSize;
  private final int stateDim;
  private final BettaParametrization bettaParametrization;
  private final WeightParametrization weightParametrization;

  public PNFAItemVecRegression(final IntSeq seq,
                               final Vec y,
                               int stateCount,
                               int alphabetSize,
                               int stateDim,
                               BettaParametrization bettaParametrization,
                               WeightParametrization weightParametrization) {
    this.seq = seq;
    this.y = y;

    this.stateCount = stateCount;
    this.alphabetSize = alphabetSize;
    this.stateDim = stateDim;

    this.bettaParametrization = bettaParametrization;
    this.weightParametrization = weightParametrization;

  }

  @Override
  public int dim() {
    return bettaParametrization.paramCount(stateCount) * alphabetSize + stateCount * stateDim;
  }

  @Override
  public double value(Vec betta) {
    return VecTools.sum2(VecTools.subtract(vecValue(betta), y));
  }

  @Override
  public Vec gradient(Vec x) {
    final Vec state = new ArrayVec(stateCount * (seq.length() + 1));
    VecTools.fill(state.sub(0, stateCount), 1.0 / stateCount);
    //System.out.println("CPU Distribution: " + Arrays.toString(distribution.toArray()));
    for (int i = 0; i < seq.length(); i++) {
      Mx weightMx = weightParametrization.getMx(x, seq.intAt(i), stateCount);
      MxTools.multiplyTo(weightMx, state.sub(i * stateCount, stateCount), state.sub((i + 1) * stateCount, stateCount));
    }
    final Mx V = getValues(x);
    final Vec[] dS = new Vec[]{new ArrayVec(stateCount), new ArrayVec((stateCount))};
    final Vec lastLayerGrad = dS[seq.length() % 2];
    final Vec lastLayerState = state.sub(seq.length() * stateCount, stateCount);
    Vec r = MxTools.multiply(V, lastLayerState);
    VecTools.incscale(r, y, -1);

    for (int s = 0; s < stateCount; s++) {
      double sum = 0;
      for (int d = 0; d < stateDim; d++) {
        sum += V.get(d, s) * r.get(d);
      }
      lastLayerGrad.set(s, 2 * sum);
    }

    final Mx dW = new VecBasedMx(stateCount, stateCount);
    final int paramCount = bettaParametrization.paramCount(stateCount);

    int[] chars = seq.stream().sorted().distinct().toArray();
    int[] index = new int[chars.length * paramCount + stateDim * stateCount];
    double[] values = new double[chars.length * paramCount + stateDim * stateCount];

    for (int i = seq.length() - 1; i >= 0; i--) {
      final int c = seq.intAt(i);
      Vec out = dS[(i + 1) % 2];
      Vec in = dS[i % 2];

      weightParametrization.gradientTo(x, state.sub(i * stateCount, stateCount), out, in, dW, c, stateCount);
      Vec grad = new ArrayVec(paramCount);
      bettaParametrization.gradientTo(x, dW, grad, c, stateCount);
      int pos = Arrays.binarySearch(chars, c);
      for (int j = 0; j < paramCount; j++) {
        index[pos * paramCount + j] = c * paramCount + j;
        values[pos * paramCount + j] += grad.at(j);
      }
    }
//    final double betta = 0.1 * 2 / stateCount / (stateCount - 1) / stateDim;
//    for (int i = 0; i < stateCount; i++) {
//      for (int j = i + 1; j < stateCount; j++) {
//        for (int c = 0; c < stateDim; c++) {
//          vGrad.adjust(c, i,-2 * betta * (V.get(i, c) - V.get(j, c)));
//          vGrad.adjust(c, j, -2 * betta * (V.get(j, c) - V.get(i, c)));
//        }
//      }
//    }

    for (int s = 0; s < stateCount; s++) {
      for (int i = 0; i < stateDim; i++) {
        index[chars.length * paramCount + s * stateDim + i] = paramCount * alphabetSize + s * stateDim + i;
        values[chars.length * paramCount + s * stateDim + i] = 2 * r.get(i) * lastLayerState.get(s);
      }
    }

    return new SparseVec(alphabetSize * paramCount + stateCount * stateDim, index, values);
  }


  public Mx getValues(final Vec params) {
    return new VecBasedMx(
        stateCount,
        params.sub(params.dim() - stateCount * stateDim, stateCount * stateDim)
    );
  }

  public Vec distribution(Vec betta) {
    Vec[] distribution = new Vec[] {new ArrayVec(stateCount), new ArrayVec(stateCount)};
    VecTools.fill(distribution[0], 1.0 / stateCount);
    //System.out.println("CPU Distribution: " + Arrays.toString(distribution.toArray()));
    for (int i = 0; i < seq.length(); i++) {
      Mx weightMx = weightParametrization.getMx(betta, seq.intAt(i), stateCount);
      //System.out.println(String.format("-- (%s) CPU WeightMx: %s", i,
      //    Arrays.toString(weightMx.toArray())));
      MxTools.multiplyTo(weightMx, distribution[i % 2], distribution[(i + 1) % 2]);
    }

    return distribution[seq.length() % 2];
  }

  public Vec vecValue(Vec betta) {
    return MxTools.multiply(
        getValues(betta),
        distribution(betta)
    );
  }

  public int stateCount() {
    return stateCount;
  }

  public int stateDim() {
    return stateDim;
  }

  public int alphabetSize() {
    return alphabetSize;
  }


}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy