hex.tree.SharedTreeModel Maven / Gradle / Ivy
package hex.tree;
import hex.*;
import water.*;
import water.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public abstract class SharedTreeModel, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput> extends Model {
public abstract static class SharedTreeParameters extends Model.Parameters {
/** Maximal number of supported levels in response. */
static final int MAX_SUPPORTED_LEVELS = 1000;
public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7
public int _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values
public int _nbins = 20; // Numerical (real/int) cols: Build a histogram of this many bins, then split at the best point
public int _nbins_cats = 1024; // Categorical (enum) cols: Build a histogram of this many bins, then split at the best point
public double _r2_stopping = 0.999999; // Stop when the r^2 metric equals or exceeds this value
public long _seed = RandomUtils.getRNG(System.nanoTime()).nextLong();
// TRUE: Continue extending an existing checkpointed model
// FALSE: Overwrite any prior model
public boolean _checkpoint;
public int _nbins_top_level = 1<<10; //hardcoded minimum top-level number of bins for real-valued columns (not currently user-facing)
}
final public VarImp varImp() { return _output._varimp; }
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
default: throw H2O.unimpl();
}
}
public abstract static class SharedTreeOutput extends Model.Output {
/** InitF value (for zero trees)
* f0 = mean(yi) for gaussian
* f0 = log(yi/1-yi) for bernoulli
*
* For GBM bernoulli, the initial prediction for 0 trees is
* p = 1/(1+exp(-f0))
*
* From this, the mse for 0 trees can be computed as follows:
* mean((yi-p)^2)
* This is what is stored in _scored_train[0]
* */
public double _init_f;
/** Number of trees actually in the model (as opposed to requested) */
public int _ntrees;
/** More indepth tree stats */
final public TreeStats _treeStats;
/** Trees get big, so store each one seperately in the DKV. */
public Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
public ScoreKeeper _scored_train[/*ntrees+1*/];
public ScoreKeeper _scored_valid[/*ntrees+1*/];
/** Training time */
public long _training_time_ms[/*ntrees+1*/] = new long[]{System.currentTimeMillis()};
/**
* Variable importances computed during training
*/
public TwoDimTable _variable_importances;
public VarImp _varimp;
public SharedTreeOutput( SharedTree b, double mse_train, double mse_valid ) {
super(b);
_ntrees = 0; // No trees yet
_treeKeys = new Key[_ntrees][]; // No tree keys yet
_treeStats = new TreeStats();
_scored_train = new ScoreKeeper[]{new ScoreKeeper(mse_train)};
_scored_valid = new ScoreKeeper[]{new ScoreKeeper(mse_valid)};
_modelClassDist = _priorClassDist;
}
// Append next set of K trees
public void addKTrees( DTree[] trees) {
// DEBUG: Print the generated K trees
//SharedTree.printGenerateTrees(trees);
assert nclasses()==trees.length;
// Compress trees and record tree-keys
_treeKeys = Arrays.copyOf(_treeKeys ,_ntrees+1);
Key[] keys = _treeKeys[_ntrees] = new Key[trees.length];
Futures fs = new Futures();
for( int i=0; iget().score(data);
assert(!Double.isInfinite(pred));
preds[keys.length == 1 ? 0 : c + 1] += pred;
}
}
@Override protected Futures remove_impl( Futures fs ) {
for( Key ks[] : _output._treeKeys)
for( Key k : ks )
if( k != null ) k.remove(fs);
return super.remove_impl(fs);
}
// Override in subclasses to provide some top-level model-specific goodness
@Override protected boolean toJavaCheckTooBig() {
// If the number of leaves in a forest is more than N, don't try to render it in the browser as POJO code.
return _output==null || _output._treeStats._num_trees * _output._treeStats._mean_leaves > 1000000;
}
protected boolean binomialOpt() { return true; }
@Override protected SB toJavaInit(SB sb, SB fileContext) {
sb.nl();
sb.ip("public boolean isSupervised() { return true; }").nl();
sb.ip("public int nfeatures() { return "+_output.nfeatures()+"; }").nl();
sb.ip("public int nclasses() { return "+_output.nclasses()+"; }").nl();
return sb;
}
@Override protected void toJavaPredictBody(SB body, SB classCtx, SB file) {
final int nclass = _output.nclasses();
body.ip("java.util.Arrays.fill(preds,0);").nl();
body.ip("double[] fdata = hex.genmodel.GenModel.SharedTree_clean(data);").nl();
String mname = JCodeGen.toJavaId(_key.toString());
// One forest-per-GBM-tree, with a real-tree-per-class
for( int t=0; t < _output._treeKeys.length; t++ ) {
toJavaForestName(body.i(),mname,t).p(".score0(fdata,preds);").nl();
file.nl();
toJavaForestName(file.ip("class "),mname,t).p(" {").nl().ii(1);
file.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
for( int c=0; c getPublishedKeys() {
assert _output._ntrees == _output._treeKeys.length :
"Tree model is inconsistent: number of trees do not match number of tree keys!";
List superP = super.getPublishedKeys();
List p = new ArrayList(_output._ntrees * _output.nclasses());
for (int i = 0; i < _output._treeKeys.length; i++) {
for (int j = 0; j < _output._treeKeys[i].length; j++) {
p.add(_output._treeKeys[i][j]);
}
}
p.addAll(superP);
return p;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy