com.expleague.ml.methods.trees.RidgeGreedyObliviousTree Maven / Gradle / Ivy
package com.expleague.ml.methods.trees;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.SingleValueVec;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.Binarize;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.methods.linearRegressionExperiments.RidgeRegression;
import com.expleague.ml.models.ModelTools;
import com.expleague.ml.models.ObliviousTree;
import java.util.ArrayList;
import java.util.List;
/**
* User: noxoomo
* Date: 02.12.2015
*/
public class RidgeGreedyObliviousTree extends VecOptimization.Stub {
private final GreedyObliviousTree base;
final double lambda;
public RidgeGreedyObliviousTree(GreedyObliviousTree base, double lambda) {
this.base = base;
this.lambda = lambda;
}
private int[] learnPoints(Loss loss, VecDataSet ds) {
if (loss instanceof WeightedLoss) {
return ((WeightedLoss) loss).points();
} else return ArrayTools.sequence(0, ds.length());
}
@SuppressWarnings("Duplicates")
protected Pair filter(final List entryList, final BinarizedDataSet bds, Vec sourceTarget, int[] points) {
final byte[] binary = new byte[base.grid.rows()];
Mx otData = new VecBasedMx(points.length, entryList.size());
Vec target = new ArrayVec(points.length);
for (int i=0;i < points.length;++i) {
final int ind = points[i];
for (int f=0; f (otData, target);
}
@Override
public ModelTools.CompiledOTEnsemble fit(final VecDataSet ds, final Loss loss) {
ObliviousTree tree = base.fit(ds, loss);
Ensemble ensemble = new Ensemble<>(new ObliviousTree[]{tree}, VecTools.fill(new SingleValueVec(1), 1.0));
ModelTools.CompiledOTEnsemble compiledOTEnsemble = ModelTools.compile(ensemble);
List entryList = compiledOTEnsemble.getEntries();
final BinarizedDataSet bds = ds.cache().cache(Binarize.class, VecDataSet.class).binarize(base.grid);
Pair compiledLearn = filter(entryList, bds, loss.target(), learnPoints(loss, ds));
RidgeRegression ridgeRegression = new RidgeRegression(lambda);
Vec weights = ridgeRegression.fit(compiledLearn.first, compiledLearn.second);
ArrayList newEntries = new ArrayList<>();
for (int i=0; i < weights.dim();++i) {
newEntries.add(new ModelTools.CompiledOTEnsemble.Entry(entryList.get(i).getBfIndices(), weights.get(i)));
}
return new ModelTools.CompiledOTEnsemble(newEntries, tree.grid());
}
}