
hex.tree.gbm.GBM Maven / Gradle / Ivy
package hex.tree.gbm;
import hex.VarImp;
import hex.schemas.GBMV2;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.*;
import water.fvec.Chunk;
import water.util.Log;
import water.util.Timer;
import water.util.ArrayUtils;
/** Gradient Boosted Trees
*
* Based on "Elements of Statistical Learning, Second Edition, page 387"
*/
public class GBM extends SharedTree {
// Called from an http request
public GBM( GBMModel.GBMParameters parms) { super("GBM",parms); init(false); }
@Override public GBMV2 schema() { return new GBMV2(); }
/** Start the GBM training Job on an F/J thread. */
@Override public Job trainModel() {
return start(new GBMDriver(), _parms._ntrees/*work for progress bar*/);
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();". This call is made
* by the front-end whenever the GUI is clicked, and needs to be fast;
* heavy-weight prep needs to wait for the trainModel() call.
*
* Validate the learning rate and loss family. */
@Override public void init(boolean expensive) {
super.init(expensive);
if( !(0. < _parms._learn_rate && _parms._learn_rate <= 1.0) )
error("_learn_rate", "learn_rate must be between 0 and 1");
if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
if( _nclass != 2 )
error("_loss","Bernoulli requires the response to be a 2-class categorical");
// Bernoulli: initial prediction is log( mean(y)/(1-mean(y)) )
double mean = _response.mean();
_initialPrediction = Math.log(mean/(1.0f-mean));
}
}
// ----------------------
private class GBMDriver extends Driver {
/** Sum of variable empirical improvement in squared-error. The value is not scaled! */
private transient float[/*nfeatures*/] _improvPerVar;
@Override protected void buildModel() {
// Initialize gbm-specific data structures
if( _parms._importance ) _improvPerVar = new float[_nclass];
// Reconstruct the working tree state from the checkpoint
if( _parms._checkpoint ) {
Timer t = new Timer();
new ResidualsCollector(_ncols, _nclass, _model._output._treeKeys).doAll(_train);
Log.info("Reconstructing tree residuals stats from checkpointed model took " + t);
}
// Loop over the K trees
for( int tid=0; tid<_parms._ntrees; tid++) {
// During first iteration model contains 0 trees, then 1-tree, ...
// No need to score a checkpoint with no extra trees added
if( tid!=0 || !_parms._checkpoint ) // do not make initial scoring if model already exist
doScoringAndSaveModel(false, false, false);
// ESL2, page 387
// Step 2a: Compute prediction (prob distribution) from prior tree results:
// Work <== f(Tree)
new ComputeProb().doAll(_train);
// ESL2, page 387
// Step 2b i: Compute residuals from the prediction (probability distribution)
// Work <== f(Work)
new ComputeRes().doAll(_train);
// ESL2, page 387, Step 2b ii, iii, iv
Timer kb_timer = new Timer();
buildNextKTrees();
Log.info((tid+1) + ". tree was built in " + kb_timer.toString());
if( !isRunning() ) return; // If canceled during building, do not bulkscore
}
// Final scoring (skip if job was cancelled)
doScoringAndSaveModel(true, false, false);
}
// --------------------------------------------------------------------------
// Compute Prediction from prior tree results.
// Classification (multinomial): Probability Distribution of loglikelyhoods
// Prob_k = exp(Work_k)/sum_all_K exp(Work_k)
// Classification (bernoulli): Probability of y = 1 given logit link function
// Prob_0 = 1/(1 + exp(Work)), Prob_1 = 1/(1 + exp(-Work))
// Regression: Just prior tree results
// Work <== f(Tree)
class ComputeProb extends MRTask {
@Override public void map( Chunk chks[] ) {
Chunk ys = chk_resp(chks);
if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
Chunk tr = chk_tree(chks,0);
Chunk wk = chk_work(chks,0);
for( int row = 0; row < ys._len; row++)
// wk.set0(row, 1.0f/(1f+Math.exp(-tr.at0(row))) ); // Prob_1
wk.set0(row, 1.0f/(1f+Math.exp(tr.at0(row))) ); // Prob_0
} else if( _nclass > 1 ) { // Classification
float fs[] = new float[_nclass+1];
for( int row=0; row {
@Override public void map( Chunk chks[] ) {
Chunk ys = chk_resp(chks);
if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
for(int row = 0; row < ys._len; row++) {
if( ys.isNA0(row) ) continue;
int y = (int)ys.at80(row); // zero-based response variable
Chunk wk = chk_work(chks,0);
// wk.set0(row, y-(float)wk.at0(row)); // wk.at0(row) is Prob_1
wk.set0(row, y-1f+(float)wk.at0(row)); // wk.at0(row) is Prob_0
}
} else if( _nclass > 1 ) { // Classification
for( int row=0; rowResiduals
for( int row=0; row 1 && _parms._loss != GBMModel.GBMParameters.Family.bernoulli ? (double)(_nclass-1)/_nclass : 1.0; // K-1/K for multinomial
for( int k=0; k<_nclass; k++ ) {
final DTree tree = ktrees[k];
if( tree == null ) continue;
for( int i=0; i fields from tree prediction
ct.set0(row, (float)(ct.at0(row) + (float) ((LeafNode)tree.node(nid))._pred));
nids.set0(row,0);
}
}
}
}.doAll(_train);
// Collect leaves stats
for (int i=0; i {
final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
final int _leafs[]; // Number of active leaves (per tree)
final boolean _isBernoulli;
// Per leaf: sum(res);
double _rss[/*tree/klass*/][/*tree-relative node-id*/];
// Per leaf: multinomial: sum(|res|*1-|res|), gaussian: sum(1), bernoulli: sum((y-res)*(1-y+res))
double _gss[/*tree/klass*/][/*tree-relative node-id*/];
GammaPass(DTree trees[], int leafs[], boolean isBernoulli) { _leafs=leafs; _trees=trees; _isBernoulli = isBernoulli; }
@Override public void map( Chunk[] chks ) {
_gss = new double[_nclass][];
_rss = new double[_nclass][];
final Chunk resp = chk_resp(chks); // Response for this frame
// For all tree/klasses
for( int k=0; k<_nclass; k++ ) {
final DTree tree = _trees[k];
final int leaf = _leafs[k];
if( tree == null ) continue; // Empty class is ignored
// A leaf-biased array of all active Tree leaves.
final double gs[] = _gss[k] = new double[tree._len-leaf];
final double rs[] = _rss[k] = new double[tree._len-leaf];
final Chunk nids = chk_nids(chks,k); // Node-ids for this tree/class
final Chunk ress = chk_work(chks,k); // Residuals for this tree/class
// If we have all constant responses, then we do not split even the
// root and the residuals should be zero.
if( tree.root() instanceof LeafNode ) continue;
for( int row=0; row 1 ? ares*(1-ares) : 1;
rs[leafnid-leaf] += res;
}
}
}
@Override public void reduce( GammaPass gp ) {
ArrayUtils.add(_gss,gp._gss);
ArrayUtils.add(_rss,gp._rss);
}
}
@Override protected GBMModel makeModel( Key modelKey, GBMModel.GBMParameters parms ) {
return new GBMModel(modelKey,parms,new GBMModel.GBMOutput(GBM.this));
}
}
@Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
return new GBMDecidedNode(udn,hs);
}
// ---
// GBM DTree decision node: same as the normal DecidedNode, but
// specifies a decision algorithm given complete histograms on all
// columns. GBM algo: find the lowest error amongst *all* columns.
static class GBMDecidedNode extends DecidedNode {
GBMDecidedNode( UndecidedNode n, DHistogram[] hs ) { super(n,hs); }
@Override public UndecidedNode makeUndecidedNode(DHistogram[] hs ) {
return new GBMUndecidedNode(_tree,_nid,hs);
}
// Find the column with the best split (lowest score). Unlike RF, GBM
// scores on all columns and selects splits on all columns.
@Override public DTree.Split bestCol( UndecidedNode u, DHistogram[] hs ) {
DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
if( hs == null ) return best;
for( int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy