com.expleague.ml.methods.trees.GreedyObliviousTreeValidationRidgeLeaves Maven / Gradle / Ivy
package com.expleague.ml.methods.trees;
import com.expleague.commons.math.vectors.Mx;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.ml.BFGrid;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.impl.BinaryFeatureImpl;
import com.expleague.ml.methods.linearRegressionExperiments.MultipleValidationRidgeRegression;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.util.Pair;
import com.expleague.commons.util.ThreadTools;
import com.expleague.ml.Binarize;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.TransObliviousTree;
import gnu.trove.list.array.TIntArrayList;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import static com.expleague.commons.math.vectors.VecTools.adjust;
/**
* User: noxoomo
*/
public class GreedyObliviousTreeValidationRidgeLeaves extends VecOptimization.Stub {
private final GreedyObliviousTree> base;
private final FastRandom rand;
public GreedyObliviousTreeValidationRidgeLeaves(
final GreedyObliviousTree> base,
final FastRandom rand) {
this.base = base;
this.rand = rand;
}
private final static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("Leaves executor", -1);
private int[] oobPoints(WeightedLoss loss) {
final TIntArrayList result = new TIntArrayList(loss.dim() + 1000);
for (int i = 0; i < loss.dim(); i++) {
if (loss.weight(i) == 0)
result.add(i);
}
return result.toArray();
}
@Override
public TransObliviousTree fit(final VecDataSet ds, final Loss loss) {
final WeightedLoss bsLoss = DataTools.bootstrap(loss, rand);
final Pair, List> tree = base.findBestSubsets(ds, bsLoss);
final List conditions = tree.getSecond();
final List subsets = tree.getFirst();
final CountDownLatch latch = new CountDownLatch(subsets.size());
final Trans[] leafTrans = new Trans[subsets.size()];
//damn java 7 without unique, filters, etc and autoboxing overhead…
Set uniqueFeatures = new TreeSet<>();
for (BFGrid.Feature bf : conditions) {
if (bf.row().size() > 2)
uniqueFeatures.add(bf.findex());
}
// //prototype
if (ds.data().rows() > 20) {
while (uniqueFeatures.size() < 6) {
int addFeature = rand.nextInt(ds.data().columns());
if (base.grid.row(addFeature).size() > 2) {
uniqueFeatures.add(addFeature);
}
}
}
final int[] features = new int[uniqueFeatures.size()];
{
int j = 0;
for (Integer i : uniqueFeatures) {
features[j++] = i;
}
}
final List oobSubsets;
final int[] oobPoints = oobPoints(bsLoss);
{
final BinarizedDataSet bds = ds.cache().cache(Binarize.class, VecDataSet.class).binarize(base.grid);
List leaves = new ArrayList<>(1);
leaves.add(new BFOptimizationSubset(bds, loss, oobPoints));
for (int i = 0; i < conditions.size(); ++i) {
final List next = new ArrayList<>(leaves.size() * 2);
final ListIterator iter = leaves.listIterator();
while (iter.hasNext()) {
final BFOptimizationSubset subset = iter.next();
next.add(subset);
next.add(subset.split(conditions.get(i)));
}
leaves = next;
}
oobSubsets = leaves;
}
{
final VecDataSet[] datas = new VecDataSet[subsets.size()];
final VecDataSet[] valDatas = new VecDataSet[subsets.size()];
final L2[] losses = new L2[subsets.size()];
final L2[] valLosses = new L2[subsets.size()];
for (int i = 0; i < subsets.size(); ++i) {
final int ind = i;
exec.submit(() -> {
{
final BFOptimizationSubset subset = subsets.get(ind);
int[] points = subset.getPoints();
Mx subData = subMx(ds.data(), points, features);
Vec target = loss.target();
Vec localTarget = subVec(target, points);
final double bias = bsLoss.bestIncrement((WeightedLoss.Stat) subset.total());
adjust(localTarget, -bias);
VecDataSetImpl subDataSet = new VecDataSetImpl(subData, ds);
L2 localLoss = DataTools.newTarget(L2.class, localTarget, subDataSet);
datas[ind] = subDataSet;
losses[ind] = localLoss;
final BFOptimizationSubset valSubset = oobSubsets.get(ind);
int[] valPoints = valSubset.getPoints();
Mx valData = subMx(ds.data(), valPoints, features);
Vec valTarget = subVec(target, valPoints);
adjust(valTarget, -bias);
valDatas[ind] = new VecDataSetImpl(valData, ds);
valLosses[ind] = DataTools.newTarget(L2.class, valTarget, valDatas[ind]);
}
latch.countDown();
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
MultipleValidationRidgeRegression ridgeRegression = new MultipleValidationRidgeRegression();
Trans[] result = ridgeRegression.fit(datas, losses, valDatas, valLosses);
for (int i = 0; i < subsets.size(); ++i) {
leafTrans[i] = new MappedTrans(result[i], features, bsLoss.bestIncrement((WeightedLoss.Stat) subsets.get(i).total()));
}
}
return new TransObliviousTree(conditions, leafTrans);
}
private Vec subVec(Vec target, int[] points) {
Vec result = new ArrayVec(points.length);
for (int i = 0; i < points.length; ++i) {
result.set(i, target.get(points[i]));
}
return result;
}
private Mx subMx(Mx base, int[] points, int[] features) {
Mx result = new VecBasedMx(points.length, features.length);
for (int i = 0; i < points.length; ++i) {
for (int j = 0; j < features.length; ++j) {
result.set(i, j, base.get(points[i], features[j]));
}
}
return result;
}
private class MappedTrans extends Trans.Stub {
final Trans trans;
final double bias;
final int[] map;
public MappedTrans(Trans trans, int[] features, double bias) {
this.trans = trans;
this.map = features;
this.bias = bias;
}
@Override
public int xdim() {
return map.length;
}
@Override
public int ydim() {
return trans.ydim();
}
@Override
public Vec trans(Vec x) {
Vec inner = new ArrayVec(map.length);
for (int i = 0; i < map.length; ++i) {
inner.set(i, x.get(map[i]));
}
return adjust(trans.trans(inner), bias);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy