hex.tree.gbm.GBMModel Maven / Gradle / Ivy
package hex.tree.gbm;
import hex.Distribution;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.Score;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Chunk;
import water.util.SBPrintStream;
import java.util.Arrays;
public class GBMModel extends SharedTreeModel {
public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
public double _learn_rate;
public double _learn_rate_annealing;
public double _col_sample_rate;
public double _max_abs_leafnode_pred;
public double _pred_noise_bandwidth;
public GBMParameters() {
super();
_learn_rate = 0.1;
_learn_rate_annealing = 1.0;
_col_sample_rate = 1.0;
_sample_rate = 1.0;
_ntrees = 50;
_max_depth = 5;
_max_abs_leafnode_pred = Double.MAX_VALUE;
_pred_noise_bandwidth =0;
}
public String algoName() { return "GBM"; }
public String fullName() { return "Gradient Boosting Machine"; }
public String javaName() { return GBMModel.class.getName(); }
}
public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
boolean _quasibinomial;
int _nclasses;
public int nclasses() {
return _nclasses;
}
public GBMOutput(GBM b) {
super(b);
_quasibinomial = b._parms._distribution == DistributionFamily.quasibinomial;
_nclasses = b.nclasses();
}
@Override
public String[] classNames() {
String [] res = super.classNames();
if(res == null && _quasibinomial)
return new String[]{"0", "1"};
return res;
}
}
public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
super(selfKey,parms,output);
}
@Override
protected final double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
assert _output.nfeatures() == tmp.length;
for (int i = 0; i < tmp.length; i++)
tmp[i] = chks[i].atd(row_in_chunk);
if (sii._startTree == 0)
Arrays.fill(preds,0);
else
for (int i = 0; i < sii._workspaceColCnt; i++)
preds[sii._predsAryOffset + i] = chks[sii._workspaceColIdx + i].atd(row_in_chunk);
score0(tmp, preds, offset, sii._startTree, _output._treeKeys.length);
for (int i = 0; i < sii._workspaceColCnt; i++)
chks[sii._workspaceColIdx + i].set(row_in_chunk, preds[sii._predsAryOffset + i]);
score0Probabilities(preds, offset);
score0PostProcessSupervised(preds, tmp);
return preds;
}
/** Bulk scoring API for one row. Chunks are all compatible with the model,
* and expect the last Chunks are for the final distribution and prediction.
* Default method is to just load the data into the tmp array, then call
* subclass scoring logic. */
@Override protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/], double offset, int ntrees) {
super.score0(data, preds, offset, ntrees); // These are f_k(x) in Algorithm 10.4
return score0Probabilities(preds, offset);
}
private double[] score0Probabilities(double preds[/*nclasses+1*/], double offset) {
if (_parms._distribution == DistributionFamily.bernoulli
|| _parms._distribution == DistributionFamily.quasibinomial
|| _parms._distribution == DistributionFamily.modified_huber) {
double f = preds[1] + _output._init_f + offset; //Note: class 1 probability stored in preds[1] (since we have only one tree)
preds[2] = new Distribution(_parms).linkInv(f);
preds[1] = 1.0 - preds[2];
} else if (_parms._distribution == DistributionFamily.multinomial) { // Kept the initial prediction for binomial
if (_output.nclasses() == 2) { //1-tree optimization for binomial
preds[1] += _output._init_f + offset; //offset is not yet allowed, but added here to be future-proof
preds[2] = -preds[1];
}
hex.genmodel.GenModel.GBM_rescale(preds);
} else { //Regression
double f = preds[0] + _output._init_f + offset;
preds[0] = new Distribution(_parms).linkInv(f);
}
return preds;
}
// Note: POJO scoring code doesn't support per-row offsets (the scoring API would need to be changed to pass in offsets)
@Override protected void toJavaUnifyPreds(SBPrintStream body) {
// Preds are filled in from the trees, but need to be adjusted according to
// the loss function.
if( _parms._distribution == DistributionFamily.bernoulli
|| _parms._distribution == DistributionFamily.quasibinomial
|| _parms._distribution == DistributionFamily.modified_huber
) {
body.ip("preds[2] = preds[1] + ").p(_output._init_f).p(";").nl();
body.ip("preds[2] = " + new Distribution(_parms).linkInvString("preds[2]") + ";").nl();
body.ip("preds[1] = 1.0-preds[2];").nl();
if (_parms._balance_classes)
body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
return;
}
if( _output.nclasses() == 1 ) { // Regression
body.ip("preds[0] += ").p(_output._init_f).p(";").nl();
body.ip("preds[0] = " + new Distribution(_parms).linkInvString("preds[0]") + ";").nl();
return;
}
if( _output.nclasses()==2 ) { // Kept the initial prediction for binomial
body.ip("preds[1] += ").p(_output._init_f).p(";").nl();
body.ip("preds[2] = - preds[1];").nl();
}
body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
if (_parms._balance_classes)
body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
}
@Override
public GbmMojoWriter getMojo() {
return new GbmMojoWriter(this);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy