com.expleague.ml.methods.seq.BootstrapSeqOptimization Maven / Gradle / Ivy
package com.expleague.ml.methods.seq;
import com.expleague.commons.func.Computable;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.WeightedL2;
import com.expleague.ml.methods.SeqOptimization;
public class BootstrapSeqOptimization extends WeakListenerHolderImpl implements SeqOptimization {
protected final FastRandom rnd;
private final SeqOptimization weak;
public BootstrapSeqOptimization(final SeqOptimization weak, final FastRandom rnd) {
this.weak = weak;
this.rnd = rnd;
}
@Override
public Computable, Vec> fit(final DataSet> learn, final Loss globalLoss) {
return weak.fit(learn, bootstrap(globalLoss, rnd));
}
private WeightedL2 bootstrap(Loss loss, FastRandom rnd) {
final double[] poissonWeights = new double[loss.xdim()];
for (int i = 0; i < loss.xdim(); i++) {
poissonWeights[i] = rnd.nextPoisson(1.);
}
return new WeightedL2(loss.target, loss.owner(), new ArrayVec(poissonWeights));
}
}