All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.expleague.ml.dynamicGrid.trees.GreedyObliviousTreeDynamic Maven / Gradle / Ivy
package com.expleague.ml.dynamicGrid.trees;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.util.ArrayTools;
import com.expleague.ml.Binarize;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.dynamicGrid.AggregateDynamic;
import com.expleague.ml.dynamicGrid.impl.BFDynamicGrid;
import com.expleague.ml.dynamicGrid.impl.BinarizedDynamicDataSet;
import com.expleague.ml.dynamicGrid.interfaces.BinaryFeature;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;
import com.expleague.ml.dynamicGrid.models.ObliviousTreeDynamicBin;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.methods.VecOptimization;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
/**
* Created by noxoomo on 22/07/14.
*/
public class GreedyObliviousTreeDynamic extends VecOptimization.Stub {
private final int depth;
private final DynamicGrid grid;
private boolean growGrid = true;
// private final int minSplits;
private final double lambda;
private static final double eps = 1e-4;
public GreedyObliviousTreeDynamic(final DynamicGrid grid, final int depth, final double lambda) {
this.depth = depth;
this.grid = grid;
// minSplits = 1;
// lambda = 1;
this.lambda = lambda;
}
public GreedyObliviousTreeDynamic(final VecDataSet ds, final int depth) {
this(ds, depth, 0, 1);
}
public GreedyObliviousTreeDynamic(final VecDataSet ds, final int depth, final double lambda) {
this(ds, depth, lambda, 1);
}
public GreedyObliviousTreeDynamic(final VecDataSet ds, final int depth, final double lambda, final int minSplits) {
// this.minSplits = minSplits;
this.depth = depth;
this.lambda = lambda;
this.grid = new BFDynamicGrid(ds, minSplits);
}
public void stopGrowing() {
this.growGrid = false;
}
@Override
public ObliviousTreeDynamicBin fit(final VecDataSet ds, final Loss loss) {
final BinarizedDynamicDataSet bds = ds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
List leaves = new ArrayList<>(1 << depth);
final TIntArrayList nonActiveF = new TIntArrayList(grid.rows());
final TIntArrayList nonActiveBin = new TIntArrayList(grid.rows());
final List conditions = new ArrayList<>(depth);
final double[][] scores = new double[grid.rows()][];
for (int i = 0; i < scores.length; ++i) {
scores[i] = new double[0];
}
while (true) {
boolean updated = false;
leaves.clear();
conditions.clear();
leaves.add(new BFDynamicOptimizationSubset(bds, loss, ArrayTools.sequence(0, ds.length())));
double currentScore = Double.POSITIVE_INFINITY;
for (int level = 0; level < depth; level++) {
for (int f = 0; f < scores.length; ++f) {
if (scores[f].length != grid.row(f).size()) {
scores[f] = new double[grid.row(f).size()];
} else Arrays.fill(scores[f], 0);
}
for (final BFDynamicOptimizationSubset leaf : leaves) {
leaf.visitAllSplits(new AggregateDynamic.SplitVisitor() {
@Override
public void accept(final BinaryFeature bf, final AdditiveStatistics left, final AdditiveStatistics right) {
final double leftScore = loss.score(left);
final double rightScore = loss.score(right);
scores[bf.fIndex()][bf.binNo()] += leftScore + rightScore;
}
});
}
int bestSplitF = -1;
int bestSplitBin = -1;
double bestSplitScore = Double.POSITIVE_INFINITY;
int bestNonActiveSplitF = -1;
int bestNonActiveSplitBin = -1;
double bestNonActiveSplitScore = Double.POSITIVE_INFINITY;
nonActiveF.clear();
nonActiveBin.clear();
for (int f = 0; f < scores.length; ++f) {
for (int bin = 0; bin < scores[f].length; ++bin) {
final BinaryFeature bf = grid.bf(f, bin);
if (bf.isActive()) {
if (bestSplitScore > scores[f][bin]) {
bestSplitF = f;
bestSplitBin = bin;
bestSplitScore = scores[f][bin];
}
} else {
nonActiveF.add(f);
nonActiveBin.add(bin);
}
}
}
if (growGrid) {
final double threshold = bestSplitScore < currentScore ? bestSplitScore : currentScore;
for (int j = 0; j < nonActiveF.size(); ++j) {
final int feature = nonActiveF.get(j);
final int bin = nonActiveBin.get(j);
final BinaryFeature bf = grid.bf(feature, bin);
final double reg = lambda != 0 ? bf.regularization() : 0;
final double score = threshold - scores[feature][bin] - lambda * reg;
if (score > eps) {
bds.queueSplit(bf);
if (bestNonActiveSplitScore > scores[feature][bin]) {
bestNonActiveSplitF = feature;
bestNonActiveSplitBin = bin;
bestNonActiveSplitScore = scores[feature][bin];
}
}
}
}
if (bestNonActiveSplitScore <= bestSplitScore) {
bestSplitF = bestNonActiveSplitF;
bestSplitBin = bestNonActiveSplitBin;
}
//tree growing continue
if (bestSplitF < 0 || scores[bestSplitF][bestSplitBin] >= currentScore) {
if (growGrid) {
if (bds.acceptQueue(leaves)) {
updated = true;
}
}
break;
}
final BinaryFeature bestSplitBF = grid.bf(bestSplitF, bestSplitBin);
final List next = new ArrayList<>(leaves.size() * 2);
final ListIterator iter = leaves.listIterator();
while (iter.hasNext()) {
final BFDynamicOptimizationSubset subset = iter.next();
next.add(subset);
next.add(subset.split(bestSplitBF));
}
conditions.add(bestSplitBF);
leaves = next;
currentScore = scores[bestSplitF][bestSplitBin];
if (growGrid) {
if (bds.acceptQueue(leaves)) {
updated = true;
}
}
}
// updated = false;
if (!updated) {
final double[] values = new double[leaves.size()];
for (int i = 0; i < values.length; i++) {
values[i] = loss.bestIncrement(leaves.get(i).total());
}
// for (BinaryFeature bf : conditions) {
// bf.use();
// }
return new ObliviousTreeDynamicBin(conditions, values);
}
}
}
public int[] hist() {
return grid.hist();
}
}