
hex.tree.SharedTreeModel Maven / Gradle / Ivy
package hex.tree;
import java.util.Arrays;
import hex.AUC;
import hex.ConfusionMatrix2;
import hex.SupervisedModel;
import hex.VarImp;
import water.*;
public abstract class SharedTreeModel, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput> extends SupervisedModel {
public abstract static class SharedTreeParameters extends SupervisedModel.SupervisedParameters {
/** Maximal number of supported levels in response. */
static final int MAX_SUPPORTED_LEVELS = 1000;
public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200
public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7
public int _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values
public int _nbins = 20; // Build a histogram of this many bins, then split at the best point
public boolean _importance = false; // compute variable importance
public long _seed; // Seed for psuedo-random redistribution
// TRUE: Continue extending an existing checkpointed model
// FALSE: Overwrite any prior model
public boolean _checkpoint;
}
public abstract static class SharedTreeOutput extends SupervisedModel.SupervisedOutput {
/** Initially predicted value (for zero trees) */
public double _initialPrediction;
/** Number of trees actually in the model (as opposed to requested) */
public int _ntrees;
/** More indepth tree stats */
final TreeStats _treeStats;
/** Trees get big, so store each one seperately in the DKV. */
public Key[/*_ntrees*/][/*_nclass*/] _treeKeys;
/** r2 metric on validation set: 1-(MSE(model) / MSE(mean)) */
public double _r2;
/** Confusion Matrix for classification models, or null otherwise */
public ConfusionMatrix2 _cm;
/** AUC for binomial models, or null otherwise */
public AUC _auc;
/** Variable Importance, if asked for */
public VarImp _varimp;
public SharedTreeOutput( SharedTree b ) {
super(b);
_ntrees = 0; // No trees yet
_treeKeys = new Key[_ntrees][]; // No tree keys yet
_treeStats = new TreeStats();
}
// Append next set of K trees
public void addKTrees( DTree[] trees ) {
assert nclasses()==trees.length;
_treeStats.updateBy(trees); // Update tree shape stats
// Compress trees and record tree-keys
_treeKeys = Arrays.copyOf(_treeKeys,_ntrees+1);
Key[] keys = _treeKeys[_ntrees] = new Key[trees.length];
Futures fs = new Futures();
for( int i=0; iget().score(data);
}
// Numeric type used in generated code to hold predicted value between the
// calls; i.e. the numerical precision of predictions.
static final String PRED_TYPE = "float";
// TODO: once SpeeDRF inherits from SharedTreeModel, remove this and use a
// v-call for the default compare function (is "<" vs "<=").
boolean isFromSpeeDRF() { return false; }
@Override protected Futures remove_impl( Futures fs ) {
for( Key ks[] : _output._treeKeys)
for( Key k : ks )
if( k != null ) k.remove(fs);
return super.remove_impl(fs);
}
}
// // --------------------------------------------------------------------------
// public static abstract class TreeModel extends hex.Model {
// @API(help="Expected max trees") public final int N;
// @API(help="MSE rate as trees are added") public final double [] errs;
// @API(help="Keys of actual trees built") public final Key [/*N*/][/*nclass*/] treeKeys; // Always filled, but 2-binary classifiers can contain null for 2nd class
// @API(help="Maximum tree depth") public final int max_depth;
// @API(help="Fewest allowed observations in a leaf") public final int min_rows;
// @API(help="Bins in the histograms") public final int nbins;
//
// // For classification models, we'll do a Confusion Matrix right in the
// // model (for now - really should be separate).
// @API(help="Testing key for cm and errs") public final Key testKey;
// // Confusion matrix per each generated tree or null
// @API(help="Confusion Matrix computed on training dataset, cm[actual][predicted]") public final ConfusionMatrix cms[/*CM-per-tree*/];
// @API(help="Confusion matrix domain.") public final String[] cmDomain;
// @API(help="Variable importance for individual input variables.") public final VarImp varimp; // NOTE: in future we can have an array of different variable importance measures (per method)
// @API(help="Tree statistics") public final TreeStats treeStats;
// @API(help="AUC for validation dataset") public final AUCData validAUC;
// @API(help="Whether this is transformed from speedrf") public boolean isFromSpeeDRF=false;
//
// private final int num_folds;
// private transient volatile CompressedTree[/*N*/][/*nclasses OR 1 for regression*/] _treeBitsCache;
//
// public TreeModel( Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int num_folds, float[] priorClassDist, float[] classDist) {
// this(key, dataKey, testKey, names, domains, cmDomain, ntrees, max_depth, min_rows, nbins, num_folds,
// priorClassDist, classDist,
// new Key[0][], new ConfusionMatrix[0], new double[0], null, null, null);
// }
// private TreeModel( Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int num_folds,
// float[] priorClassDist, float[] classDist,
// Key[][] treeKeys, ConfusionMatrix[] cms, double[] errs, TreeStats treeStats, VarImp varimp, AUCData validAUC) {
// super(key,dataKey,names,domains,priorClassDist, classDist);
// this.N = ntrees;
// this.max_depth = max_depth; this.min_rows = min_rows; this.nbins = nbins;
// this.num_folds = num_folds;
// this.treeKeys = treeKeys;
// this.treeStats = treeStats;
// this.cmDomain = cmDomain!=null ? cmDomain : new String[0];;
// this.testKey = testKey;
// this.cms = cms;
// this.errs = errs;
// this.varimp = varimp;
// this.validAUC = validAUC;
// }
// // Simple copy ctor, null value of parameter means copy from prior-model
// protected TreeModel(TreeModel prior, Key[][] treeKeys, double[] errs, ConfusionMatrix[] cms, TreeStats tstats, VarImp varimp, AUCData validAUC) {
// super(prior._key,prior._dataKey,prior._names,prior._domains, prior._priorClassDist,prior._modelClassDist,prior.training_start_time,prior.training_duration_in_ms);
// this.N = prior.N;
// this.testKey = prior.testKey;
// this.max_depth = prior.max_depth;
// this.min_rows = prior.min_rows;
// this.nbins = prior.nbins;
// this.cmDomain = prior.cmDomain;
// this.num_folds = prior.num_folds;
//
// if (treeKeys != null) this.treeKeys = treeKeys; else this.treeKeys = prior.treeKeys;
// if (errs != null) this.errs = errs; else this.errs = prior.errs;
// if (cms != null) this.cms = cms; else this.cms = prior.cms;
// if (tstats != null) this.treeStats = tstats; else this.treeStats = prior.treeStats;
// if (varimp != null) this.varimp = varimp; else this.varimp = prior.varimp;
// if (validAUC != null) this.validAUC = validAUC; else this.validAUC = prior.validAUC;
// }
// // Additional copy ctors to update specific fields
// public TreeModel(TreeModel prior, DTree[] tree, double err, ConfusionMatrix cm, TreeStats tstats) {
// this(prior, append(prior.treeKeys, tree), Utils.append(prior.errs, err), Utils.append(prior.cms, cm), tstats, null, null);
// }
// public TreeModel(TreeModel prior, DTree[] tree, TreeStats tstats) {
// this(prior, append(prior.treeKeys, tree), null, null, tstats, null, null);
// }
// public TreeModel(TreeModel prior, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
// this(prior, null, Utils.append(prior.errs, err), Utils.append(prior.cms, cm), null, varimp, validAUC);
// }
//
// public enum TreeModelType {
// UNKNOWN,
// GBM,
// DRF,
// }
//
// protected TreeModelType getTreeModelType() { return TreeModelType.UNKNOWN; }
//
// /** Returns Producer if the model is under construction else null.
// * The implementation looks for writer lock. If it is present, then returns true.
// *
// * WARNING: the method is strictly for UI used, does not provide any atomicity!!!
*/
// private final Key getProducer() {
// return FetchProducer.fetch(_key);
// }
// private final boolean isProduced() {
// return getProducer()!=null;
// }
//
// private static final class FetchProducer extends DTask {
// final private Key _key;
// private Key _producer;
// public static Key fetch(Key key) {
// FetchProducer fp = new FetchProducer(key);
// if (key.home()) fp.compute2();
// else fp = RPC.call(key.home_node(), fp).get();
// return fp._producer;
// }
// private FetchProducer(Key k) { _key = k; }
// @Override public void compute2() {
// Lockable l = UKV.get(_key);
// _producer = l!=null && l._lockers!=null && l._lockers.length > 0 ? l._lockers[0] : null;
// tryComplete();
// }
// @Override public byte priority() { return H2O.ATOMIC_PRIORITY; }
// }
//
// private static final Key[][] append(Key[][] prior, DTree[] tree ) {
// if (tree==null) return prior;
// prior = Arrays.copyOf(prior, prior.length+1);
// Key ts[] = prior[prior.length-1] = new Key[tree.length];
// for( int c=0; c 0){
// int n = cms.length-1;
// while(n > 0 && cms[n] == null)--n;
// return cms[n] == null?null:cms[n];
// } else return null;
// }
//
// @Override public VarImp varimp() { return varimp; }
// @Override public double mse() {
// if(errs != null && errs.length > 0){
// int n = errs.length-1;
// while(n > 0 && Double.isNaN(errs[n]))--n;
// return errs[n];
// } else return Double.NaN;
// }
// @Override protected float[] score0(double data[], float preds[]) {
// // Prefetch trees into the local cache if it is necessary
// // Invoke scoring
// Arrays.fill(preds,0);
// for( int tidx=0; tidx").append("Actions: ");
// if (_dataKey != null)
// sb.append(Inspect2.link("Inspect training data ("+_dataKey.toString()+")", _dataKey)).append(", ");
// sb.append(Predict.link(_key,"Score on dataset")).append(", ");
// if (_dataKey != null)
// sb.append(UIUtils.builderModelLink(this.getClass(), _dataKey, responseName(), "Compute new model")).append(", ");
// sb.append(UIUtils.qlink(SaveModel.class, "model", _key, "Save model")).append(", ");
// if (isProduced()) { // looks at locker field and check W-locker guy
// sb.append(" ").append(Cancel.link(getProducer(), "Stop training this model"));
// } else {
// sb.append(" ").append(UIUtils.builderLink(this.getClass(), _dataKey, responseName(), this._key, "Continue training this model"));
// }
// sb.append("
Reported on ").append(num_folds).append("-fold cross-validated training data
");
// else {
// sb.append("Reported on ").append(title.contains("DRF") ? "out-of-bag" : "training").append(" data");
// if (num_folds > 0) sb.append(" (cross-validation results are being computed - please reload this page later)");
// sb.append(".");
// if (_priorClassDist!=null && _modelClassDist!=null) sb.append("
Data were resampled to balance class distribution."); // sb.append("
");
// }
// } else {
// RString rs = new RString("Data were resampled to balance class distribution."); // sb.append("
Reported on %key
");
// rs.replace("key", testKey);
// DocGen.HTML.paragraph(sb,rs.toString());
// }
// if (validAUC == null) { //AUC shows the CM already
// // generate HTML for CM
// DocGen.HTML.section(sb, "Confusion Matrix");
// cm.toHTML(sb, domain);
// }
// }
//
// if( errs != null ) {
// if (!isClassifier() && num_folds > 0) {
// if (_have_cv_results)
// DocGen.HTML.section(sb, num_folds + "-fold cross-validated Mean Squared Error: " + String.format("%5.3f", errs[errs.length-1]));
// else
// DocGen.HTML.section(sb, num_folds + "-fold cross-validated Mean Squared Error is being computed - please reload this page later.");
// }
// DocGen.HTML.section(sb,"Mean Squared Error by Tree");
// DocGen.HTML.arrayHead(sb);
// sb.append("");
//
// boolean featureAllowed = isFeatureAllowed();
// if (! featureAllowed) {
// sb.append("
");
// sb.append("");
// }
//
// @Override protected SB toJavaInit(SB sb, SB fileContextSB) {
// sb = super.toJavaInit(sb, fileContextSB);
//
// String modelName = JCodeGen.toJavaId(_key.toString());
//
// // Generate main method with benchmark
// if (GEN_BENCHMARK_CODE) {
// sb.i().p("/**").nl();
// sb.i().p(" * Sample program harness providing an example of how to call predict().").nl();
// sb.i().p(" */").nl();
// sb.i().p("public static void main(String[] args) throws Exception {").nl();
// sb.i(1).p("int iters = args.length > 0 ? Integer.valueOf(args[0]) : DEFAULT_ITERATIONS;").nl();
// sb.i(1).p(modelName).p(" model = new ").p(modelName).p("();").nl();
// sb.i(1).p("model.bench(iters, DataSample.DATA, new float[NCLASSES+1], NTREES);").nl();
// sb.i().p("}").nl();
// sb.di(1);
// sb.p(TO_JAVA_BENCH_FUNC);
// }
//
// JCodeGen.toStaticVar(sb, "NTREES", ntrees(), "Number of trees in this model.");
// JCodeGen.toStaticVar(sb, "NTREES_INTERNAL", ntrees()*nclasses(), "Number of internal trees in this model (= NTREES*NCLASSES).");
// if (GEN_BENCHMARK_CODE) JCodeGen.toStaticVar(sb, "DEFAULT_ITERATIONS", 10000, "Default number of iterations.");
// // Generate a data in separated class since we do not want to influence size of constant pool of model class
// if (GEN_BENCHMARK_CODE) {
// if( _dataKey != null ) {
// Value dataval = DKV.get(_dataKey);
// if (dataval != null) {
// water.fvec.Frame frdata = dataval.get();
// water.fvec.Frame frsub = frdata.subframe(_names);
// JCodeGen.toClass(fileContextSB, "// Sample of data used by benchmark\nclass DataSample", "DATA", frsub, 10, "Sample test data.");
// }
// }
// }
// return sb;
// }
// // Convert Tree model to Java
// @Override protected void toJavaPredictBody( final SB bodySb, final SB classCtxSb, final SB fileCtxSb) {
// // AD-HOC maximal number of trees in forest - in fact constant pool size for Forest class (all UTF String + references to static classes).
// // TODO: in future this parameter can be a parameter for generator, as well as maxIters
// final int maxfsize = 4000;
// int fidx = 0; // forest index
// int treesInForest = 0;
// SB forest = new SB();
// // divide trees into small forests per 100 trees
// /* DEBUG line */ bodySb.i().p("// System.err.println(\"Row (gencode.predict): \" + java.util.Arrays.toString(data));").nl();
// bodySb.i().p("java.util.Arrays.fill(preds,0f);").nl();
// if (isFromSpeeDRF) {
// bodySb.i().p("// Call forest predicting class ").p(0).nl();
// bodySb.i().p("preds").p(" =").p(" Forest_").p(fidx).p("_class_").p(0).p(".predict(data, maxIters - " + fidx * maxfsize + ");").nl();
// }
// for( int c=0; c");
// sb.append("You have requested a premium feature (> 10 trees) and your H2O software is unlicensed.
"); // sb.append("Please enter your email address below, and we will send you a trial license shortly.
"); // sb.append("This will also temporarily enable downloading Java models.
"); // sb.append("
");
// sb.append(""); // sb.append("Please enter your email address below, and we will send you a trial license shortly.
"); // sb.append("This will also temporarily enable downloading Java models.
"); // sb.append("
");
// }
// if( ntrees() * treeStats.meanLeaves > 5000 ) {
// String modelName = JCodeGen.toJavaId(_key.toString());
// sb.append("
"); // close license blog
// sb.append("");
// sb.append("/* Java code is too large to display, download it directly.\n");
// sb.append(" To obtain the code please invoke in your terminal:\n");
// sb.append(" curl http:/").append(H2O.SELF.toString()).append("/h2o-model.jar > h2o-model.jar\n");
// sb.append(" curl http:/").append(H2O.SELF.toString()).append("/2/").append(this.getClass().getSimpleName()).append("View.java?_modelKey=").append(_key).append(" > ").append(modelName).append(".java\n");
// sb.append(" javac -cp h2o-model.jar -J-Xmx2g -J-XX:MaxPermSize=128m ").append(modelName).append(".java\n");
// if (GEN_BENCHMARK_CODE)
// sb.append(" java -cp h2o-model.jar:. -Xmx2g -XX:MaxPermSize=256m -XX:ReservedCodeCacheSize=256m ").append(modelName).append('\n');
// sb.append("*/");
// sb.append("
");
// } else {
// sb.append("");
// DocGen.HTML.escape(sb, toJava());
// sb.append("
");
// }
// if (!featureAllowed) sb.append("