com.expleague.ml.methods.seq.PNFARegressor Maven / Gradle / Ivy
package com.expleague.ml.methods.seq;
import com.expleague.commons.math.FuncC1;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.IntSeq;
import com.expleague.commons.seq.Seq;
import com.expleague.commons.seq.regexp.Alphabet;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.func.FuncEnsemble;
import com.expleague.ml.func.RegularizerFunc;
import com.expleague.ml.loss.WeightedL2;
import com.expleague.ml.methods.SeqOptimization;
import com.expleague.ml.methods.seq.param.BettaParametrization;
import com.expleague.ml.methods.seq.param.WeightParametrization;
import com.expleague.ml.optimization.Optimize;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class PNFARegressor implements SeqOptimization {
private static final double W_SIGMA = 0.7;
private final int startStateCount;
private final int endStateCount;
private final int stateDim;
private final int alphabetSize;
private final Alphabet alphabet;
private final Random random;
private final Optimize> weightsOptimize;
private final double alpha;
private final double addToDiag;
private final double betta;
private final double expandMxPercent;
private final BettaParametrization bettaParametrization;
private final WeightParametrization weightParametrization;
public PNFARegressor(int startStateCount,
int endStateCount,
int stateDim,
Alphabet alphabet,
double alpha,
double betta,
double addToDiag,
double expandMxPercent,
Random random,
Optimize> weightsOptimize,
BettaParametrization bettaParametrization,
WeightParametrization weightParametrization) {
this.startStateCount = startStateCount;
this.endStateCount = endStateCount;
this.betta = betta;
this.stateDim = stateDim;
this.alphabetSize = alphabet.size();
this.alpha = alpha;
this.addToDiag = addToDiag;
this.expandMxPercent = expandMxPercent;
this.random = random;
this.weightsOptimize= weightsOptimize;
this.alphabet = alphabet;
this.bettaParametrization = bettaParametrization;
this.weightParametrization = weightParametrization;
}
@Override
public PNFAModel fit(final DataSet> learn, final Loss loss) {
Vec params = init(loss.target(), startStateCount);
PNFAItemVecRegression[] funcs = new PNFAItemVecRegression[learn.length()];
//Vec wCacheVec = new ArrayVec(stateCount * stateCount * alphabetSize);
//VecTools.fill(wCacheVec, -1);
// Mx[] wCache = new Mx[alphabetSize];
// for (int i = 0; i < wCache.length; i++) {
// wCache[i] = new VecBasedMx(stateCount, wCacheVec.sub(stateCount * stateCount * i, stateCount * stateCount));
// }
for (int stateCount = startStateCount; stateCount <= endStateCount; stateCount++) {
for (int i = 0; i < learn.length(); i++) {
final IntSeq seq = (IntSeq) learn.at(i);
funcs[i] = new PNFAItemVecRegression(
seq,
loss.target().sub(i * stateDim, stateDim),
stateCount,
alphabetSize,
stateDim,
bettaParametrization,
weightParametrization
);
}
final RegularizerFunc regularizer = new MyRegularizer(funcs, null, stateCount, alpha, betta);
final FuncEnsemble func = new FuncEnsemble<>(funcs, loss.getWeights());
params = weightsOptimize.optimize(func, regularizer, params);
double totalEntropy = 0;
for (int i =0 ; i < learn.length(); i++) {
totalEntropy += VecTools.entropy(funcs[i].distribution(params));
}
System.out.println("Entropy: " + (totalEntropy / learn.length()));
if (stateCount == endStateCount) {
break;
}
List wIdx = IntStream.range(0, alphabetSize).boxed().collect(Collectors.toList());
final Vec paramsFinal = params;
final int stateCountFinal = stateCount;
wIdx.sort(Comparator.comparingDouble(i -> {
Mx w = weightParametrization.getMx(paramsFinal, i, stateCountFinal);
double entropy = Double.MAX_VALUE;
for (int j = 0; j < w.columns(); j++) {
entropy = Math.min(entropy, VecTools.entropy(w.col(j)));
}
return entropy;
}
));
Set badMxIdx = wIdx.stream().limit((int) (expandMxPercent * wIdx.size())).collect(Collectors.toSet());
Vec newParams = new ArrayVec(2 * (stateCount + 1) * alphabetSize + (stateCount + 1) * stateDim);
int paramsPos = 0, newParamsPos = 0;
//todo generify for different parametrizations
for (int a = 0; a < alphabetSize; a++, paramsPos += 2 * stateCount, newParamsPos += 2 * (stateCount + 1)) {
VecTools.assign(newParams.sub(newParamsPos, stateCount), params.sub(paramsPos, stateCount));
VecTools.assign(newParams.sub(newParamsPos + stateCount + 1, stateCount), params.sub(paramsPos + stateCount, stateCount));
newParams.set(newParamsPos + 2 * stateCount + 1, W_SIGMA * random.nextGaussian());
double val;
if (badMxIdx.contains(a)) {
val = W_SIGMA * random.nextGaussian() * 100;
}
else {
val = W_SIGMA * random.nextGaussian();
}
newParams.set(newParamsPos + stateCount, val);
}
Vec expected = newParams.sub(newParams.length() - stateDim, stateDim);
for (int i = 0; i < learn.length(); i++) {
final IntSeq seq = (IntSeq) learn.at(i);
PNFAItemVecRegression f = new PNFAItemVecRegression(
seq,
loss.target().sub(i * stateDim, stateDim),
stateCount + 1,
alphabetSize,
stateDim,
bettaParametrization,
weightParametrization
);
Vec distr = f.distribution(newParams);
for (int j = 0; j < stateDim; j++) {
expected.adjust(j, distr.get(j) * loss.target().get(i * stateDim + j) / learn.length());
}
}
params = newParams;
}
// System.out.println("Value after: " + func.value(params) / func.size());
return new PNFAModel<>(params, endStateCount, stateDim, addToDiag, alpha, alphabet, bettaParametrization, weightParametrization);
}
private Vec init(Vec target, int stateCount) {
int paramCount = bettaParametrization.paramCount(stateCount);
final Vec params = new ArrayVec(
paramCount * alphabetSize + stateCount * stateDim
);
{ // u & v init
for (int i = 0; i < paramCount * alphabetSize; i++) {
params.set(i, W_SIGMA * Math.abs(random.nextGaussian()));
}
}
{ // values init
final Mx values = new VecBasedMx(stateDim, params.sub(params.dim() - stateCount * stateDim, stateCount * stateDim));
final Mx targetValuesMx = new VecBasedMx(stateDim, target);
for (int col = 0; col < targetValuesMx.columns(); col++) {
final double[] targetValues = targetValuesMx.col(col).toArray();
Arrays.sort(targetValues);
List medians = new ArrayList<>(stateCount);
for (int i = 0; i < stateCount; i++) {
medians.add(targetValues[(int) ((i + 0.5) * target.dim() / stateDim / stateCount)]);
}
Collections.shuffle(medians, random);
for (int i = 0; i < stateCount; i++) {
values.set(i, col, medians.get(i));
}
}
}
return params;
}
private class MyRegularizer extends RegularizerFunc.Stub {
private final PNFAItemVecRegression[] funcs;
private final int stateCount;
private final double alpha;
private final double betta;
private final Vec wCacheVec;
Vec prev;
public MyRegularizer(PNFAItemVecRegression[] funcs, Vec wCacheVec, int stateCount, double alpha, double betta) {
this.funcs = funcs;
this.wCacheVec = wCacheVec;
this.stateCount = stateCount;
this.alpha = alpha;
this.betta = betta;
prev = null;
prev = new ArrayVec(funcs[0].dim());
}
@Override
public double value(Vec x) {
int paramCount = bettaParametrization.paramCount(stateCount);
return alpha * VecTools.l1(x.sub(0, paramCount * alphabetSize))
+ betta * VecTools.l2(x.sub(paramCount * alphabetSize, stateCount * stateDim));
}
@Override
public int dim() {
return bettaParametrization.paramCount(stateCount) * alphabetSize + stateCount * stateDim;
}
@Override
public Vec project(Vec x) {
Mx values = funcs[0].getValues(x);
IntStream.range(0, x.length() - values.dim()).filter(idx -> prev.get(idx) != x.get(idx)).forEach(idx -> {
final double val = x.get(idx);
if (Math.abs(val) > alpha)
x.adjust(idx, val > alpha ? -alpha : alpha);
else
x.set(idx, 0);
});
VecTools.assign(prev, x);
VecTools.scale(values, values.dim() / (betta + values.dim()));
if (wCacheVec != null) {
VecTools.fill(wCacheVec, -1);
}
return x;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy