All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRFModel Maven / Gradle / Ivy

package hex.tree.drf;

import hex.tree.SharedTreeModel;
import water.Key;
import water.util.MathUtils;
import water.util.SBPrintStream;

import java.util.Arrays;

public class DRFModel extends SharedTreeModel {

  public static class DRFParameters extends SharedTreeModel.SharedTreeParameters {
    public boolean _binomial_double_trees = false;
    public int _mtries = -1; //number of columns to use per split. default depeonds on the algorithm and problem (classification/regression)

    public DRFParameters() {
      super();
      // Set DRF-specific defaults (can differ from SharedTreeModel's defaults)
      _mtries = -1;
      _sample_rate = 0.632f;
      _max_depth = 20;
      _min_rows = 1;
    }
  }

  public static class DRFOutput extends SharedTreeModel.SharedTreeOutput {
    public DRFOutput( DRF b, double mse_train, double mse_valid ) { super(b,mse_train,mse_valid); }
  }

  public DRFModel(Key selfKey, DRFParameters parms, DRFOutput output ) { super(selfKey,parms,output); }

  @Override protected boolean binomialOpt() { return !_parms._binomial_double_trees; }

  /** Bulk scoring API for one row.  Chunks are all compatible with the model,
   *  and expect the last Chunks are for the final distribution and prediction.
   *  Default method is to just load the data into the tmp array, then call
   *  subclass scoring logic. */
  @Override protected double[] score0(double data[], double preds[], double weight, double offset) {
    super.score0(data, preds, weight, offset);
    int N = _output._ntrees;
    if (_output.nclasses() == 1) { // regression - compute avg over all trees
      if (N>=1) preds[0] /= N;
    } else { // classification
      if (_output.nclasses() == 2 && binomialOpt()) {
        if (N>=1) {
          preds[1] /= N; //average probability
        }
        preds[2] = 1. - preds[1];
      } else {
        double sum = MathUtils.sum(preds);
        if (sum > 0) MathUtils.div(preds, sum);
      }
    }
    return preds;
  }

  @Override protected void toJavaUnifyPreds(SBPrintStream body) {
    if (_output.nclasses() == 1) { // Regression
      body.ip("preds[0] /= " + _output._ntrees + ";").nl();
    } else { // Classification
      if( _output.nclasses()==2 && binomialOpt()) { // Kept the initial prediction for binomial
        body.ip("preds[1] /= " + _output._ntrees + ";").nl();
        body.ip("preds[2] = 1.0 - preds[1];").nl();
      } else {
        body.ip("double sum = 0;").nl();
        body.ip("for(int i=1; i0) for(int i=1; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy