All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.gbm.GBMModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.gbm;

import hex.Distribution;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.Score;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Chunk;
import water.util.SBPrintStream;

import java.util.Arrays;

public class GBMModel extends SharedTreeModel {

  public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
    public double _learn_rate;
    public double _learn_rate_annealing;
    public double _col_sample_rate;
    public double _max_abs_leafnode_pred;
    public double _pred_noise_bandwidth;

    public GBMParameters() {
      super();
      _learn_rate = 0.1;
      _learn_rate_annealing = 1.0;
      _col_sample_rate = 1.0;
      _sample_rate = 1.0;
      _ntrees = 50;
      _max_depth = 5;
      _max_abs_leafnode_pred = Double.MAX_VALUE;
      _pred_noise_bandwidth =0;
    }

    public String algoName() { return "GBM"; }
    public String fullName() { return "Gradient Boosting Machine"; }
    public String javaName() { return GBMModel.class.getName(); }
  }

  public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
    boolean _quasibinomial;
    int _nclasses;
    public int nclasses() {
      return _nclasses;
    }
    public GBMOutput(GBM b) {
      super(b);
      _quasibinomial = b._parms._distribution == DistributionFamily.quasibinomial;
      _nclasses = b.nclasses();
    }
    @Override
    public String[] classNames() {
      String [] res = super.classNames();
      if(res == null && _quasibinomial)
        return new String[]{"0", "1"};
      return res;
    }
  }


  public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
    super(selfKey,parms,output);
  }

  @Override
  protected final double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
    assert _output.nfeatures() == tmp.length;
    for (int i = 0; i < tmp.length; i++)
      tmp[i] = chks[i].atd(row_in_chunk);

    if (sii._startTree == 0)
      Arrays.fill(preds,0);
    else
      for (int i = 0; i < sii._workspaceColCnt; i++)
        preds[sii._predsAryOffset + i] = chks[sii._workspaceColIdx + i].atd(row_in_chunk);

    score0(tmp, preds, offset, sii._startTree, _output._treeKeys.length);

    for (int i = 0; i < sii._workspaceColCnt; i++)
      chks[sii._workspaceColIdx + i].set(row_in_chunk, preds[sii._predsAryOffset + i]);

    score0Probabilities(preds, offset);
    score0PostProcessSupervised(preds, tmp);
    return preds;
  }

  /** Bulk scoring API for one row.  Chunks are all compatible with the model,
   *  and expect the last Chunks are for the final distribution and prediction.
   *  Default method is to just load the data into the tmp array, then call
   *  subclass scoring logic. */
  @Override protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/], double offset, int ntrees) {
    super.score0(data, preds, offset, ntrees);    // These are f_k(x) in Algorithm 10.4
    return score0Probabilities(preds, offset);
  }

  private double[] score0Probabilities(double preds[/*nclasses+1*/], double offset) {
    if (_parms._distribution == DistributionFamily.bernoulli
        || _parms._distribution == DistributionFamily.quasibinomial
        || _parms._distribution == DistributionFamily.modified_huber) {
      double f = preds[1] + _output._init_f + offset; //Note: class 1 probability stored in preds[1] (since we have only one tree)
      preds[2] = new Distribution(_parms).linkInv(f);
      preds[1] = 1.0 - preds[2];
    } else if (_parms._distribution == DistributionFamily.multinomial) { // Kept the initial prediction for binomial
      if (_output.nclasses() == 2) { //1-tree optimization for binomial
        preds[1] += _output._init_f + offset; //offset is not yet allowed, but added here to be future-proof
        preds[2] = -preds[1];
      }
      hex.genmodel.GenModel.GBM_rescale(preds);
    } else { //Regression
      double f = preds[0] + _output._init_f + offset;
      preds[0] = new Distribution(_parms).linkInv(f);
    }
    return preds;
  }

  // Note: POJO scoring code doesn't support per-row offsets (the scoring API would need to be changed to pass in offsets)
  @Override protected void toJavaUnifyPreds(SBPrintStream body) {
    // Preds are filled in from the trees, but need to be adjusted according to
    // the loss function.
    if( _parms._distribution == DistributionFamily.bernoulli
        || _parms._distribution == DistributionFamily.quasibinomial
        || _parms._distribution == DistributionFamily.modified_huber
        ) {
      body.ip("preds[2] = preds[1] + ").p(_output._init_f).p(";").nl();
      body.ip("preds[2] = " + new Distribution(_parms).linkInvString("preds[2]") + ";").nl();
      body.ip("preds[1] = 1.0-preds[2];").nl();
      if (_parms._balance_classes)
        body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
      body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
      return;
    }
    if( _output.nclasses() == 1 ) { // Regression
      body.ip("preds[0] += ").p(_output._init_f).p(";").nl();
      body.ip("preds[0] = " + new Distribution(_parms).linkInvString("preds[0]") + ";").nl();
      return;
    }
    if( _output.nclasses()==2 ) { // Kept the initial prediction for binomial
      body.ip("preds[1] += ").p(_output._init_f).p(";").nl();
      body.ip("preds[2] = - preds[1];").nl();
    }
    body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
    if (_parms._balance_classes)
      body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
    body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
  }


  @Override
  public GbmMojoWriter getMojo() {
    return new GbmMojoWriter(this);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy