All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRF Maven / Gradle / Ivy

package hex.tree.drf;

import hex.Model;
import hex.ModelCategory;
import hex.schemas.DRFV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.Timer;

import java.util.Arrays;
import java.util.Random;

import static hex.genmodel.GenModel.getPrediction;
import static hex.tree.drf.TreeMeasuresCollector.asSSE;
import static hex.tree.drf.TreeMeasuresCollector.asVotes;

/** Gradient Boosted Trees
 *
 *  Based on "Elements of Statistical Learning, Second Edition, page 387"
 */
public class DRF extends SharedTree {
  protected int _mtry;

  @Override public ModelCategory[] can_build() {
    return new ModelCategory[]{
      ModelCategory.Regression,
      ModelCategory.Binomial,
      ModelCategory.Multinomial,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };

  // Called from an http request
  public DRF( hex.tree.drf.DRFModel.DRFParameters parms) { super("DRF",parms); init(false); }

  @Override public DRFV3 schema() { return new DRFV3(); }

  /** Start the DRF training Job on an F/J thread. */
  @Override public Job trainModel() {
    return start(new DRFDriver(), _parms._ntrees/*work for progress bar*/);
  }

  @Override public Vec vresponse() { return super.vresponse() == null ? response() : super.vresponse(); }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    // Initialize local variables
    if (!(0.0 < _parms._sample_rate && _parms._sample_rate <= 1.0))
      throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + _parms._sample_rate);
    if( _parms._mtries < 1 && _parms._mtries != -1 ) error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + _parms._mtries);
    if( _train != null ) {
      int ncols = _train.numCols();
      if( _parms._mtries != -1 && !(1 <= _parms._mtries && _parms._mtries < ncols))
        error("_mtries","Computed mtries should be -1 or in interval <1,#cols> but it is " + _parms._mtries);
    }
    if (_parms._sample_rate == 1f && _valid == null)
      error("_sample_rate", "Sample rate is 100% and no validation dataset is specified.  There are no OOB data to compute out-of-bag error estimation!");
  }

  // A standard DTree with a few more bits.  Support for sampling during
  // training, and replaying the sample later on the identical dataset to
  // e.g. compute OOBEE.
  static class DRFTree extends DTree {
    final int _mtrys;           // Number of columns to choose amongst in splits
    final long _seeds[];        // One seed for each chunk, for sampling
    final transient Random _rand; // RNG for split decisions & sampling
    DRFTree( Frame fr, int ncols, char nbins, char nclass, int min_rows, int mtrys, long seed ) {
      super(fr._names, ncols, nbins, nclass, min_rows, seed);
      _mtrys = mtrys;
      _rand = createRNG(seed);
      _seeds = new long[fr.vecs()[0].nChunks()];
      for( int i=0; i<_seeds.length; i++ )
        _seeds[i] = _rand.nextLong();
    }
    // Return a deterministic chunk-local RNG.  Can be kinda expensive.
    public Random rngForChunk( int cidx ) {
      long seed = _seeds[cidx];
      return createRNG(seed);
    }
  }

  /** Fill work columns:
   *   - classification: set 1 in the corresponding wrk col according to row response
   *   - regression:     copy response into work column (there is only 1 work column)
   */
  private class SetWrkTask extends MRTask {
    @Override public void map( Chunk chks[] ) {
      Chunk cy = chk_resp(chks);
      for( int i=0; i= 2 ) {
//        for( int c=0; c<_nclass; c++ ) {
//          final double init = _model._output._priorClassDist[c];
//          new MRTask() {
//            @Override public void map(Chunk tree) { for( int i=0; i but it is " + _mtry);
      // Initialize TreeVotes for classification, MSE arrays for regression
      initTreeMeasurements();
      // Append number of trees participating in on-the-fly scoring
      _train.add("OUT_BAG_TREES", _response.makeZero());
      // Prepare working columns
      new SetWrkTask().doAll(_train);
      // If there was a check point recompute tree_<_> and oob columns based on predictions from previous trees
      // but only if OOB validation is requested.
      if (_parms._checkpoint) {
        Timer t = new Timer();
        // Compute oob votes for each output level
        new OOBScorer(_ncols, _nclass, _parms._sample_rate, _model._output._treeKeys).doAll(_train);
        Log.info("Reconstructing oob stats from checkpointed model took " + t);
      }

      // The RNG used to pick split columns
      Random rand = createRNG(_parms._seed);
      // To be deterministic get random numbers for previous trees and
      // put random generator to the same state
      for (int i=0; i<_ntreesFromCheckpoint; i++) rand.nextLong();

      int tid;
      DTree[] ktrees = null;
      // Prepare tree statistics
      // Build trees until we hit the limit
      for( tid=0; tid<_parms._ntrees; tid++) { // Building tid-tree
        if (tid!=0 || !_parms._checkpoint) { // do not make initial scoring if model already exist
          double training_r2 = doScoringAndSaveModel(false, true, _parms._build_tree_one_node);
          if( training_r2 >= _parms._r2_stopping )
            return;             // Stop when approaching round-off error
        }
        // At each iteration build K trees (K = nclass = response column domain size)

        // TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
        // Idea: launch more DRF at once.
        Timer kb_timer = new Timer();
        buildNextKTrees(_train,_mtry,_parms._sample_rate,rand,tid);
        Log.info((tid+1) + ". tree was built " + kb_timer.toString());
        DRF.this.update(1);
        if( !isRunning() ) return; // If canceled during building, do not bulkscore

      }
      doScoringAndSaveModel(true, true, _parms._build_tree_one_node);
    }



    // --------------------------------------------------------------------------
    // Build the next random k-trees representing tid-th tree
    private void buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
      // We're going to build K (nclass) trees - each focused on correcting
      // errors for a single class.
      final DTree[] ktrees = new DTree[_nclass];

      // Initial set of histograms.  All trees; one leaf per tree (the root
      // leaf); all columns
      DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];

      // Adjust nbins for the top-levels
      int adj_nbins = Math.max((1<<(10-0)),_parms._nbins);

      // Use for all k-trees the same seed. NOTE: this is only to make a fair
      // view for all k-trees
      final long[] _distribution = _model._output._distribution;
      long rseed = rand.nextLong();
        // Initially setup as-if an empty-split had just happened
      for (int k = 0; k < _nclass; k++) {
        if (_distribution[k] != 0) { // Ignore missing classes
          // The Boolean Optimization
          // This optimization assumes the 2nd tree of a 2-class system is the
          // inverse of the first (and that the same columns were picked)
          if( k==1 && _nclass==2 ) continue;
          ktrees[k] = new DRFTree(fr, _ncols, (char) _parms._nbins, (char) _nclass, _parms._min_rows, mtrys, rseed);
          boolean isBinom = isClassifier();
          new DRFUndecidedNode(ktrees[k], -1, DHistogram.initialHist(fr, _ncols, adj_nbins, hcs[k][0], isBinom)); // The "root" node
        }
      }

      // Sample - mark the lines by putting 'OUT_OF_BAG' into nid() vector
      Timer t_1 = new Timer();
      Sample ss[] = new Sample[_nclass];
      for( int k=0; k<_nclass; k++)
        if (ktrees[k] != null) ss[k] = new Sample((DRFTree)ktrees[k], sample_rate).dfork(0,new Frame(vec_nids(fr,k),vec_resp(fr)), _parms._build_tree_one_node);
      for( int k=0; k<_nclass; k++)
        if( ss[k] != null ) ss[k].getResult();
      Log.debug("Sampling took: + " + t_1);

      int[] leafs = new int[_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i

      // ----
      // One Big Loop till the ktrees are of proper depth.
      // Adds a layer to the trees each pass.
      Timer t_2 = new Timer();
      int depth=0;
      for( ; depth<_parms._max_depth; depth++ ) {
        if( !isRunning() ) return;
        hcs = buildLayer(fr, _parms._nbins, ktrees, leafs, hcs, true, _parms._build_tree_one_node);
        // If we did not make any new splits, then the tree is split-to-death
        if( hcs == null ) break;
      }
      Log.debug("Tree build took: " + t_2);

      // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
      // LeafNodes to hold predictions.
      Timer t_3 = new Timer();
      for( int k=0; k<_nclass; k++ ) {
        DTree tree = ktrees[k];
        if( tree == null ) continue;
        int leaf = leafs[k] = tree.len();
        for( int nid=0; nid {
      /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
      /* @IN */  double _threshold;      // Sum of squares for this tree only
      /* @OUT */ long rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
      /* @OUT */ long allRows;    // number of all OOB rows (sampled by this tree)
      /* @OUT */ float sse;      // Sum of squares for this tree only
      CollectPreds(DTree trees[], int leafs[], double threshold) { _trees=trees; _threshold = threshold; }
      final boolean importance = true;
      @Override public void map( Chunk[] chks ) {
        final Chunk    y       = importance ? chk_resp(chks) : null; // Response
        final double[] rpred   = importance ? new double[1+_nclass] : null; // Row prediction
        final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
        final Chunk   oobt  = chk_oobt(chks); // Out-of-bag rows counter over all trees
        // Iterate over all rows
        for( int row=0; row 1 : "broken histo range "+hs[i];
        cols[len++] = i;        // Gather active column
      }
      int choices = len;        // Number of columns I can choose from
      assert choices > 0;

      // Draw up to mtry columns at random without replacement.
      for( int i=0; i 0;
      return Arrays.copyOfRange(cols, len, choices);
    }
  }
  
  // ---
  static class DRFLeafNode extends LeafNode {
    DRFLeafNode( DTree tree, int pid ) { super(tree,pid); }
    DRFLeafNode( DTree tree, int pid, int nid ) { super(tree, pid, nid); }
    // Insert just the predictions: a single byte/short if we are predicting a
    // single class, or else the full distribution.
    @Override protected AutoBuffer compress(AutoBuffer ab) { assert !Double.isNaN(_pred); return ab.put4f(_pred); }
    @Override protected int size() { return 4; }
  }

  // Deterministic sampling
  static class Sample extends MRTask {
    final DRFTree _tree;
    final float _rate;
    Sample( DRFTree tree, float rate ) { _tree = tree; _rate = rate; }
    @Override public void map( Chunk nids, Chunk ys ) {
      Random rand = _tree.rngForChunk(nids.cidx());
      for( int row=0; row= _rate || Double.isNaN(ys.atd(row)) ) {
          nids.set(row, ScoreBuildHistogram.OUT_OF_BAG);     // Flag row as being ignored by sampling
        }
    }
  }

  // Read the 'tree' columns, do model-specific math and put the results in the
  // fs[] array, and return the sum.  Dividing any fs[] element by the sum
  // turns the results into a probability distribution.
  @Override protected double score1( Chunk chks[], double fs[/*nclass*/], int row ) {
    double sum = 0;
    if (_nclass > 2) { //multinomial
      for (int k = 0; k < _nclass; k++)
        sum += (fs[k+1] = chk_tree(chks, k).atd(row));
    }
    else if (_nclass==2) { //binomial optimization
      fs[1] = chk_tree(chks, 0).atd(row);
      assert(fs[1] >= 0 && fs[1] <= 1);
      fs[2] = 1. - fs[1];
    }
    else { //regression
      // average per trees voted for this row (only trees which have row in "out-of-bag"
      sum += (fs[0] = chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row) );
      fs[1] = 0;
    }
    return sum;
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy