
hex.tree.SharedTreeModel Maven / Gradle / Ivy
package hex.tree;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import hex.Distribution;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ScoreKeeper;
import hex.VarImp;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.JCodeSB;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.RandomUtils;
import water.util.SB;
import water.util.SBPrintStream;
import water.util.TwoDimTable;
public abstract class SharedTreeModel, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput> extends Model {
public abstract static class SharedTreeParameters extends Model.Parameters {
public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7
public double _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values
public int _nbins = 20; // Numerical (real/int) cols: Build a histogram of this many bins, then split at the best point
public int _nbins_cats = 1024; // Categorical (factor) cols: Build a histogram of this many bins, then split at the best point
public double _r2_stopping = 0.999999; // Stop when the r^2 metric equals or exceeds this value
public long _seed = RandomUtils.getRNG(System.nanoTime()).nextLong();
public int _nbins_top_level = 1<<10; //hardcoded maximum top-level number of bins for real-valued columns
public boolean _build_tree_one_node = false;
public int _initial_score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring the first 4 secs
public int _score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring each iteration every 4 secs
public float _sample_rate = 0.632f; //fraction of rows to sample for each tree
/** Fields which can NOT be modified if checkpoint is specified.
* FIXME: should be defined in Schema API annotation
*/
private static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[] { "_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};
protected String[] getCheckpointNonModifiableFields() {
return CHECKPOINT_NON_MODIFIABLE_FIELDS;
}
/** This method will take actual parameters and validate them with parameters of
* requested checkpoint. In case of problem, it throws an API exception.
*
* @param checkpointParameters checkpoint parameters
*/
public void validateWithCheckpoint(SharedTreeParameters checkpointParameters) {
for (Field fAfter : this.getClass().getFields()) {
// only look at non-modifiable fields
if (ArrayUtils.contains(getCheckpointNonModifiableFields(),fAfter.getName())) {
for (Field fBefore : checkpointParameters.getClass().getFields()) {
if (fBefore.equals(fAfter)) {
try {
if (!PojoUtils.equals(this, fAfter, checkpointParameters, checkpointParameters.getClass().getField(fAfter.getName()))) {
throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " cannot be modified if checkpoint is specified!");
}
} catch (NoSuchFieldException e) {
throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " is not supported by checkpoint!");
}
}
}
}
}
}
}
@Override
public double deviance(double w, double y, double f) {
return new Distribution(_parms._distribution, _parms._tweedie_power).deviance(w, y, f);
}
final public VarImp varImp() { return _output._varimp; }
@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
switch(_output.getModelCategory()) {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
default: throw H2O.unimpl();
}
}
public abstract static class SharedTreeOutput extends Model.Output {
/** InitF value (for zero trees)
* f0 = mean(yi) for gaussian
* f0 = log(yi/1-yi) for bernoulli
*
* For GBM bernoulli, the initial prediction for 0 trees is
* p = 1/(1+exp(-f0))
*
* From this, the mse for 0 trees (null model) can be computed as follows:
* mean((yi-p)^2)
* */
public double _init_f;
/** Number of trees actually in the model (as opposed to requested) */
public int _ntrees;
/** More indepth tree stats */
final public TreeStats _treeStats;
/** Trees get big, so store each one seperately in the DKV. */
public Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
public ScoreKeeper _scored_train[/*ntrees+1*/];
public ScoreKeeper _scored_valid[/*ntrees+1*/];
public ScoreKeeper[] scoreKeepers() {
List sk = new ArrayList<>();
ScoreKeeper[] ska = _validation_metrics != null ? _scored_valid : _scored_train;
for (int i=0;iget().score(data);
assert (!Double.isInfinite(pred));
preds[keys.length == 1 ? 0 : c + 1] += pred;
}
}
}
@Override protected Futures remove_impl( Futures fs ) {
for( Key ks[] : _output._treeKeys)
for( Key k : ks )
if( k != null ) k.remove(fs);
return super.remove_impl(fs);
}
// Override in subclasses to provide some top-level model-specific goodness
@Override protected boolean toJavaCheckTooBig() {
// If the number of leaves in a forest is more than N, don't try to render it in the browser as POJO code.
return _output==null || _output._treeStats._num_trees * _output._treeStats._mean_leaves > 1000000;
}
protected boolean binomialOpt() { return true; }
@Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
sb.nl();
sb.ip("public boolean isSupervised() { return true; }").nl();
sb.ip("public int nfeatures() { return "+_output.nfeatures()+"; }").nl();
sb.ip("public int nclasses() { return "+_output.nclasses()+"; }").nl();
return sb;
}
@Override protected void toJavaPredictBody(SBPrintStream body,
CodeGeneratorPipeline classCtx,
CodeGeneratorPipeline fileCtx,
final boolean verboseCode) {
final int nclass = _output.nclasses();
body.ip("java.util.Arrays.fill(preds,0);").nl();
body.ip("double[] fdata = hex.genmodel.GenModel.SharedTree_clean(data);").nl();
final String mname = JCodeGen.toJavaId(_key.toString());
// One forest-per-GBM-tree, with a real-tree-per-class
for (int t=0; t < _output._treeKeys.length; t++) {
// Generate score method for given tree
toJavaForestName(body.i(),mname,t).p(".score0(fdata,preds);").nl();
final int treeIdx = t;
fileCtx.add(new CodeGenerator() {
@Override
public void generate(JCodeSB out) {
// Generate a class implementing a tree
out.nl();
toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
for( int c=0; c T toJavaTreeName(final T sb, String mname, int t, int c ) {
return (T) sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
}
protected T toJavaForestName(final T sb, String mname, int t ) {
return (T) sb.p(mname).p("_Forest_").p(t);
}
@Override
public List getPublishedKeys() {
assert _output._ntrees == _output._treeKeys.length :
"Tree model is inconsistent: number of trees do not match number of tree keys!";
List superP = super.getPublishedKeys();
List p = new ArrayList(_output._ntrees * _output.nclasses());
for (int i = 0; i < _output._treeKeys.length; i++) {
for (int j = 0; j < _output._treeKeys[i].length; j++) {
p.add(_output._treeKeys[i][j]);
}
}
p.addAll(superP);
return p;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy