com.expleague.ml.methods.seq.GradientSeqBoosting Maven / Gradle / Ivy
package com.expleague.ml.methods.seq;
import com.expleague.commons.func.Computable;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.Seq;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.methods.SeqOptimization;
import java.util.ArrayList;
import java.util.List;
public class GradientSeqBoosting extends WeakListenerHolderImpl, Vec>> implements SeqOptimization {
protected final SeqOptimization weak;
private final Class extends L2> factory;
int iterationsCount;
double step;
public GradientSeqBoosting(final SeqOptimization weak, final int iterationsCount, final double step) {
this(weak, L2.class, iterationsCount, step);
}
public GradientSeqBoosting(final SeqOptimization weak, final Class extends L2> factory, final int iterationsCount, final double step) {
this.weak = weak;
this.factory = factory;
this.iterationsCount = iterationsCount;
this.step = step;
}
@Override
public Computable, Vec> fit(final DataSet> learn, final GlobalLoss globalLoss) {
final Vec cursor = new ArrayVec(globalLoss.xdim());
final List, Vec>> weakModels = new ArrayList<>(iterationsCount);
final Trans gradient = globalLoss.gradient();
for (int t = 0; t < iterationsCount; t++) {
final Vec gradientValueAtCursor = gradient.trans(cursor);
final L2 localLoss = DataTools.newTarget(factory, gradientValueAtCursor, learn);
System.out.println("Iteration " + t + ". Gradient norm: " + VecTools.norm(localLoss.target));
final Computable, Vec> weakModel = weak.fit(learn, localLoss);
weakModels.add(weakModel);
final Computable, Vec> curRes = getResult(new ArrayList<>(weakModels));
invoke(curRes);
for (int i = 0; i < learn.length(); i++) {
cursor.adjust(i, weakModel.compute(learn.at(i)).get(0) * -step);
}
}
return getResult(weakModels);
}
private Computable, Vec> getResult(final List, Vec>> weakModels) {
return argument -> {
double result = 0;
for (Computable, Vec> model: weakModels) {
result += model.compute(argument).get(0) * -step;
}
return new SingleValueVec(result);
};
}
}