All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRFModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.drf;

import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Chunk;
import water.util.MathUtils;
import water.util.SB;

public class DRFModel extends SharedTreeModel {

  public static class DRFParameters extends SharedTreeModel.SharedTreeParameters {
    int _mtries = -1;
    float _sample_rate = 0.632f;
    public boolean _build_tree_one_node = false;
    public boolean _binomial_double_trees = true;
    public DRFParameters() {
      super();
      // Set DRF-specific defaults (can differ from SharedTreeModel's defaults)
      _ntrees = 50;
      _max_depth = 20;
      _min_rows = 1;
    }
  }

  public static class DRFOutput extends SharedTreeModel.SharedTreeOutput {
    public DRFOutput( DRF b, double mse_train, double mse_valid ) { super(b,mse_train,mse_valid); }
  }

  public DRFModel(Key selfKey, DRFParameters parms, DRFOutput output ) { super(selfKey,parms,output); }

  @Override protected boolean binomialOpt() { return !_parms._binomial_double_trees; }

  /** Bulk scoring API for one row.  Chunks are all compatible with the model,
   *  and expect the last Chunks are for the final distribution and prediction.
   *  Default method is to just load the data into the tmp array, then call
   *  subclass scoring logic. */
  @Override public double[] score0( Chunk chks[], int row_in_chunk, double[] tmp, double[] preds ) {
    assert chks.length>=tmp.length;
    for( int i=0; i 0) MathUtils.div(preds, sum);
      }
      if (_parms._balance_classes)
        GenModel.correctProbabilities(preds, _output._priorClassDist, _output._modelClassDist);
      preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, defaultThreshold());
    }
    return preds;
  }

  @Override protected void toJavaUnifyPreds(SB body, SB file) {
    if (_output.nclasses() == 1) { // Regression
      body.ip("preds[0] /= " + _output._ntrees + ";").nl();
    } else { // Classification
      if( _output.nclasses()==2 && !_parms._binomial_double_trees) { // Kept the initial prediction for binomial
        body.ip("preds[1] /= " + _output._ntrees + ";").nl();
        body.ip("preds[2] = 1.0 - preds[1];").nl();
      } else {
        body.ip("double sum = 0;").nl();
        body.ip("for(int i=1; i0) for(int i=1; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy