All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRF Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.drf;

import hex.Model;
import hex.ModelCategory;
import hex.PojoWriter;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.drf.DrfMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;

import java.util.Arrays;
import java.util.Random;

import static hex.genmodel.GenModel.getPrediction;
import static hex.tree.drf.TreeMeasuresCollector.asSSE;
import static hex.tree.drf.TreeMeasuresCollector.asVotes;

/** Distributed Random Forest
 */
public class DRF extends SharedTree {
  private static final double ONEBOUND=1+1e-12;    // due to fixed precision
  private static final double ZEROBOUND=-1e-12;    // due to fixed precision
  @Override public ModelCategory[] can_build() {
    return new ModelCategory[]{
      ModelCategory.Regression,
      ModelCategory.Binomial,
      ModelCategory.Multinomial,
    };
  }

  // Called from an http request
  public DRF( hex.tree.drf.DRFModel.DRFParameters parms                   ) { super(parms     ); init(false); }
  public DRF( hex.tree.drf.DRFModel.DRFParameters parms, Key key) { super(parms, key); init(false); }
  public DRF( hex.tree.drf.DRFModel.DRFParameters parms, Job job          ) { super(parms, job); init(false); }
  public DRF(boolean startup_once) { super(new hex.tree.drf.DRFModel.DRFParameters(),startup_once); }

  /** Start the DRF training Job on an F/J thread. */
  @Override protected Driver trainModelImpl() { return new DRFDriver(); }

  @Override public boolean scoreZeroTrees() { return false; }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    // Initialize local variables
    if( _parms._mtries < 1 && _parms._mtries != -1 && _parms._mtries != -2 )
      error("_mtries", "mtries must be -1 (converted to sqrt(features)) or -2 (All features) or >= 1 but it is " + _parms._mtries);
    if( _train != null ) {
      int ncols = _train.numCols();
      if( _parms._mtries != -1 && _parms._mtries != -2 && !(1 <= _parms._mtries && _parms._mtries < ncols /*ncols includes the response*/))
        error("_mtries","Computed mtries should be -1 or -2 or in interval [1,"+ncols+"[ but it is " + _parms._mtries);
    }
    DistributionFamily[] allowed_distributions = new DistributionFamily[] {
            DistributionFamily.AUTO,
            DistributionFamily.bernoulli,
            DistributionFamily.multinomial,
            DistributionFamily.gaussian,
    };
    if (!ArrayUtils.contains(allowed_distributions, _parms._distribution))
      error("_distribution", _parms._distribution.name() + " distribution is not supported for DRF in current H2O.");
    if (_parms._distribution == DistributionFamily.AUTO) {
      if (_nclass == 1) _parms._distribution = DistributionFamily.gaussian;
      if (_nclass >= 2) _parms._distribution = DistributionFamily.multinomial;
    }
    if (_parms._sample_rate == 1f && _valid == null && _parms._nfolds == 0)
      warn("_sample_rate", "Sample rate is 100% and no validation dataset and no cross-validation. There are no out-of-bag data to compute error estimates on the training data!");
    if (hasOffsetCol())
      error("_offset_column", "Offsets are not yet supported for DRF.");
  }

  // ----------------------
  private class DRFDriver extends Driver {
    @Override protected boolean doOOBScoring() { return true; }

    // --- Private data handled only on master node
    // Classification or Regression:
    // Tree votes/SSE of individual trees on OOB rows
    public transient TreeMeasuresCollector.TreeMeasures _treeMeasuresOnOOB;
    // Tree votes/SSE per individual features on permutated OOB rows
    public transient TreeMeasuresCollector.TreeMeasures[/*features*/] _treeMeasuresOnSOOB;
    // Variable importance based on tree split decisions
    private transient float[/*nfeatures*/] _improvPerVar;

    private void initTreeMeasurements() {
      _improvPerVar = new float[_ncols];
      final int ntrees = _parms._ntrees;
      // Preallocate tree votes
      if (_model._output.isClassifier()) {
        _treeMeasuresOnOOB  = new TreeMeasuresCollector.TreeVotes(ntrees);
        _treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeVotes[_ncols];
        for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeVotes(ntrees);
      } else {
        _treeMeasuresOnOOB  = new TreeMeasuresCollector.TreeSSE(ntrees);
        _treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeSSE[_ncols];
        for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeMeasuresCollector.TreeSSE(ntrees);
      }
    }

    @Override protected void initializeModelSpecifics() {
      _mtry_per_tree = Math.max(1, (int)(_parms._col_sample_rate_per_tree * _ncols));
      if (!(1 <= _mtry_per_tree && _mtry_per_tree <= _ncols)) throw new IllegalArgumentException("Computed mtry_per_tree should be in interval <1,"+_ncols+"> but it is " + _mtry_per_tree);
      if(_parms._mtries==-2){ //mtries set to -2 would use all columns in each split regardless of what column has been dropped during train
        _mtry = _ncols;
      }else if(_parms._mtries==-1) {
        _mtry = (isClassifier() ? Math.max((int) Math.sqrt(_ncols), 1) : Math.max(_ncols / 3, 1)); // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
      }else{
        _mtry = _parms._mtries;
      }
      if (!(1 <= _mtry && _mtry <= _ncols)) {
        throw new IllegalArgumentException("Computed mtry should be in interval <1," + _ncols + "> but it is " + _mtry);
      }
      if (_model != null && _model.evalAutoParamsEnabled) {
        _model.initActualParamValuesAfterOutputSetup(isClassifier());
      }

      _initialPrediction = isClassifier() ? 0 : getInitialValue();
      // Initialize TreeVotes for classification, MSE arrays for regression
      initTreeMeasurements();

      /** Fill work columns:
       *   - classification: set 1 in the corresponding wrk col according to row response
       *   - regression:     copy response into work column (there is only 1 work column)
       */
      new MRTask() {
        @Override public void map(Chunk chks[]) {
          Chunk cy = chk_resp(chks);
          for (int i = 0; i < cy._len; i++) {
            if (cy.isNA(i)) continue;
            if (isClassifier()) {
              int cls = (int) cy.at8(i);
              chk_work(chks, cls).set(i, 1L);
            } else {
              float pred = (float) cy.atd(i);
              chk_work(chks, 0).set(i, pred);
            }
          }
        }
      }.doAll(_train);
    }

    // --------------------------------------------------------------------------
    // Build the next random k-trees representing tid-th tree
    @Override protected boolean buildNextKTrees() {
      // We're going to build K (nclass) trees - each focused on correcting
      // errors for a single class.
      final DTree[] ktrees = new DTree[_nclass];

      // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
      int[] leafs = new int[_nclass];

      // Assign rows to nodes - fill the "NIDs" column(s)
      growTrees(ktrees, leafs, _rand);

      // Move rows into the final leaf rows - fill "Tree" and OUT_BAG_TREES columns and zap the NIDs column
      CollectPreds cp = new CollectPreds(ktrees,leafs,_model.defaultThreshold()).doAll(_train,_parms._build_tree_one_node);

      if (isClassifier())   asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
      else /* regression */ asSSE  (_treeMeasuresOnOOB).append(cp.sse, cp.allRows);

      // Grow the model by K-trees
      _model._output.addKTrees(ktrees);

      return false; //never stop early
    }

    // Assumes that the "Work" column are filled with horizontalized (0/1) class memberships per row (or copy of regression response)
    private void growTrees(DTree[] ktrees, int[] leafs, Random rand) {
      // Initial set of histograms.  All trees; one leaf per tree (the root
      // leaf); all columns
      DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];

      // Adjust real bins for the top-levels
      int adj_nbins = Math.max(_parms._nbins_top_level,_parms._nbins);

      // Use for all k-trees the same seed. NOTE: this is only to make a fair
      // view for all k-trees
      long rseed = rand.nextLong();
      // Initially setup as-if an empty-split had just happened
      for (int k = 0; k < _nclass; k++) {
        if (_model._output._distribution[k] != 0) { // Ignore missing classes
          // The Boolean Optimization
          // This optimization assumes the 2nd tree of a 2-class system is the
          // inverse of the first (and that the same columns were picked)
          if( k==1 && _nclass==2 && _model.binomialOpt()) continue;
          ktrees[k] = new DTree(_train, _ncols, _mtry, _mtry_per_tree, rseed, _parms);
          new UndecidedNode(ktrees[k], -1, DHistogram.initialHist(_train, _ncols, adj_nbins, hcs[k][0], rseed, _parms, getGlobalSplitPointsKeys(), null, true, null), null, null); // The "root" node
        }
      }

      // Sample - mark the lines by putting 'OUT_OF_BAG' into nid() vector
      Sample ss[] = new Sample[_nclass];
      for( int k=0; k<_nclass; k++)
        if (ktrees[k] != null) ss[k] = new Sample(ktrees[k], _parms._sample_rate, _parms._sample_rate_per_class).dfork(null,new Frame(vec_nids(_train,k),vec_resp(_train)), _parms._build_tree_one_node);
      for( int k=0; k<_nclass; k++)
        if( ss[k] != null ) ss[k].getResult();

      // ----
      // One Big Loop till the ktrees are of proper depth.
      // Adds a layer to the trees each pass.
      int depth=0;
      for( ; depth<_parms._max_depth; depth++ ) {
        hcs = buildLayer(_train, _parms._nbins, ktrees, leafs, hcs, _parms._build_tree_one_node);
        // If we did not make any new splits, then the tree is split-to-death
        if( hcs == null ) break;
      }

      // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
      // LeafNodes to hold predictions.
      for( int k=0; k<_nclass; k++ ) {
        DTree tree = ktrees[k];
        if( tree == null ) continue;
        int leaf = leafs[k] = tree.len();
        for( int nid=0; nid {
      /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
      /* @IN */  double _threshold;      // Sum of squares for this tree only
      /* @OUT */ double rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
      /* @OUT */ double allRows;    // number of all OOB rows (sampled by this tree)
      /* @OUT */ float sse;      // Sum of squares for this tree only
      CollectPreds(DTree trees[], int leafs[], double threshold) { _trees=trees; _threshold = threshold; }
      final boolean importance = true;
      @Override public void map( Chunk[] chks ) {
        final Chunk    y       = importance ? chk_resp(chks) : null; // Response
        final double[] rpred   = importance ? new double[1+_nclass] : null; // Row prediction
        final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
        final Chunk   oobt  = chk_oobt(chks); // Out-of-bag rows counter over all trees
        final Chunk   weights  = hasWeightCol() ? chk_weight(chks) : new C0DChunk(1, chks[0]._len); // Out-of-bag rows counter over all trees
        // Iterate over all rows
        for( int row=0; row 2 || (_nclass == 2 && !_model.binomialOpt())) {
      for (int k = 0; k < _nclass; k++)
        sum += (fs[k+1] = weight * chk_tree(chks, k).atd(row) / chk_oobt(chks).atd(row));
    }
    else if (_nclass==2 && _model.binomialOpt()) {
      fs[1] = weight * chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row);
      if (fs[1]>1 && fs[1]<=ONEBOUND)
        fs[1] = 1.0;
      else if (fs[1]<0 && fs[1]>=ZEROBOUND)
        fs[1] = 0.0;
      assert(fs[1] >= 0 && fs[1] <= 1);
      fs[2] = 1. - fs[1];
    }
    else { //regression
      // average per trees voted for this row (only trees which have row in "out-of-bag"
      sum += (fs[0] = weight * chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row) );
      fs[1] = 0;
    }
    return sum;
  }

  @Override
  public PojoWriter makePojoWriter(Model genericModel, MojoModel mojoModel) {
    DrfMojoModel drfMojoModel = (DrfMojoModel) mojoModel;
    CompressedTree[][] trees = MojoUtils.extractCompressedTrees(drfMojoModel);
    boolean binomialOpt = MojoUtils.isUsingBinomialOpt(drfMojoModel, trees);
    return new DrfPojoWriter(genericModel, drfMojoModel.getCategoricalEncoding(), binomialOpt, trees, drfMojoModel._balanceClasses);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy