hex.tree.gbm.GBM Maven / Gradle / Ivy
package hex.tree.gbm;
import hex.ModelCategory;
import hex.schemas.GBMV3;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.util.Log;
import water.util.Timer;
import water.util.ArrayUtils;
/** Gradient Boosted Trees
*
* Based on "Elements of Statistical Learning, Second Edition, page 387"
*/
public class GBM extends SharedTree {
@Override public ModelCategory[] can_build() {
return new ModelCategory[]{
ModelCategory.Regression,
ModelCategory.Binomial,
ModelCategory.Multinomial,
};
}
@Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; }
// Called from an http request
public GBM( GBMModel.GBMParameters parms) { super("GBM",parms); init(false); }
@Override public GBMV3 schema() { return new GBMV3(); }
/** Start the GBM training Job on an F/J thread. */
@Override public Job trainModel() {
return start(new GBMDriver(), _parms._ntrees/*work for progress bar*/);
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();". This call is made
* by the front-end whenever the GUI is clicked, and needs to be fast;
* heavy-weight prep needs to wait for the trainModel() call.
*
* Validate the learning rate and distribution family. */
@Override public void init(boolean expensive) {
super.init(expensive);
// Initialize response based on given distribution family.
// Regression: initially predict the response mean
// Binomial: just class 0 (class 1 in the exact inverse prediction)
// Multinomial: Class distribution which is not a single value.
// However there is this weird tension on the initial value for
// classification: If you guess 0's (no class is favored over another),
// then with your first GBM tree you'll typically move towards the correct
// answer a little bit (assuming you have decent predictors) - and
// immediately the Confusion Matrix shows good results which gradually
// improve... BUT the Means Squared Error will suck for unbalanced sets,
// even as the CM is good. That's because we want the predictions for the
// common class to be large and positive, and the rare class to be negative
// and instead they start around 0. Guessing initial zero's means the MSE
// is so bad, that the R^2 metric is typically negative (usually it's
// between 0 and 1).
// If instead you guess the mean (reversed through the loss function), then
// the zero-tree GBM model reports an MSE equal to the response variance -
// and an initial R^2 of zero. More trees gradually improves the R^2 as
// expected. However, all the minority classes have large guesses in the
// wrong direction, and it takes a long time (lotsa trees) to correct that
// - so your CM sucks for a long time.
double mean = 0;
if (expensive) {
if (error_count() > 0) {
GBM.this.updateValidationMessages();
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GBM.this);
}
mean = _response.mean();
_initialPrediction = _nclass == 1 ? mean
: (_nclass == 2 ? -0.5 * Math.log(mean / (1.0 - mean))/*0.0*/ : 0.0/*not a single value*/);
if (_parms._distribution == GBMModel.GBMParameters.Family.AUTO) {
if (_nclass == 1) _parms._distribution = GBMModel.GBMParameters.Family.gaussian;
if (_nclass == 2) _parms._distribution = GBMModel.GBMParameters.Family.bernoulli;
if (_nclass >= 3) _parms._distribution = GBMModel.GBMParameters.Family.multinomial;
}
}
switch( _parms._distribution) {
case bernoulli:
if( _nclass != 2 /*&& !couldBeBool(_response)*/)
error("_distribution", "Binomial requires the response to be a 2-class categorical");
else if( _response != null )
// Bernoulli: initial prediction is log( mean(y)/(1-mean(y)) )
_initialPrediction = Math.log(mean / (1.0 - mean));
break;
case multinomial:
if (!isClassifier()) error("_distribution", "Multinomial requires an enum response.");
break;
case gaussian:
if (isClassifier()) error("_distribution", "Gaussian requires the response to be numeric.");
break;
case AUTO:
break;
default:
error("_distribution","Invalid distribution: " + _parms._distribution);
}
if( !(0. < _parms._learn_rate && _parms._learn_rate <= 1.0) )
error("_learn_rate", "learn_rate must be between 0 and 1");
}
// ----------------------
private class GBMDriver extends Driver {
@Override protected void buildModel() {
final double init = _initialPrediction;
if( init != 0.0 ) // Only non-zero for regression or bernoulli
new MRTask() {
@Override public void map(Chunk tree) { for( int i=0; i= _parms._r2_stopping )
return; // Stop when approaching round-off error
}
// ESL2, page 387
// Step 2a: Compute prediction (prob distribution) from prior tree results:
// Work <== f(Tree)
new ComputeProb().doAll(_train);
// ESL2, page 387
// Step 2b i: Compute residuals from the prediction (probability distribution)
// Work <== f(Work)
new ComputeRes().doAll(_train);
// ESL2, page 387, Step 2b ii, iii, iv
Timer kb_timer = new Timer();
buildNextKTrees();
Log.info((tid+1) + ". tree was built in " + kb_timer.toString());
GBM.this.update(1);
if( !isRunning() ) return; // If canceled during building, do not bulkscore
}
// Final scoring (skip if job was cancelled)
doScoringAndSaveModel(true, false, false);
}
// --------------------------------------------------------------------------
// Compute Prediction from prior tree results.
// Classification (multinomial): Probability Distribution of loglikelyhoods
// Prob_k = exp(Work_k)/sum_all_K exp(Work_k)
// Classification (bernoulli): Probability of y = 1 given logit link function
// Prob_0 = 1/(1 + exp(Work)), Prob_1 = 1/(1 + exp(-Work))
// Regression: Just prior tree results
// Work <== f(Tree)
class ComputeProb extends MRTask {
@Override public void map( Chunk chks[] ) {
Chunk ys = chk_resp(chks);
if( _parms._distribution == GBMModel.GBMParameters.Family.bernoulli ) {
Chunk tr = chk_tree(chks,0);
Chunk wk = chk_work(chks,0);
for( int row = 0; row < ys._len; row++)
// wk.set(row, 1.0f/(1f+Math.exp(-tr.atd(row))) ); // Prob_1
wk.set(row, 1.0f / (1f + Math.exp(tr.atd(row)))); // Prob_0
} else if( _nclass > 1 ) { // Classification
double fs[] = new double[_nclass+1];
for( int row=0; row {
@Override public void map( Chunk chks[] ) {
Chunk ys = chk_resp(chks);
if( _parms._distribution == GBMModel.GBMParameters.Family.bernoulli ) {
for(int row = 0; row < ys._len; row++) {
if( ys.isNA(row) ) continue;
int y = (int)ys.at8(row); // zero-based response variable
Chunk wk = chk_work(chks,0);
// wk.set(row, y-(float)wk.atd(row)); // wk.atd(row) is Prob_1
wk.set(row, y-1f+(float)wk.atd(row)); // wk.atd(row) is Prob_0
}
} else if( _nclass > 1 ) { // Classification
for( int row=0; rowResiduals
for( int row=0; row 1 && _parms._distribution != GBMModel.GBMParameters.Family.bernoulli ? (double)(_nclass-1)/_nclass : 1.0; // K-1/K for multinomial
for( int k=0; k<_nclass; k++ ) {
final DTree tree = ktrees[k];
if( tree == null ) continue;
for( int i=0; i 1e4 ) gf = 1e4f; // Cap prediction, will already overflow during Math.exp(gf)
else if( gf < -1e4 ) gf = -1e4f;
}
assert !Float.isNaN(gf) && !Float.isInfinite(gf);
((LeafNode) tree.node(leafs[k] + i))._pred = gf;
}
}
// ----
// ESL2, page 387. Step 2b iv. Cache the sum of all the trees, plus the
// new tree, in the 'tree' columns. Also, zap the NIDs for next pass.
// Tree <== f(Tree)
// Nids <== 0
new MRTask() {
@Override public void map( Chunk chks[] ) {
// For all tree/klasses
for( int k=0; k<_nclass; k++ ) {
final DTree tree = ktrees[k];
if( tree == null ) continue;
final Chunk nids = chk_nids(chks,k);
final Chunk ct = chk_tree(chks,k);
for( int row=0; row fields from tree prediction
ct.set(row, (float)(ct.atd(row) + ((LeafNode)tree.node(nid))._pred));
nids.set(row, 0);
}
}
}
}.doAll(_train);
// Grow the model by K-trees
_model._output.addKTrees(ktrees);
}
// ---
// ESL2, page 387. Step 2b iii.
// Nids <== f(Nids)
private class GammaPass extends MRTask {
final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
final int _leafs[]; // Number of active leaves (per tree)
final boolean _isBernoulli;
// Per leaf: sum(res);
double _rss[/*tree/klass*/][/*tree-relative node-id*/];
// Per leaf: multinomial: sum(|res|*1-|res|), gaussian: sum(1), bernoulli: sum((y-res)*(1-y+res))
double _gss[/*tree/klass*/][/*tree-relative node-id*/];
GammaPass(DTree trees[], int leafs[], boolean isBernoulli) { _leafs=leafs; _trees=trees; _isBernoulli = isBernoulli; }
@Override public void map( Chunk[] chks ) {
_gss = new double[_nclass][];
_rss = new double[_nclass][];
final Chunk resp = chk_resp(chks); // Response for this frame
// For all tree/klasses
for( int k=0; k<_nclass; k++ ) {
final DTree tree = _trees[k];
final int leaf = _leafs[k];
if( tree == null ) continue; // Empty class is ignored
// A leaf-biased array of all active Tree leaves.
final double gs[] = _gss[k] = new double[tree._len-leaf];
final double rs[] = _rss[k] = new double[tree._len-leaf];
final Chunk nids = chk_nids(chks,k); // Node-ids for this tree/class
final Chunk ress = chk_work(chks,k); // Residuals for this tree/class
// If we have all constant responses, then we do not split even the
// root and the residuals should be zero.
if( tree.root() instanceof LeafNode ) continue;
for( int row=0; row 1 ? ares*(1-ares) : 1;
rs[leafnid-leaf] += res;
}
}
}
@Override public void reduce( GammaPass gp ) {
ArrayUtils.add(_gss,gp._gss);
ArrayUtils.add(_rss,gp._rss);
}
}
@Override protected GBMModel makeModel( Key modelKey, GBMModel.GBMParameters parms, double mse_train, double mse_valid ) {
return new GBMModel(modelKey,parms,new GBMModel.GBMOutput(GBM.this,mse_train,mse_valid));
}
}
@Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
return new GBMDecidedNode(udn,hs);
}
// ---
// GBM DTree decision node: same as the normal DecidedNode, but
// specifies a decision algorithm given complete histograms on all
// columns. GBM algo: find the lowest error amongst *all* columns.
static class GBMDecidedNode extends DecidedNode {
GBMDecidedNode( UndecidedNode n, DHistogram[] hs ) { super(n,hs); }
@Override public UndecidedNode makeUndecidedNode(DHistogram[] hs ) {
return new GBMUndecidedNode(_tree,_nid,hs);
}
// Find the column with the best split (lowest score). Unlike RF, GBM
// scores on all columns and selects splits on all columns.
@Override public DTree.Split bestCol( UndecidedNode u, DHistogram[] hs ) {
DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
if( hs == null ) return best;
for( int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy