All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.SharedTreeModel Maven / Gradle / Ivy

package hex.tree;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import hex.Distribution;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ScoreKeeper;
import hex.VarImp;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.JCodeSB;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.RandomUtils;
import water.util.SB;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

public abstract class SharedTreeModel, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput> extends Model {

  public abstract static class SharedTreeParameters extends Model.Parameters {

    public int _ntrees=50; // Number of trees in the final model. Grid Search, comma sep values:50,100,150,200

    public int _max_depth = 5; // Maximum tree depth. Grid Search, comma sep values:5,7

    public double _min_rows = 10; // Fewest allowed observations in a leaf (in R called 'nodesize'). Grid Search, comma sep values

    public int _nbins = 20; // Numerical (real/int) cols: Build a histogram of this many bins, then split at the best point

    public int _nbins_cats = 1024; // Categorical (factor) cols: Build a histogram of this many bins, then split at the best point

    public double _r2_stopping = 0.999999; // Stop when the r^2 metric equals or exceeds this value

    public long _seed = RandomUtils.getRNG(System.nanoTime()).nextLong();

    public int _nbins_top_level = 1<<10; //hardcoded maximum top-level number of bins for real-valued columns

    public boolean _build_tree_one_node = false;

    public int _initial_score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring the first  4 secs

    public int _score_interval = 4000; //Adding this parameter to take away the hard coded value of 4000 for scoring each iteration every 4 secs

    public float _sample_rate = 0.632f; //fraction of rows to sample for each tree

    /** Fields which can NOT be modified if checkpoint is specified.
     * FIXME: should be defined in Schema API annotation
     */
    private static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[] { "_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};

    protected String[] getCheckpointNonModifiableFields() {
      return CHECKPOINT_NON_MODIFIABLE_FIELDS;
    }

    /** This method will take actual parameters and validate them with parameters of
     * requested checkpoint. In case of problem, it throws an API exception.
     *
     * @param checkpointParameters checkpoint parameters
     */
    public void validateWithCheckpoint(SharedTreeParameters checkpointParameters) {
      for (Field fAfter : this.getClass().getFields()) {
        // only look at non-modifiable fields
        if (ArrayUtils.contains(getCheckpointNonModifiableFields(),fAfter.getName())) {
          for (Field fBefore : checkpointParameters.getClass().getFields()) {
            if (fBefore.equals(fAfter)) {
              try {
                if (!PojoUtils.equals(this, fAfter, checkpointParameters, checkpointParameters.getClass().getField(fAfter.getName()))) {
                  throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " cannot be modified if checkpoint is specified!");
                }
              } catch (NoSuchFieldException e) {
                throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", "Field " + fAfter.getName() + " is not supported by checkpoint!");
              }
            }
          }
        }
      }
    }
  }

  @Override
  public double deviance(double w, double y, double f) {
    return new Distribution(_parms._distribution, _parms._tweedie_power).deviance(w, y, f);
  }

  final public VarImp varImp() { return _output._varimp; }

  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    switch(_output.getModelCategory()) {
      case Binomial:    return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
      case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain);
      case Regression:  return new ModelMetricsRegression.MetricBuilderRegression();
      default: throw H2O.unimpl();
    }
  }

  public abstract static class SharedTreeOutput extends Model.Output {
    /** InitF value (for zero trees)
     *  f0 = mean(yi) for gaussian
     *  f0 = log(yi/1-yi) for bernoulli
     *
     *  For GBM bernoulli, the initial prediction for 0 trees is
     *  p = 1/(1+exp(-f0))
     *
     *  From this, the mse for 0 trees (null model) can be computed as follows:
     *  mean((yi-p)^2)
     * */
    public double _init_f;

    /** Number of trees actually in the model (as opposed to requested) */
    public int _ntrees;

    /** More indepth tree stats */
    final public TreeStats _treeStats;

    /** Trees get big, so store each one seperately in the DKV. */
    public Key[/*_ntrees*/][/*_nclass*/] _treeKeys;

    public ScoreKeeper _scored_train[/*ntrees+1*/];
    public ScoreKeeper _scored_valid[/*ntrees+1*/];
    public ScoreKeeper[] scoreKeepers() {
      List sk = new ArrayList<>();
      ScoreKeeper[] ska = _validation_metrics != null ? _scored_valid : _scored_train;
      for (int i=0;iget().score(data);
        assert (!Double.isInfinite(pred));
        preds[keys.length == 1 ? 0 : c + 1] += pred;
      }
    }
  }

  @Override protected Futures remove_impl( Futures fs ) {
    for( Key ks[] : _output._treeKeys)
      for( Key k : ks )
        if( k != null ) k.remove(fs);
    return super.remove_impl(fs);
  }

  // Override in subclasses to provide some top-level model-specific goodness
  @Override protected boolean toJavaCheckTooBig() {
    // If the number of leaves in a forest is more than N, don't try to render it in the browser as POJO code.
    return _output==null || _output._treeStats._num_trees * _output._treeStats._mean_leaves > 1000000;
  }
  protected boolean binomialOpt() { return true; }
  @Override protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
    sb.nl();
    sb.ip("public boolean isSupervised() { return true; }").nl();
    sb.ip("public int nfeatures() { return "+_output.nfeatures()+"; }").nl();
    sb.ip("public int nclasses() { return "+_output.nclasses()+"; }").nl();
    return sb;
  }
  @Override protected void toJavaPredictBody(SBPrintStream body,
                                             CodeGeneratorPipeline classCtx,
                                             CodeGeneratorPipeline fileCtx,
                                             final boolean verboseCode) {
    final int nclass = _output.nclasses();
    body.ip("java.util.Arrays.fill(preds,0);").nl();
    body.ip("double[] fdata = hex.genmodel.GenModel.SharedTree_clean(data);").nl();
    final String mname = JCodeGen.toJavaId(_key.toString());

    // One forest-per-GBM-tree, with a real-tree-per-class
    for (int t=0; t < _output._treeKeys.length; t++) {
      // Generate score method for given tree
      toJavaForestName(body.i(),mname,t).p(".score0(fdata,preds);").nl();

      final int treeIdx = t;

      fileCtx.add(new CodeGenerator() {
        @Override
        public void generate(JCodeSB out) {
          // Generate a class implementing a tree
          out.nl();
          toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
          out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
          for( int c=0; c T toJavaTreeName(final T sb, String mname, int t, int c ) {
    return (T) sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
  }
  protected  T toJavaForestName(final T sb, String mname, int t ) {
    return (T) sb.p(mname).p("_Forest_").p(t);
  }

  @Override
  public List getPublishedKeys() {
    assert _output._ntrees == _output._treeKeys.length :
            "Tree model is inconsistent: number of trees do not match number of tree keys!";
    List superP = super.getPublishedKeys();
    List p = new ArrayList(_output._ntrees * _output.nclasses());
    for (int i = 0; i < _output._treeKeys.length; i++) {
      for (int j = 0; j < _output._treeKeys[i].length; j++) {
        p.add(_output._treeKeys[i][j]);
      }
    }
    p.addAll(superP);
    return p;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy