All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.methods.trees.RegGreedyObliviousTree Maven / Gradle / Ivy

package com.expleague.ml.methods.trees;

import com.expleague.ml.data.impl.BinarizedDataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.ObliviousTree;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.BFGrid;
import com.expleague.ml.Binarize;
import gnu.trove.list.array.TIntArrayList;

import java.util.*;

public class RegGreedyObliviousTree extends VecOptimization.Stub {
  private final int depth;
  public final BFGrid grid;
  private final double lambda;

  public RegGreedyObliviousTree(final BFGrid grid, final int depth, double lambda) {
    this.grid = grid;
    this.depth = depth;
    this.lambda = lambda;
  }

  private final Set knownSplits = new HashSet<>();
  @Override
  public ObliviousTree fit(final VecDataSet ds, final Loss loss) {
    Pair, List> result = findBestSubsets(ds,loss);
    List leaves = result.getFirst();
    List conditions = result.getSecond();
    final double[] step = new double[leaves.size()];
    final double[] based = new double[leaves.size()];
    for (int i = 0; i < step.length; i++) {
      step[i] = loss.bestIncrement(leaves.get(i).total());
      based[i] = leaves.get(i).size();
    }
    return new ObliviousTree(conditions, step, based);
  }

  //decomposition for oblivious tree with non-constant functions in leaves
  public final Pair,List> findBestSubsets(final VecDataSet ds, final Loss loss) {
    List leaves = new ArrayList(1 << depth);
    final List conditions = new ArrayList(depth);
    double currentScore = Double.POSITIVE_INFINITY;

    final BinarizedDataSet bds =  ds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);

    leaves.add(new BFOptimizationSubset(bds, loss, learnPoints(loss, ds)));

    final double[] scores = new double[grid.size()];
    for (int level = 0; level < depth; level++) {
      Arrays.fill(scores, 0.);
      for (final BFOptimizationSubset leaf : leaves) {
        leaf.visitAllSplits((bf, left, right) -> scores[bf.bfIndex] += loss.score(left) + loss.score(right));
      }
      final TIntArrayList currentConditions = new TIntArrayList();
      for (int i = 0; i < conditions.size(); i++) {
        currentConditions.add(conditions.get(i).bfIndex);
      }
      for (int i = 0; i < scores.length; i++) {
        if (!currentConditions.contains(i)) {
          currentConditions.add(i);
          currentConditions.sort();
          if (!knownSplits.contains(currentConditions))
            scores[i] += lambda * Math.abs(scores[i]);
          currentConditions.remove(i);
        }
        else {
          currentConditions.sort();
          scores[i] += knownSplits.contains(currentConditions) ? 0 : lambda * Math.abs(scores[i]);
        }
      }
      final int bestSplit = ArrayTools.min(scores);
      if (bestSplit < 0 || scores[bestSplit] >= currentScore)
        break;
      final BFGrid.BinaryFeature bestSplitBF = grid.bf(bestSplit);
      final List next = new ArrayList(leaves.size() * 2);
      final ListIterator iter = leaves.listIterator();
      while (iter.hasNext()) {
        final BFOptimizationSubset subset = iter.next();
        next.add(subset);
        next.add(subset.split(bestSplitBF));
      }
      conditions.add(bestSplitBF);
      if (!currentConditions.contains(bestSplitBF.bfIndex))
        currentConditions.add(bestSplitBF.bfIndex);
      currentConditions.sort();
      currentConditions.trimToSize();
      knownSplits.add(currentConditions);
      leaves = next;
      currentScore = scores[bestSplit];
    }
    return new Pair<>(leaves, conditions);
  }

  private int[] learnPoints(Loss loss, VecDataSet ds) {
    if (loss instanceof WeightedLoss) {
      return ((WeightedLoss) loss).points();
    } else return ArrayTools.sequence(0, ds.length());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy