hex.tree.drf.DRFModel Maven / Gradle / Ivy
package hex.tree.drf;
import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Chunk;
import water.util.MathUtils;
import water.util.SB;
public class DRFModel extends SharedTreeModel {
public static class DRFParameters extends SharedTreeModel.SharedTreeParameters {
int _mtries = -1;
float _sample_rate = 0.632f;
public boolean _build_tree_one_node = false;
public boolean _binomial_double_trees = true;
public DRFParameters() {
super();
// Set DRF-specific defaults (can differ from SharedTreeModel's defaults)
_ntrees = 50;
_max_depth = 20;
_min_rows = 1;
}
}
public static class DRFOutput extends SharedTreeModel.SharedTreeOutput {
public DRFOutput( DRF b, double mse_train, double mse_valid ) { super(b,mse_train,mse_valid); }
}
public DRFModel(Key selfKey, DRFParameters parms, DRFOutput output ) { super(selfKey,parms,output); }
@Override protected boolean binomialOpt() { return !_parms._binomial_double_trees; }
/** Bulk scoring API for one row. Chunks are all compatible with the model,
* and expect the last Chunks are for the final distribution and prediction.
* Default method is to just load the data into the tmp array, then call
* subclass scoring logic. */
@Override public double[] score0( Chunk chks[], int row_in_chunk, double[] tmp, double[] preds ) {
assert chks.length>=tmp.length;
for( int i=0; i 0) MathUtils.div(preds, sum);
}
if (_parms._balance_classes)
GenModel.correctProbabilities(preds, _output._priorClassDist, _output._modelClassDist);
preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, defaultThreshold());
}
return preds;
}
@Override protected void toJavaUnifyPreds(SB body, SB file) {
if (_output.nclasses() == 1) { // Regression
body.ip("preds[0] /= " + _output._ntrees + ";").nl();
} else { // Classification
if( _output.nclasses()==2 && !_parms._binomial_double_trees) { // Kept the initial prediction for binomial
body.ip("preds[1] /= " + _output._ntrees + ";").nl();
body.ip("preds[2] = 1.0 - preds[1];").nl();
} else {
body.ip("double sum = 0;").nl();
body.ip("for(int i=1; i0) for(int i=1; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy