All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRF Maven / Gradle / Ivy

package hex.tree.drf;

import hex.Distribution;
import hex.ModelCategory;
import hex.schemas.DRFV3;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.Log;
import water.util.Timer;

import java.util.Arrays;
import java.util.Random;

import static hex.genmodel.GenModel.getPrediction;
import static hex.tree.drf.TreeMeasuresCollector.asSSE;
import static hex.tree.drf.TreeMeasuresCollector.asVotes;

/** Gradient Boosted Trees
 *
 *  Based on "Elements of Statistical Learning, Second Edition, page 387"
 */
public class DRF extends SharedTree {

  @Override public ModelCategory[] can_build() {
    return new ModelCategory[]{
      ModelCategory.Regression,
      ModelCategory.Binomial,
      ModelCategory.Multinomial,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };

  // Called from an http request
  public DRF( hex.tree.drf.DRFModel.DRFParameters parms) { super("DRF", parms); init(false); }

  @Override public DRFV3 schema() { return new DRFV3(); }

  /** Start the DRF training Job on an F/J thread.
   * @param work
   * @param restartTimer*/
  @Override protected Job trainModelImpl(long work, boolean restartTimer) {
    return start(new DRFDriver(), work, restartTimer);
  }


  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    // Initialize local variables
    if( _parms._mtries < 1 && _parms._mtries != -1 )
      error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + _parms._mtries);
    if( _train != null ) {
      int ncols = _train.numCols();
      if( _parms._mtries != -1 && !(1 <= _parms._mtries && _parms._mtries < ncols /*ncols includes the response*/))
        error("_mtries","Computed mtries should be -1 or in interval [1,"+ncols+"[ but it is " + _parms._mtries);
    }
    if (_parms._distribution == Distribution.Family.AUTO) {
      if (_nclass == 1) _parms._distribution = Distribution.Family.gaussian;
      if (_nclass >= 2) _parms._distribution = Distribution.Family.multinomial;
    }
    if (_parms._sample_rate == 1f && _valid == null)
      error("_sample_rate", "Sample rate is 100% and no validation dataset is specified.  There are no OOB data to compute out-of-bag error estimation!");
    if (hasOffsetCol())
      error("_offset_column", "Offsets are not yet supported for DRF.");
    if (hasOffsetCol() && isClassifier()) {
      error("_offset_column", "Offset is only supported for regression.");
    }
  }

  // ----------------------
  private class DRFDriver extends Driver {
    @Override protected boolean doOOBScoring() { return true; }

    // --- Private data handled only on master node
    // Classification or Regression:
    // Tree votes/SSE of individual trees on OOB rows
    public transient TreeMeasuresCollector.TreeMeasures _treeMeasuresOnOOB;
    // Tree votes/SSE per individual features on permutated OOB rows
    public transient TreeMeasuresCollector.TreeMeasures[/*features*/] _treeMeasuresOnSOOB;
    // Variable importance beased on tree split decisions
    private transient float[/*nfeatures*/] _improvPerVar;

    private void initTreeMeasurements() {
      _improvPerVar = new float[_ncols];
      final int ntrees = _parms._ntrees;
      // Preallocate tree votes
      if (_model._output.isClassifier()) {
        _treeMeasuresOnOOB  = new TreeMeasuresCollector.TreeVotes(ntrees);
        _treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeVotes[_ncols];
        for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeVotes(ntrees);
      } else {
        _treeMeasuresOnOOB  = new TreeMeasuresCollector.TreeSSE(ntrees);
        _treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeSSE[_ncols];
        for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeSSE(ntrees);
      }
    }

    @Override protected void initializeModelSpecifics() {
      _mtry = (_parms._mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
              ( isClassifier() ? Math.max((int)Math.sqrt(_ncols),1) : Math.max(_ncols/3,1))  : _parms._mtries;
      if (!(1 <= _mtry && _mtry <= _ncols)) throw new IllegalArgumentException("Computed mtry should be in interval <1,"+_ncols+"> but it is " + _mtry);
      _initialPrediction = isClassifier() ? 0 : getInitialValue();
      // Initialize TreeVotes for classification, MSE arrays for regression
      initTreeMeasurements();

      /** Fill work columns:
       *   - classification: set 1 in the corresponding wrk col according to row response
       *   - regression:     copy response into work column (there is only 1 work column)
       */
      new MRTask() {
        @Override public void map(Chunk chks[]) {
          Chunk cy = chk_resp(chks);
          for (int i = 0; i < cy._len; i++) {
            if (cy.isNA(i)) continue;
            if (isClassifier()) {
              int cls = (int) cy.at8(i);
              chk_work(chks, cls).set(i, 1L);
            } else {
              float pred = (float) cy.atd(i);
              chk_work(chks, 0).set(i, pred);
            }
          }
        }
      }.doAll(_train);
    }

    // --------------------------------------------------------------------------
    // Build the next random k-trees representing tid-th tree
    @Override protected void buildNextKTrees() {
      // We're going to build K (nclass) trees - each focused on correcting
      // errors for a single class.
      final DTree[] ktrees = new DTree[_nclass];

      // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
      int[] leafs = new int[_nclass];

      // Assign rows to nodes - fill the "NIDs" column(s)
      growTrees(ktrees, leafs, _rand);

      // Move rows into the final leaf rows - fill "Tree" and OUT_BAG_TREES columns and zap the NIDs column
      CollectPreds cp = new CollectPreds(ktrees,leafs,_model.defaultThreshold()).doAll(_train,_parms._build_tree_one_node);

      if (isClassifier())   asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
      else /* regression */ asSSE  (_treeMeasuresOnOOB).append(cp.sse, cp.allRows);

      // Grow the model by K-trees
      _model._output.addKTrees(ktrees);
    }

    // Assumes that the "Work" column are filled with horizontalized (0/1) class memberships per row (or copy of regression response)
    private void growTrees(DTree[] ktrees, int[] leafs, Random rand) {
      // Initial set of histograms.  All trees; one leaf per tree (the root
      // leaf); all columns
      DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];

      // Adjust real bins for the top-levels
      int adj_nbins = Math.max(_parms._nbins_top_level,_parms._nbins);

      // Use for all k-trees the same seed. NOTE: this is only to make a fair
      // view for all k-trees
      long rseed = rand.nextLong();
      // Initially setup as-if an empty-split had just happened
      for (int k = 0; k < _nclass; k++) {
        if (_model._output._distribution[k] != 0) { // Ignore missing classes
          // The Boolean Optimization
          // This optimization assumes the 2nd tree of a 2-class system is the
          // inverse of the first (and that the same columns were picked)
          if( k==1 && _nclass==2 && _model.binomialOpt()) continue;
          ktrees[k] = new DTree(_train, _ncols, (char)_parms._nbins, (char)_parms._nbins_cats, (char)_nclass, _parms._min_rows, _mtry, rseed);
          new UndecidedNode(ktrees[k], -1, DHistogram.initialHist(_train, _ncols, adj_nbins, _parms._nbins_cats, hcs[k][0])); // The "root" node
        }
      }

      // Sample - mark the lines by putting 'OUT_OF_BAG' into nid() vector
      Sample ss[] = new Sample[_nclass];
      for( int k=0; k<_nclass; k++)
        if (ktrees[k] != null) ss[k] = new Sample(ktrees[k], _parms._sample_rate).dfork(null,new Frame(vec_nids(_train,k),vec_resp(_train)), _parms._build_tree_one_node);
      for( int k=0; k<_nclass; k++)
        if( ss[k] != null ) ss[k].getResult();

      // ----
      // One Big Loop till the ktrees are of proper depth.
      // Adds a layer to the trees each pass.
      int depth=0;
      for( ; depth<_parms._max_depth; depth++ ) {
        if( !isRunning() ) return;
        hcs = buildLayer(_train, _parms._nbins, _parms._nbins_cats, ktrees, leafs, hcs, _mtry < _model._output.nfeatures(), _parms._build_tree_one_node);
        // If we did not make any new splits, then the tree is split-to-death
        if( hcs == null ) break;
      }

      // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
      // LeafNodes to hold predictions.
      for( int k=0; k<_nclass; k++ ) {
        DTree tree = ktrees[k];
        if( tree == null ) continue;
        int leaf = leafs[k] = tree.len();
        for( int nid=0; nid {
      /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
      /* @IN */  double _threshold;      // Sum of squares for this tree only
      /* @OUT */ long rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
      /* @OUT */ long allRows;    // number of all OOB rows (sampled by this tree)
      /* @OUT */ float sse;      // Sum of squares for this tree only
      CollectPreds(DTree trees[], int leafs[], double threshold) { _trees=trees; _threshold = threshold; }
      final boolean importance = true;
      @Override public void map( Chunk[] chks ) {
        final Chunk    y       = importance ? chk_resp(chks) : null; // Response
        final double[] rpred   = importance ? new double[1+_nclass] : null; // Row prediction
        final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
        final Chunk   oobt  = chk_oobt(chks); // Out-of-bag rows counter over all trees
        // Iterate over all rows
        for( int row=0; row 2 || (_nclass == 2 && !_model.binomialOpt())) {
      for (int k = 0; k < _nclass; k++)
        sum += (fs[k+1] = chk_tree(chks, k).atd(row) / chk_oobt(chks).atd(row));
    }
    else if (_nclass==2 && _model.binomialOpt()) {
      fs[1] = chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row);
      assert(fs[1] >= 0 && fs[1] <= 1);
      fs[2] = 1. - fs[1];
    }
    else { //regression
      // average per trees voted for this row (only trees which have row in "out-of-bag"
      sum += (fs[0] = chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row) );
      fs[1] = 0;
    }
    return sum;
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy