All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.dynamicGrid.trees.GreedyObliviousTreeDynamic2 Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.dynamicGrid.trees;

import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.dynamicGrid.impl.BFDynamicGrid;
import com.expleague.ml.dynamicGrid.interfaces.BinaryFeature;
import com.expleague.commons.func.AdditiveStatistics;
import com.expleague.commons.util.ArrayTools;
import com.expleague.ml.Binarize;
import com.expleague.ml.dynamicGrid.AggregateDynamic;
import com.expleague.ml.dynamicGrid.impl.BinarizedDynamicDataSet;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;
import com.expleague.ml.dynamicGrid.models.ObliviousTreeDynamicBin;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.methods.VecOptimization;
import gnu.trove.list.array.TIntArrayList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;

/**
 * Created by noxoomo on 22/07/14.
 */
public class GreedyObliviousTreeDynamic2 extends VecOptimization.Stub {
  private final int depth;
  private final DynamicGrid grid;
  private boolean growGrid = true;
  private final int minSplits;
  private final double lambda;
  private static final double eps = 1e-4;


  public GreedyObliviousTreeDynamic2(final DynamicGrid grid, final int depth, final double lambda) {
    this.depth = depth;
    this.grid = grid;
    minSplits = 1;
    this.lambda = lambda;
  }

  public GreedyObliviousTreeDynamic2(final VecDataSet ds, final int depth) {
    this(ds, depth, 0, 1);
  }

  public GreedyObliviousTreeDynamic2(final VecDataSet ds, final int depth, final double lambda) {
    this(ds, depth, lambda, 1);
  }

  public GreedyObliviousTreeDynamic2(final VecDataSet ds, final int depth, final double lambda, final int minSplits) {
    this.minSplits = minSplits;
    this.depth = depth;
    this.lambda = lambda;
    this.grid = new BFDynamicGrid(ds, minSplits);
  }

  public GreedyObliviousTreeDynamic2(final VecDataSet ds, final int depth, final double lambda, final int minSplits, final boolean grow) {
    this.minSplits = minSplits;
    this.depth = depth;
    this.lambda = lambda;
    this.grid = new BFDynamicGrid(ds, minSplits);
    this.growGrid = grow;
  }


  public GreedyObliviousTreeDynamic2(final DynamicGrid grid, final int depth, final double lambda, final boolean grow) {
//    this.minSplits = minSplits;
    this.depth = depth;
    this.minSplits = 1;
    this.lambda = lambda;
    this.grid = grid;
    this.growGrid = grow;
  }


  public void stopGrowing() {
    this.growGrid = false;
  }

  @Override
  public ObliviousTreeDynamicBin fit(final VecDataSet ds, final Loss loss) {
    final BinarizedDynamicDataSet bds = ds.cache().cache(Binarize.class, VecDataSet.class).binarize(grid);

    List leaves = new ArrayList<>(1 << depth);
    final TIntArrayList nonActiveF = new TIntArrayList(grid.rows());
    final TIntArrayList nonActiveBin = new TIntArrayList(grid.rows());
    final List conditions = new ArrayList<>(depth);
    final double[][] scores = new double[grid.rows()][];
    for (int i = 0; i < scores.length; ++i) {
      scores[i] = new double[0];
    }

    leaves.clear();
    conditions.clear();
    leaves.add(new BFDynamicOptimizationSubset(bds, loss, ArrayTools.sequence(0, ds.length())));
    double currentScore = Double.POSITIVE_INFINITY;


    for (int level = 0; level < depth; level++) {
      for (int f = 0; f < scores.length; ++f) {
        if (scores[f].length != grid.row(f).size()) {
          scores[f] = new double[grid.row(f).size()];
        } else Arrays.fill(scores[f], 0);
      }


      for (final BFDynamicOptimizationSubset leaf : leaves) {
        leaf.visitAllSplits(new AggregateDynamic.SplitVisitor() {
          @Override
          public void accept(final BinaryFeature bf, final AdditiveStatistics left, final AdditiveStatistics right) {
            final double leftScore = loss.score(left);
            final double rightScore = loss.score(right);
            scores[bf.fIndex()][bf.binNo()] += leftScore + rightScore;
          }
        });
      }

      int bestSplitF = -1;
      int bestSplitBin = -1;
      double bestSplitScore = Double.POSITIVE_INFINITY;


      for (int f = 0; f < scores.length; ++f) {

        int candidateIndex = -1;

        double bestFeatureScore = Double.POSITIVE_INFINITY;
        int bestBinIndex = -1;


        for (int bin = 0; bin < scores[f].length; ++bin) {
          final BinaryFeature bf = grid.bf(f, bin);
          if (bf.isActive()) {
            if (scores[f][bin] < bestFeatureScore) {
              bestBinIndex = bin;
              bestFeatureScore = scores[f][bin];
            }
          } else {
            candidateIndex = bin;
          }
        }

        if (candidateIndex != -1 && growGrid) {
          final BinaryFeature candidate = grid.bf(f, candidateIndex);

          double score = bestFeatureScore - scores[f][candidateIndex];
          final double reg = candidate.regularization();
          score -= lambda * reg;
          if (score > 0) {
            bds.queueSplit(candidate);
            bestBinIndex = candidateIndex;
          }
        }
        if (bestBinIndex == -1)
          continue;

        if (scores[f][bestBinIndex] < bestSplitScore) {
          bestSplitBin = bestBinIndex;
          bestSplitF = f;
          bestSplitScore = scores[f][bestBinIndex];
        }
      }


      //tree growing continue
      if (bestSplitF < 0 || scores[bestSplitF][bestSplitBin] >= currentScore) {
        if (growGrid) {
          bds.acceptQueue(leaves);
        }
        break;
      }
      final BinaryFeature bestSplitBF = grid.bf(bestSplitF, bestSplitBin);
      final List next = new ArrayList<>(leaves.size() * 2);
      final ListIterator iter = leaves.listIterator();
      while (iter.hasNext()) {
        final BFDynamicOptimizationSubset subset = iter.next();
        next.add(subset);
        next.add(subset.split(bestSplitBF));
      }
      conditions.add(bestSplitBF);
      leaves = next;
      currentScore = scores[bestSplitF][bestSplitBin];
      if (growGrid) {
        bds.acceptQueue(leaves);
      }
    }

//            updated = false;
    final double[] values = new double[leaves.size()];
    for (int i = 0; i < values.length; i++) {
      values[i] = loss.bestIncrement(leaves.get(i).total());
    }
//    for (Feature bf : conditions) {
//      bf.use();
//    }
    return new ObliviousTreeDynamicBin(conditions, values);
  }


  public int[] hist() {
    return grid.hist();
  }
}






© 2015 - 2024 Weber Informatics LLC | Privacy Policy