
hex.tree.drf.DRF Maven / Gradle / Ivy
package hex.tree.drf;
import hex.ModelCategory;
import hex.schemas.DRFV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.DTree.DecidedNode;
import hex.tree.DTree.LeafNode;
import hex.tree.DTree.UndecidedNode;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.Log;
import water.util.Timer;
import java.util.Arrays;
import java.util.Random;
import static hex.genmodel.GenModel.getPrediction;
import static hex.tree.drf.TreeMeasuresCollector.asSSE;
import static hex.tree.drf.TreeMeasuresCollector.asVotes;
/** Gradient Boosted Trees
*
* Based on "Elements of Statistical Learning, Second Edition, page 387"
*/
public class DRF extends SharedTree {
protected int _mtry;
@Override public ModelCategory[] can_build() {
return new ModelCategory[]{
ModelCategory.Regression,
ModelCategory.Binomial,
ModelCategory.Multinomial,
};
}
@Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };
// Called from an http request
public DRF( hex.tree.drf.DRFModel.DRFParameters parms) { super("DRF",parms); init(false); }
@Override public DRFV3 schema() { return new DRFV3(); }
/** Start the DRF training Job on an F/J thread. */
@Override public Job trainModel() {
return start(new DRFDriver(), _parms._ntrees/*work for progress bar*/);
}
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();". This call is made
* by the front-end whenever the GUI is clicked, and needs to be fast;
* heavy-weight prep needs to wait for the trainModel() call.
*/
@Override public void init(boolean expensive) {
super.init(expensive);
// Initialize local variables
if (!(0.0 < _parms._sample_rate && _parms._sample_rate <= 1.0))
throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + _parms._sample_rate);
if( _parms._mtries < 1 && _parms._mtries != -1 ) error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + _parms._mtries);
if( _train != null ) {
int ncols = _train.numCols();
if( _parms._mtries != -1 && !(1 <= _parms._mtries && _parms._mtries < ncols))
error("_mtries","Computed mtries should be -1 or in interval <1,#cols> but it is " + _parms._mtries);
}
if (_parms._sample_rate == 1f && _valid == null)
error("_sample_rate", "Sample rate is 100% and no validation dataset is specified. There are no OOB data to compute out-of-bag error estimation!");
}
// A standard DTree with a few more bits. Support for sampling during
// training, and replaying the sample later on the identical dataset to
// e.g. compute OOBEE.
static class DRFTree extends DTree {
final int _mtrys; // Number of columns to choose amongst in splits
final long _seeds[]; // One seed for each chunk, for sampling
final transient Random _rand; // RNG for split decisions & sampling
DRFTree( Frame fr, int ncols, char nbins, char nbins_cats, char nclass, int min_rows, int mtrys, long seed ) {
super(fr._names, ncols, nbins, nbins_cats, nclass, min_rows, seed);
_mtrys = mtrys;
_rand = createRNG(seed);
_seeds = new long[fr.vecs()[0].nChunks()];
for( int i=0; i<_seeds.length; i++ )
_seeds[i] = _rand.nextLong();
}
// Return a deterministic chunk-local RNG. Can be kinda expensive.
public Random rngForChunk( int cidx ) {
long seed = _seeds[cidx];
return createRNG(seed);
}
}
/** Fill work columns:
* - classification: set 1 in the corresponding wrk col according to row response
* - regression: copy response into work column (there is only 1 work column)
*/
private class SetWrkTask extends MRTask {
@Override public void map( Chunk chks[] ) {
Chunk cy = chk_resp(chks);
for( int i=0; i= 2 ) {
// for( int c=0; c<_nclass; c++ ) {
// final double init = _model._output._priorClassDist[c];
// new MRTask() {
// @Override public void map(Chunk tree) { for( int i=0; i but it is " + _mtry);
// Initialize TreeVotes for classification, MSE arrays for regression
initTreeMeasurements();
// Append number of trees participating in on-the-fly scoring
_train.add("OUT_BAG_TREES", _response.makeZero());
// Prepare working columns
new SetWrkTask().doAll(_train);
// If there was a check point recompute tree_<_> and oob columns based on predictions from previous trees
// but only if OOB validation is requested.
if (_parms._checkpoint) {
Timer t = new Timer();
// Compute oob votes for each output level
new OOBScorer(_ncols, _nclass, _parms._sample_rate, _model._output._treeKeys).doAll(_train);
Log.info("Reconstructing oob stats from checkpointed model took " + t);
}
// The RNG used to pick split columns
Random rand = createRNG(_parms._seed);
// To be deterministic get random numbers for previous trees and
// put random generator to the same state
for (int i=0; i<_ntreesFromCheckpoint; i++) rand.nextLong();
int tid;
DTree[] ktrees = null;
// Prepare tree statistics
// Build trees until we hit the limit
for( tid=0; tid<_parms._ntrees; tid++) { // Building tid-tree
if (tid!=0 || !_parms._checkpoint) { // do not make initial scoring if model already exist
double training_r2 = doScoringAndSaveModel(false, true, _parms._build_tree_one_node);
if( training_r2 >= _parms._r2_stopping )
return; // Stop when approaching round-off error
}
// At each iteration build K trees (K = nclass = response column domain size)
// TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
// Idea: launch more DRF at once.
Timer kb_timer = new Timer();
buildNextKTrees(_train,_mtry,_parms._sample_rate,rand,tid);
Log.info((tid+1) + ". tree was built " + kb_timer.toString());
DRF.this.update(1);
if( !isRunning() ) return; // If canceled during building, do not bulkscore
}
doScoringAndSaveModel(true, true, _parms._build_tree_one_node);
}
// --------------------------------------------------------------------------
// Build the next random k-trees representing tid-th tree
private void buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
// We're going to build K (nclass) trees - each focused on correcting
// errors for a single class.
final DTree[] ktrees = new DTree[_nclass];
// Initial set of histograms. All trees; one leaf per tree (the root
// leaf); all columns
DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];
// Adjust real bins for the top-levels
int adj_nbins = Math.max(_parms._nbins_top_level,_parms._nbins);
// Use for all k-trees the same seed. NOTE: this is only to make a fair
// view for all k-trees
final long[] _distribution = _model._output._distribution;
long rseed = rand.nextLong();
// Initially setup as-if an empty-split had just happened
for (int k = 0; k < _nclass; k++) {
if (_distribution[k] != 0) { // Ignore missing classes
// The Boolean Optimization
// This optimization assumes the 2nd tree of a 2-class system is the
// inverse of the first (and that the same columns were picked)
if( k==1 && _nclass==2 && !_parms._binomial_double_trees) continue;
ktrees[k] = new DRFTree(fr, _ncols, (char)_parms._nbins, (char)_parms._nbins_cats, (char)_nclass, _parms._min_rows, mtrys, rseed);
boolean isBinom = isClassifier();
new DRFUndecidedNode(ktrees[k], -1, DHistogram.initialHist(fr, _ncols, adj_nbins, _parms._nbins_cats, hcs[k][0], isBinom)); // The "root" node
}
}
// Sample - mark the lines by putting 'OUT_OF_BAG' into nid() vector
Timer t_1 = new Timer();
Sample ss[] = new Sample[_nclass];
for( int k=0; k<_nclass; k++)
if (ktrees[k] != null) ss[k] = new Sample((DRFTree)ktrees[k], sample_rate).dfork(0,new Frame(vec_nids(fr,k),vec_resp(fr)), _parms._build_tree_one_node);
for( int k=0; k<_nclass; k++)
if( ss[k] != null ) ss[k].getResult();
Log.debug("Sampling took: + " + t_1);
int[] leafs = new int[_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
// ----
// One Big Loop till the ktrees are of proper depth.
// Adds a layer to the trees each pass.
Timer t_2 = new Timer();
int depth=0;
for( ; depth<_parms._max_depth; depth++ ) {
if( !isRunning() ) return;
hcs = buildLayer(fr, _parms._nbins, _parms._nbins_cats, ktrees, leafs, hcs, true, _parms._build_tree_one_node);
// If we did not make any new splits, then the tree is split-to-death
if( hcs == null ) break;
}
Log.debug("Tree build took: " + t_2);
// Each tree bottomed-out in a DecidedNode; go 1 more level and insert
// LeafNodes to hold predictions.
Timer t_3 = new Timer();
for( int k=0; k<_nclass; k++ ) {
DTree tree = ktrees[k];
if( tree == null ) continue;
int leaf = leafs[k] = tree.len();
for( int nid=0; nid {
/* @IN */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
/* @IN */ double _threshold; // Sum of squares for this tree only
/* @OUT */ long rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
/* @OUT */ long allRows; // number of all OOB rows (sampled by this tree)
/* @OUT */ float sse; // Sum of squares for this tree only
CollectPreds(DTree trees[], int leafs[], double threshold) { _trees=trees; _threshold = threshold; }
final boolean importance = true;
@Override public void map( Chunk[] chks ) {
final Chunk y = importance ? chk_resp(chks) : null; // Response
final double[] rpred = importance ? new double[1+_nclass] : null; // Row prediction
final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
final Chunk oobt = chk_oobt(chks); // Out-of-bag rows counter over all trees
// Iterate over all rows
for( int row=0; row 1 : "broken histo range "+hs[i];
cols[len++] = i; // Gather active column
}
int choices = len; // Number of columns I can choose from
assert choices > 0;
// Draw up to mtry columns at random without replacement.
for( int i=0; i 0;
return Arrays.copyOfRange(cols, len, choices);
}
}
// ---
static class DRFLeafNode extends LeafNode {
DRFLeafNode( DTree tree, int pid ) { super(tree,pid); }
DRFLeafNode( DTree tree, int pid, int nid ) { super(tree, pid, nid); }
// Insert just the predictions: a single byte/short if we are predicting a
// single class, or else the full distribution.
@Override protected AutoBuffer compress(AutoBuffer ab) { assert !Double.isNaN(_pred); return ab.put4f(_pred); }
@Override protected int size() { return 4; }
}
// Deterministic sampling
static class Sample extends MRTask {
final DRFTree _tree;
final float _rate;
Sample( DRFTree tree, float rate ) { _tree = tree; _rate = rate; }
@Override public void map( Chunk nids, Chunk ys ) {
Random rand = _tree.rngForChunk(nids.cidx());
for( int row=0; row= _rate || Double.isNaN(ys.atd(row)) ) {
nids.set(row, ScoreBuildHistogram.OUT_OF_BAG); // Flag row as being ignored by sampling
}
}
}
// Read the 'tree' columns, do model-specific math and put the results in the
// fs[] array, and return the sum. Dividing any fs[] element by the sum
// turns the results into a probability distribution.
@Override protected double score1( Chunk chks[], double fs[/*nclass*/], int row ) {
double sum = 0;
if (_nclass > 2 || (_nclass == 2 && _parms._binomial_double_trees) ) { //multinomial or binomial with 1 tree per class
for (int k = 0; k < _nclass; k++)
sum += (fs[k+1] = chk_tree(chks, k).atd(row));
}
else if (_nclass==2 && !_parms._binomial_double_trees) { //binomial optimization
fs[1] = chk_tree(chks, 0).atd(row);
assert(fs[1] >= 0 && fs[1] <= 1);
fs[2] = 1. - fs[1];
}
else { //regression
// average per trees voted for this row (only trees which have row in "out-of-bag"
sum += (fs[0] = chk_tree(chks, 0).atd(row) / chk_oobt(chks).atd(row) );
fs[1] = 0;
}
return sum;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy