All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.gbm.GBM Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.gbm;

import hex.ModelCategory;
import hex.schemas.GBMV3;
import hex.tree.*;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.util.Log;
import water.util.Timer;
import water.util.ArrayUtils;

/** Gradient Boosted Trees
 *
 *  Based on "Elements of Statistical Learning, Second Edition, page 387"
 */
public class GBM extends SharedTree {
  @Override public ModelCategory[] can_build() {
    return new ModelCategory[]{
      ModelCategory.Regression,
      ModelCategory.Binomial,
      ModelCategory.Multinomial,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; }

  // Called from an http request
  public GBM( GBMModel.GBMParameters parms) { super("GBM",parms); init(false); }

  @Override public GBMV3 schema() { return new GBMV3(); }

  /** Start the GBM training Job on an F/J thread. */
  @Override public Job trainModel() {
    return start(new GBMDriver(), _parms._ntrees/*work for progress bar*/);
  }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   *
   *  Validate the learning rate and distribution family. */
  @Override public void init(boolean expensive) {
    super.init(expensive);

    // Initialize response based on given distribution family.
    // Regression: initially predict the response mean
    // Binomial: just class 0 (class 1 in the exact inverse prediction)
    // Multinomial: Class distribution which is not a single value.

    // However there is this weird tension on the initial value for
    // classification: If you guess 0's (no class is favored over another),
    // then with your first GBM tree you'll typically move towards the correct
    // answer a little bit (assuming you have decent predictors) - and
    // immediately the Confusion Matrix shows good results which gradually
    // improve... BUT the Means Squared Error will suck for unbalanced sets,
    // even as the CM is good.  That's because we want the predictions for the
    // common class to be large and positive, and the rare class to be negative
    // and instead they start around 0.  Guessing initial zero's means the MSE
    // is so bad, that the R^2 metric is typically negative (usually it's
    // between 0 and 1).

    // If instead you guess the mean (reversed through the loss function), then
    // the zero-tree GBM model reports an MSE equal to the response variance -
    // and an initial R^2 of zero.  More trees gradually improves the R^2 as
    // expected.  However, all the minority classes have large guesses in the
    // wrong direction, and it takes a long time (lotsa trees) to correct that
    // - so your CM sucks for a long time.
    double mean = 0;
    if (expensive) {
      if (error_count() > 0) {
        GBM.this.updateValidationMessages();
        throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GBM.this);
      }

      mean = _response.mean();
      _initialPrediction = _nclass == 1 ? mean
              : (_nclass == 2 ? -0.5 * Math.log(mean / (1.0 - mean))/*0.0*/ : 0.0/*not a single value*/);

      if (_parms._distribution == GBMModel.GBMParameters.Family.AUTO) {
        if (_nclass == 1) _parms._distribution = GBMModel.GBMParameters.Family.gaussian;
        if (_nclass == 2) _parms._distribution = GBMModel.GBMParameters.Family.bernoulli;
        if (_nclass >= 3) _parms._distribution = GBMModel.GBMParameters.Family.multinomial;
      }
    }

    switch( _parms._distribution) {
    case bernoulli:
      if( _nclass != 2 /*&& !couldBeBool(_response)*/)
        error("_distribution", "Binomial requires the response to be a 2-class categorical");
      else if( _response != null ) 
        // Bernoulli: initial prediction is log( mean(y)/(1-mean(y)) )
        _initialPrediction = Math.log(mean / (1.0 - mean));
      break;
    case multinomial:
      if (!isClassifier()) error("_distribution", "Multinomial requires an enum response.");
      break;
    case gaussian:
      if (isClassifier()) error("_distribution", "Gaussian requires the response to be numeric.");
      break;
    case AUTO:
      break;
    default:
      error("_distribution","Invalid distribution: " + _parms._distribution);
    }
    
    if( !(0. < _parms._learn_rate && _parms._learn_rate <= 1.0) )
      error("_learn_rate", "learn_rate must be between 0 and 1");
  }

  // ----------------------
  private class GBMDriver extends Driver {

    @Override protected void buildModel() {
      final double init = _initialPrediction;
      if( init != 0.0 )       // Only non-zero for regression or bernoulli
        new MRTask() {
          @Override public void map(Chunk tree) { for( int i=0; i= _parms._r2_stopping )
            return;             // Stop when approaching round-off error
        }

        // ESL2, page 387
        // Step 2a: Compute prediction (prob distribution) from prior tree results:
        //   Work <== f(Tree)
        new ComputeProb().doAll(_train);

        // ESL2, page 387
        // Step 2b i: Compute residuals from the prediction (probability distribution)
        //   Work <== f(Work)
        new ComputeRes().doAll(_train);

        // ESL2, page 387, Step 2b ii, iii, iv
        Timer kb_timer = new Timer();
        buildNextKTrees();
        Log.info((tid+1) + ". tree was built in " + kb_timer.toString());
        GBM.this.update(1);
        if( !isRunning() ) return; // If canceled during building, do not bulkscore
      }
      // Final scoring (skip if job was cancelled)
      doScoringAndSaveModel(true, false, false);
    }

    // --------------------------------------------------------------------------
    // Compute Prediction from prior tree results.
    // Classification (multinomial): Probability Distribution of loglikelyhoods
    //   Prob_k = exp(Work_k)/sum_all_K exp(Work_k)
    // Classification (bernoulli): Probability of y = 1 given logit link function
    //   Prob_0 = 1/(1 + exp(Work)), Prob_1 = 1/(1 + exp(-Work))
    // Regression: Just prior tree results
    // Work <== f(Tree)
    class ComputeProb extends MRTask {
      @Override public void map( Chunk chks[] ) {
        Chunk ys = chk_resp(chks);
        if( _parms._distribution == GBMModel.GBMParameters.Family.bernoulli ) {
          Chunk tr = chk_tree(chks,0);
          Chunk wk = chk_work(chks,0);
          for( int row = 0; row < ys._len; row++)
            // wk.set(row, 1.0f/(1f+Math.exp(-tr.atd(row))) ); // Prob_1
            wk.set(row, 1.0f / (1f + Math.exp(tr.atd(row))));     // Prob_0
        } else if( _nclass > 1 ) {       // Classification
          double fs[] = new double[_nclass+1];
          for( int row=0; row {
      @Override public void map( Chunk chks[] ) {
        Chunk ys = chk_resp(chks);
        if( _parms._distribution == GBMModel.GBMParameters.Family.bernoulli ) {
          for(int row = 0; row < ys._len; row++) {
            if( ys.isNA(row) ) continue;
            int y = (int)ys.at8(row); // zero-based response variable
            Chunk wk = chk_work(chks,0);
            // wk.set(row, y-(float)wk.atd(row));  // wk.atd(row) is Prob_1
            wk.set(row, y-1f+(float)wk.atd(row));  // wk.atd(row) is Prob_0
          }
        } else if( _nclass > 1 ) {       // Classification

          for( int row=0; rowResiduals
          for( int row=0; row 1 && _parms._distribution != GBMModel.GBMParameters.Family.bernoulli ? (double)(_nclass-1)/_nclass : 1.0; // K-1/K for multinomial
      for( int k=0; k<_nclass; k++ ) {
        final DTree tree = ktrees[k];
        if( tree == null ) continue;
        for( int i=0; i  1e4 ) gf =  1e4f; // Cap prediction, will already overflow during Math.exp(gf)
            else if( gf < -1e4 ) gf = -1e4f;
          }
          assert !Float.isNaN(gf) && !Float.isInfinite(gf);
          ((LeafNode) tree.node(leafs[k] + i))._pred = gf;
        }
      }

      // ----
      // ESL2, page 387.  Step 2b iv.  Cache the sum of all the trees, plus the
      // new tree, in the 'tree' columns.  Also, zap the NIDs for next pass.
      // Tree <== f(Tree)
      // Nids <== 0
      new MRTask() {
        @Override public void map( Chunk chks[] ) {
          // For all tree/klasses
          for( int k=0; k<_nclass; k++ ) {
            final DTree tree = ktrees[k];
            if( tree == null ) continue;
            final Chunk nids = chk_nids(chks,k);
            final Chunk ct   = chk_tree(chks,k);
            for( int row=0; row fields from tree prediction
              ct.set(row, (float)(ct.atd(row) + ((LeafNode)tree.node(nid))._pred));
              nids.set(row, 0);
            }
          }
        }
      }.doAll(_train);

      // Grow the model by K-trees
      _model._output.addKTrees(ktrees);
    }

    // ---
    // ESL2, page 387.  Step 2b iii.
    // Nids <== f(Nids)
    private class GammaPass extends MRTask {
      final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
      final int   _leafs[]; // Number of active leaves (per tree)
      final boolean _isBernoulli;
      // Per leaf: sum(res);
      double _rss[/*tree/klass*/][/*tree-relative node-id*/];
      // Per leaf: multinomial: sum(|res|*1-|res|), gaussian: sum(1), bernoulli: sum((y-res)*(1-y+res))
      double _gss[/*tree/klass*/][/*tree-relative node-id*/];
      GammaPass(DTree trees[], int leafs[], boolean isBernoulli) { _leafs=leafs; _trees=trees; _isBernoulli = isBernoulli; }
      @Override public void map( Chunk[] chks ) {
        _gss = new double[_nclass][];
        _rss = new double[_nclass][];
        final Chunk resp = chk_resp(chks); // Response for this frame

        // For all tree/klasses
        for( int k=0; k<_nclass; k++ ) {
          final DTree tree = _trees[k];
          final int   leaf = _leafs[k];
          if( tree == null ) continue; // Empty class is ignored

          // A leaf-biased array of all active Tree leaves.
          final double gs[] = _gss[k] = new double[tree._len-leaf];
          final double rs[] = _rss[k] = new double[tree._len-leaf];
          final Chunk nids = chk_nids(chks,k); // Node-ids  for this tree/class
          final Chunk ress = chk_work(chks,k); // Residuals for this tree/class

          // If we have all constant responses, then we do not split even the
          // root and the residuals should be zero.
          if( tree.root() instanceof LeafNode ) continue;
          for( int row=0; row 1 ? ares*(1-ares) : 1;
            rs[leafnid-leaf] += res;
          }
        }
      }
      @Override public void reduce( GammaPass gp ) {
        ArrayUtils.add(_gss,gp._gss);
        ArrayUtils.add(_rss,gp._rss);
      }
    }

    @Override protected GBMModel makeModel( Key modelKey, GBMModel.GBMParameters parms, double mse_train, double mse_valid ) {
      return new GBMModel(modelKey,parms,new GBMModel.GBMOutput(GBM.this,mse_train,mse_valid));
    }

  }

  @Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
    return new GBMDecidedNode(udn,hs);
  }
  
  // ---
  // GBM DTree decision node: same as the normal DecidedNode, but
  // specifies a decision algorithm given complete histograms on all
  // columns.  GBM algo: find the lowest error amongst *all* columns.
  static class GBMDecidedNode extends DecidedNode {
    GBMDecidedNode( UndecidedNode n, DHistogram[] hs ) { super(n,hs); }
    @Override public UndecidedNode makeUndecidedNode(DHistogram[] hs ) {
      return new GBMUndecidedNode(_tree,_nid,hs);
    }
  
    // Find the column with the best split (lowest score).  Unlike RF, GBM
    // scores on all columns and selects splits on all columns.
    @Override public DTree.Split bestCol( UndecidedNode u, DHistogram[] hs ) {
      DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
      if( hs == null ) return best;
      for( int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy