All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.gbm.GBM Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.gbm;

import hex.VarImp;
import hex.schemas.GBMV2;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.*;
import water.fvec.Chunk;
import water.util.Log;
import water.util.Timer;
import water.util.ArrayUtils;

/** Gradient Boosted Trees
 *
 *  Based on "Elements of Statistical Learning, Second Edition, page 387"
 */
public class GBM extends SharedTree {
  // Called from an http request
  public GBM( GBMModel.GBMParameters parms) { super("GBM",parms); init(false); }

  @Override public GBMV2 schema() { return new GBMV2(); }

  /** Start the GBM training Job on an F/J thread. */
  @Override public Job trainModel() {
    return start(new GBMDriver(), _parms._ntrees/*work for progress bar*/);
  }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   *
   *  Validate the learning rate and loss family. */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if( !(0. < _parms._learn_rate && _parms._learn_rate <= 1.0) )
      error("_learn_rate", "learn_rate must be between 0 and 1");
    if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
      if( _nclass != 2 )
        error("_loss","Bernoulli requires the response to be a 2-class categorical");
      // Bernoulli: initial prediction is log( mean(y)/(1-mean(y)) )
      double mean = _response.mean();
      _initialPrediction = Math.log(mean/(1.0f-mean));
    }
  }

  // ----------------------
  private class GBMDriver extends Driver {

    /** Sum of variable empirical improvement in squared-error. The value is not scaled! */
    private transient float[/*nfeatures*/] _improvPerVar;

    @Override protected void buildModel() {
      // Initialize gbm-specific data structures
      if( _parms._importance ) _improvPerVar = new float[_nclass];

      // Reconstruct the working tree state from the checkpoint
      if( _parms._checkpoint ) {
        Timer t = new Timer();
        new ResidualsCollector(_ncols, _nclass, _model._output._treeKeys).doAll(_train);
        Log.info("Reconstructing tree residuals stats from checkpointed model took " + t);
      }

      // Loop over the K trees
      for( int tid=0; tid<_parms._ntrees; tid++) {
        // During first iteration model contains 0 trees, then 1-tree, ...
        // No need to score a checkpoint with no extra trees added
        if( tid!=0 || !_parms._checkpoint ) // do not make initial scoring if model already exist
          doScoringAndSaveModel(false, false, false);

        // ESL2, page 387
        // Step 2a: Compute prediction (prob distribution) from prior tree results:
        //   Work <== f(Tree)
        new ComputeProb().doAll(_train);

        // ESL2, page 387
        // Step 2b i: Compute residuals from the prediction (probability distribution)
        //   Work <== f(Work)
        new ComputeRes().doAll(_train);

        // ESL2, page 387, Step 2b ii, iii, iv
        Timer kb_timer = new Timer();
        buildNextKTrees();
        Log.info((tid+1) + ". tree was built in " + kb_timer.toString());
        if( !isRunning() ) return; // If canceled during building, do not bulkscore
      }
      // Final scoring (skip if job was cancelled)
      doScoringAndSaveModel(true, false, false);
    }

    // --------------------------------------------------------------------------
    // Compute Prediction from prior tree results.
    // Classification (multinomial): Probability Distribution of loglikelyhoods
    //   Prob_k = exp(Work_k)/sum_all_K exp(Work_k)
    // Classification (bernoulli): Probability of y = 1 given logit link function
    //   Prob_0 = 1/(1 + exp(Work)), Prob_1 = 1/(1 + exp(-Work))
    // Regression: Just prior tree results
    // Work <== f(Tree)
    class ComputeProb extends MRTask {
      @Override public void map( Chunk chks[] ) {
        Chunk ys = chk_resp(chks);
        if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
          Chunk tr = chk_tree(chks,0);
          Chunk wk = chk_work(chks,0);
          for( int row = 0; row < ys._len; row++)
            // wk.set0(row, 1.0f/(1f+Math.exp(-tr.at0(row))) ); // Prob_1
            wk.set0(row, 1.0f/(1f+Math.exp(tr.at0(row))) );     // Prob_0
        } else if( _nclass > 1 ) {       // Classification
          float fs[] = new float[_nclass+1];
          for( int row=0; row {
      @Override public void map( Chunk chks[] ) {
        Chunk ys = chk_resp(chks);
        if( _parms._loss == GBMModel.GBMParameters.Family.bernoulli ) {
          for(int row = 0; row < ys._len; row++) {
            if( ys.isNA0(row) ) continue;
            int y = (int)ys.at80(row); // zero-based response variable
            Chunk wk = chk_work(chks,0);
            // wk.set0(row, y-(float)wk.at0(row));  // wk.at0(row) is Prob_1
            wk.set0(row, y-1f+(float)wk.at0(row));  // wk.at0(row) is Prob_0
          }
        } else if( _nclass > 1 ) {       // Classification

          for( int row=0; rowResiduals
          for( int row=0; row 1 && _parms._loss != GBMModel.GBMParameters.Family.bernoulli ? (double)(_nclass-1)/_nclass : 1.0; // K-1/K for multinomial
      for( int k=0; k<_nclass; k++ ) {
        final DTree tree = ktrees[k];
        if( tree == null ) continue;
        for( int i=0; i fields from tree prediction
              ct.set0(row, (float)(ct.at0(row) + (float) ((LeafNode)tree.node(nid))._pred));
              nids.set0(row,0);
            }
          }
        }
      }.doAll(_train);

      // Collect leaves stats
      for (int i=0; i {
      final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
      final int   _leafs[]; // Number of active leaves (per tree)
      final boolean _isBernoulli;
      // Per leaf: sum(res);
      double _rss[/*tree/klass*/][/*tree-relative node-id*/];
      // Per leaf: multinomial: sum(|res|*1-|res|), gaussian: sum(1), bernoulli: sum((y-res)*(1-y+res))
      double _gss[/*tree/klass*/][/*tree-relative node-id*/];
      GammaPass(DTree trees[], int leafs[], boolean isBernoulli) { _leafs=leafs; _trees=trees; _isBernoulli = isBernoulli; }
      @Override public void map( Chunk[] chks ) {
        _gss = new double[_nclass][];
        _rss = new double[_nclass][];
        final Chunk resp = chk_resp(chks); // Response for this frame

        // For all tree/klasses
        for( int k=0; k<_nclass; k++ ) {
          final DTree tree = _trees[k];
          final int   leaf = _leafs[k];
          if( tree == null ) continue; // Empty class is ignored

          // A leaf-biased array of all active Tree leaves.
          final double gs[] = _gss[k] = new double[tree._len-leaf];
          final double rs[] = _rss[k] = new double[tree._len-leaf];
          final Chunk nids = chk_nids(chks,k); // Node-ids  for this tree/class
          final Chunk ress = chk_work(chks,k); // Residuals for this tree/class

          // If we have all constant responses, then we do not split even the
          // root and the residuals should be zero.
          if( tree.root() instanceof LeafNode ) continue;
          for( int row=0; row 1 ? ares*(1-ares) : 1;
            rs[leafnid-leaf] += res;
          }
        }
      }
      @Override public void reduce( GammaPass gp ) {
        ArrayUtils.add(_gss,gp._gss);
        ArrayUtils.add(_rss,gp._rss);
      }
    }

    @Override protected GBMModel makeModel( Key modelKey, GBMModel.GBMParameters parms ) {
      return new GBMModel(modelKey,parms,new GBMModel.GBMOutput(GBM.this));
    }

  }

  @Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
    return new GBMDecidedNode(udn,hs);
  }
  
  // ---
  // GBM DTree decision node: same as the normal DecidedNode, but
  // specifies a decision algorithm given complete histograms on all
  // columns.  GBM algo: find the lowest error amongst *all* columns.
  static class GBMDecidedNode extends DecidedNode {
    GBMDecidedNode( UndecidedNode n, DHistogram[] hs ) { super(n,hs); }
    @Override public UndecidedNode makeUndecidedNode(DHistogram[] hs ) {
      return new GBMUndecidedNode(_tree,_nid,hs);
    }
  
    // Find the column with the best split (lowest score).  Unlike RF, GBM
    // scores on all columns and selects splits on all columns.
    @Override public DTree.Split bestCol( UndecidedNode u, DHistogram[] hs ) {
      DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
      if( hs == null ) return best;
      for( int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy