All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.DTree Maven / Gradle / Ivy

package hex.tree;

import water.*;
import water.fvec.Chunk;
import water.util.*;

import java.util.*;

/** A Decision Tree, laid over a Frame of Vecs, and built distributed.
 *
 *  

This class defines an explicit Tree structure, as a collection of {@code * DTree} {@code Node}s. The Nodes are numbered with a unique {@code _nid}. * Users need to maintain their own mapping from their data to a {@code _nid}, * where the obvious technique is to have a Vec of {@code _nid}s (ints), one * per each element of the data Vecs. * *

Each {@code Node} has a {@code DHistogram}, describing summary data * about the rows. The DHistogram requires a pass over the data to be filled * in, and we expect to fill in all rows for Nodes at the same depth at the * same time. i.e., a single pass over the data will fill in all leaf Nodes' * DHistograms at once. * * @author Cliff Click */ public class DTree extends Iced { final String[] _names; // Column names final int _ncols; // Active training columns final char _nbins; // Max number of bins to split over final char _nclass; // #classes, or 1 for regression trees final int _min_rows; // Fewest allowed rows in any split final long _seed; // RNG seed; drives sampling seeds if necessary private Node[] _ns; // All the nodes in the tree. Node 0 is the root. public int _len; // Resizable array // Public stats about tree public int _leaves; public int _depth; public DTree( String[] names, int ncols, char nbins, char nclass, int min_rows ) { this(names,ncols,nbins,nclass,min_rows,-1); } public DTree( String[] names, int ncols, char nbins, char nclass, int min_rows, long seed ) { _names = names; _ncols = ncols; _nbins=nbins; _nclass=nclass; _min_rows = min_rows; _ns = new Node[1]; _seed = seed; } public final Node root() { return _ns[0]; } // One-time local init after wire transfer void init_tree( ) { for( int j=0; j<_len; j++ ) _ns[j]._tree = this; } // Return Node i public final Node node( int i ) { if( i >= _len ) throw new ArrayIndexOutOfBoundsException(i); return _ns[i]; } public final UndecidedNode undecided( int i ) { return (UndecidedNode)node(i); } public final DecidedNode decided( int i ) { return ( DecidedNode)node(i); } // Get a new node index, growing innards on demand private synchronized int newIdx(Node n) { if( _len == _ns.length ) _ns = Arrays.copyOf(_ns,_len<<1); _ns[_len] = n; return _len++; } public final int len() { return _len; } // -------------------------------------------------------------------------- // Abstract node flavor public static abstract class Node extends Iced { transient protected DTree _tree; // Make transient, lest we clone the whole tree final public int _pid; // Parent node id, root has no parent and uses -1 final protected int _nid; // My node-ID, 0 is root Node( DTree tree, int pid, int nid ) { _tree = tree; _pid=pid; tree._ns[_nid=nid] = this; } Node( DTree tree, int pid ) { _tree = tree; _pid=pid; _nid = tree.newIdx(this); } // Recursively print the decision-line from tree root to this child. StringBuilder printLine(StringBuilder sb ) { if( _pid==-1 ) return sb.append("[root]"); DecidedNode parent = _tree.decided(_pid); parent.printLine(sb).append(" to "); return parent.printChild(sb,_nid); } abstract public StringBuilder toString2(StringBuilder sb, int depth); abstract protected AutoBuffer compress(AutoBuffer ab); abstract protected int size(); public final int nid() { return _nid; } } // -------------------------------------------------------------------------- // Records a column, a bin to split at within the column, and the MSE. public static class Split extends Iced { final public int _col, _bin;// Column to split, bin where being split final IcedBitSet _bs; // For binary y and categorical x (with >= 4 levels), split into 2 non-contiguous groups final byte _equal; // Split is 0: <, 1: == with single split point, 2: == with group split (<= 32 levels), 3: == with group split (> 32 levels) final double _se0, _se1; // Squared error of each subsplit final long _n0, _n1; // Rows in each final split final double _p0, _p1; // Predicted value for each split public Split( int col, int bin, IcedBitSet bs, byte equal, double se0, double se1, long n0, long n1, double p0, double p1 ) { _col = col; _bin = bin; _bs = bs; _equal = equal; _n0 = n0; _n1 = n1; _se0 = se0; _se1 = se1; _p0 = p0; _p1 = p1; } public final double se() { return _se0+_se1; } public final int col() { return _col; } public final int bin() { return _bin; } // Split-at dividing point. Don't use the step*bin+bmin, due to roundoff // error we can have that point be slightly higher or lower than the bin // min/max - which would allow values outside the stated bin-range into the // split sub-bins. Always go for a value which splits the nearest two // elements. float splat(DHistogram hs[]) { DHistogram h = hs[_col]; assert _bin > 0 && _bin < h.nbins(); if( _equal == 1 ) { assert h.bins(_bin)!=0; return h.binAt(_bin); } // Find highest non-empty bin below the split int x=_bin-1; while( x >= 0 && h.bins(x)==0 ) x--; // Find lowest non-empty bin above the split int n=_bin; while( n < h.nbins() && h.bins(n)==0 ) n++; // Lo is the high-side of the low non-empty bin, rounded to int for int columns // Hi is the low -side of the hi non-empty bin, rounded to int for int columns // Example: Suppose there are no empty bins, and we are splitting an // integer column at 48.4 (more than nbins, so step != 1.0, perhaps // step==1.8). The next lowest non-empty bin is from 46.6 to 48.4, and // we set lo=48.4. The next highest non-empty bin is from 48.4 to 50.2 // and we set hi=48.4. Since this is an integer column, we round lo to // 48 (largest integer below the split) and hi to 49 (smallest integer // above the split). Finally we average them, and split at 48.5. float lo = h.binAt(x+1); float hi = h.binAt(n ); if( h._isInt > 0 ) lo = h._step==1 ? lo-1 : (float)Math.floor(lo); if( h._isInt > 0 ) hi = h._step==1 ? hi : (float)Math.ceil (hi); return (lo+hi)/2.0f; } // Split a DHistogram. Return null if there is no point in splitting // this bin further (such as there's fewer than min_row elements, or zero // error in the response column). Return an array of DHistograms (one // per column), which are bounded by the split bin-limits. If the column // has constant data, or was not being tracked by a prior DHistogram // (for being constant data from a prior split), then that column will be // null in the returned array. public DHistogram[] split( int way, char nbins, int min_rows, DHistogram hs[], float splat ) { long n = way==0 ? _n0 : _n1; if( n < min_rows || n <= 1 ) return null; // Too few elements double se = way==0 ? _se0 : _se1; if( se <= 1e-30 ) return null; // No point in splitting a perfect prediction // Build a next-gen split point from the splitting bin int cnt=0; // Count of possible splits DHistogram nhists[] = new DHistogram[hs.length]; // A new histogram set for( int j=0; j>1,nbins); // min & max come from the original column data, since splitting on an // unrelated column will not change the j'th columns min/max. // Tighten min/max based on actual observed data for tracked columns float min, maxEx; if( h._bins == null ) { // Not tracked this last pass? min = h._min; // Then no improvement over last go maxEx = h._maxEx; } else { // Else pick up tighter observed bounds min = h.find_min(); // Tracked inclusive lower bound if( h.find_maxIn() == min ) continue; // This column will not split again maxEx = h.find_maxEx(); // Exclusive max } // Tighter bounds on the column getting split: exactly each new // DHistogram's bound are the bins' min & max. if( _col==j ) { if( _equal != 0 ) { // Equality split; no change on unequals-side if( way == 1 ) continue; // but know exact bounds on equals-side - and this col will not split again } else { // Less-than split if( h._bins[_bin]==0 ) throw H2O.unimpl(); // Here I should walk up & down same as split() above. float split = splat; if( h._isInt > 0 ) split = (float)Math.ceil(split); if( way == 0 ) maxEx= split; else min = split; } } if( MathUtils.equalsWithinOneSmallUlp(min, maxEx) ) continue; // This column will not split again if( h._isInt > 0 && !(min+1 < maxEx ) ) continue; // This column will not split again if( min > maxEx ) continue; // Happens for all-NA subsplits assert min < maxEx && n > 1 : ""+min+"<"+maxEx+" n="+n; nhists[j] = DHistogram.make(h._name,adj_nbins,h._isInt,min,maxEx,n,h._doGrpSplit,h.isBinom()); cnt++; // At least some chance of splitting } return cnt == 0 ? null : nhists; } public static StringBuilder ary2str( StringBuilder sb, int w, long xs[] ) { sb.append('['); for( long x : xs ) UndecidedNode.p(sb,x,w).append(","); return sb.append(']'); } public static StringBuilder ary2str( StringBuilder sb, int w, float xs[] ) { sb.append('['); for( float x : xs ) UndecidedNode.p(sb,x,w).append(","); return sb.append(']'); } public static StringBuilder ary2str( StringBuilder sb, int w, double xs[] ) { sb.append('['); for( double x : xs ) UndecidedNode.p(sb,(float)x,w).append(","); return sb.append(']'); } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("{").append(_col).append("/"); UndecidedNode.p(sb,_bin,2); sb.append(", se0=").append(_se0); sb.append(", se1=").append(_se1); sb.append(", n0=" ).append(_n0 ); sb.append(", n1=" ).append(_n1 ); return sb.append("}").toString(); } } // -------------------------------------------------------------------------- // An UndecidedNode: Has a DHistogram which is filled in (in parallel // with other histograms) in a single pass over the data. Does not contain // any split-decision. public static abstract class UndecidedNode extends Node { public transient DHistogram[] _hs; public final int _scoreCols[]; // A list of columns to score; could be null for all public UndecidedNode( DTree tree, int pid, DHistogram[] hs ) { super(tree,pid); assert hs.length==tree._ncols; _scoreCols = scoreCols(_hs=hs); } // Pick a random selection of columns to compute best score. // Can return null for 'all columns'. abstract public int[] scoreCols( DHistogram[] hs ); // Make the parent of this Node use a -1 NID to prevent the split that this // node otherwise induces. Happens if we find out too-late that we have a // perfect prediction here, and we want to turn into a leaf. public void do_not_split( ) { if( _pid == -1 ) return; // skip root DecidedNode dn = _tree.decided(_pid); for( int i=0; i nbins ) nbins = _hs[j].nbins(); for( int i=0; i w ) s = String.format("%4.1f",d); if( s.length() > w ) s = String.format("%4.0f",d); return p(sb,s,w); } @Override public StringBuilder toString2(StringBuilder sb, int depth) { for( int d=0; d= // T | != == public final int _nids[]; // Children NIDS for the split LEFT, RIGHT transient byte _nodeType; // Complex encoding: see the compressed struct comments transient int _size = 0; // Compressed byte size of this subtree // Make a correctly flavored Undecided public abstract UndecidedNode makeUndecidedNode(DHistogram hs[]); // Pick the best column from the given histograms public abstract Split bestCol( UndecidedNode u, DHistogram hs[] ); public DecidedNode( UndecidedNode n, DHistogram hs[] ) { super(n._tree,n._pid,n._nid); // Replace Undecided with this DecidedNode _nids = new int[2]; // Split into 2 subsets _split = bestCol(n,hs); // Best split-point for this tree if( _split._col == -1 ) { // No good split? // Happens because the predictor columns cannot split the responses - // which might be because all predictor columns are now constant, or // because all responses are now constant. _splat = Float.NaN; Arrays.fill(_nids,-1); return; } _splat = (_split._equal == 0 || _split._equal == 1) ? _split.splat(hs) : -1; // Split-at value (-1 for group-wise splits) final char nbins = _tree._nbins; final int min_rows = _tree._min_rows; for( int b=0; b<2; b++ ) { // For all split-points // Setup for children splits DHistogram nhists[] = _split.split(b,nbins,min_rows,hs,_splat); assert nhists==null || nhists.length==_tree._ncols; _nids[b] = nhists == null ? -1 : makeUndecidedNode(nhists)._nid; } } // Bin #. public int bin( Chunk chks[], int row ) { float d = (float)chks[_split._col].at0(row); // Value to split on for this row if( Float.isNaN(d) ) // Missing data? return 0; // NAs always to bin 0 // Note that during *scoring* (as opposed to training), we can be exposed // to data which is outside the bin limits. if(_split._equal == 0) return d < _splat ? 0 : 1; else if(_split._equal == 1) return d != _splat ? 0 : 1; else return _split._bs.contains((int)d) ? 1 : 0; // return _split._equal ? (d != _splat ? 0 : 1) : (d < _splat ? 0 : 1); } public int ns( Chunk chks[], int row ) { return _nids[bin(chks,row)]; } @Override public String toString() { if( _split._col == -1 ) return "Decided has col = -1"; int col = _split._col; if( _split._equal == 1 ) return _tree._names[col]+" != "+_splat+"\n"+ _tree._names[col]+" == "+_splat+"\n"; else if( _split._equal == 2 || _split._equal == 3 ) return _tree._names[col]+" != "+_split._bs.toString()+"\n"+ _tree._names[col]+" == "+_split._bs.toString()+"\n"; return _tree._names[col]+" < "+_splat+"\n"+ _splat+" <="+_tree._names[col]+"\n"; } StringBuilder printChild( StringBuilder sb, int nid ) { int i = _nids[0]==nid ? 0 : 1; assert _nids[i]==nid : "No child nid "+nid+"? " +Arrays.toString(_nids); sb.append("[").append(_tree._names[_split._col]); sb.append(_split._equal != 0 ? (i==0 ? " != " : " == ") : (i==0 ? " < " : " >= ")); sb.append((_split._equal == 2 || _split._equal == 3) ? _split._bs.toString() : _splat).append("]"); return sb; } @Override public StringBuilder toString2(StringBuilder sb, int depth) { for( int i=0; i<_nids.length; i++ ) { for( int d=0; d= ")); sb.append((_split._equal == 2 || _split._equal == 3) ? _split._bs.toString() : _splat).append("\n"); } if( _nids[i] >= 0 && _nids[i] < _tree._len ) _tree.node(_nids[i]).toString2(sb,depth+1); } return sb; } // Size of this subtree; sets _nodeType also @Override public final int size(){ if( _size != 0 ) return _size; // Cached size assert _nodeType == 0:"unexpected node type: " + _nodeType; if(_split._equal != 0) _nodeType |= _split._equal == 1 ? 4 : (_split._equal == 2 ? 8 : 12); // int res = 7; // 1B node type + flags, 2B colId, 4B float split val // 1B node type + flags, 2B colId, 4B split val/small group or (2B offset + 2B size) + large group int res = _split._equal == 3 ? 7 + _split._bs.numBytes() : 7; Node left = _tree.node(_nids[0]); int lsz = left.size(); res += lsz; if( left instanceof LeafNode ) _nodeType |= (byte)(48 << 0*2); else { int slen = lsz < 256 ? 0 : (lsz < 65535 ? 1 : (lsz<(1<<24) ? 2 : 3)); _nodeType |= slen; // Set the size-skip bits res += (slen+1); // } Node rite = _tree.node(_nids[1]); if( rite instanceof LeafNode ) _nodeType |= (byte)(48 << 1*2); res += rite.size(); assert (_nodeType&0x33) != 51; assert res != 0; return (_size = res); } // Compress this tree into the AutoBuffer @Override public AutoBuffer compress(AutoBuffer ab) { int pos = ab.position(); if( _nodeType == 0 ) size(); // Sets _nodeType & _size both ab.put1(_nodeType); // Includes left-child skip-size bits assert _split._col != -1; // Not a broken root non-decision? ab.put2((short)_split._col); // Save split-at-value or group if(_split._equal == 0 || _split._equal == 1) ab.put4f(_splat); else if(_split._equal == 2) { /* byte[] ary = MemoryManager.malloc1(4); for(int i = 0; i < 4; i++) ary[i] = _split._bs._val[i]; ab.putA1(ary, 4); */ //ab.putA1(_split._bs._val, 4); throw H2O.unimpl(); // TODO: fold offset into IcedBitSet } else { assert _split._equal == 3; //ab.put2((char)_split._bs._offset); //ab.put2((char)_split._bs.numBytes()); //ab.putA1(_split._bs._val, _split._bs.numBytes()); throw H2O.unimpl(); // TODO: fold offset into IcedBitSet } Node left = _tree.node(_nids[0]); if( (_nodeType&48) == 0 ) { // Size bits are optional for left leaves ! int sz = left.size(); if(sz < 256) ab.put1( sz); else if (sz < 65535) ab.put2((short)sz); else if (sz < (1<<24)) ab.put3( sz); else ab.put4( sz); // 1<<31-1 } // now write the subtree in left.compress(ab); Node rite = _tree.node(_nids[1]); rite.compress(ab); assert _size == ab.position()-pos:"reported size = " + _size + " , real size = " + (ab.position()-pos); return ab; } } public static abstract class LeafNode extends Node { public double _pred; public LeafNode( DTree tree, int pid ) { super(tree,pid); } public LeafNode( DTree tree, int pid, int nid ) { super(tree,pid,nid); } @Override public String toString() { return "Leaf#"+_nid+" = "+_pred; } @Override public final StringBuilder toString2(StringBuilder sb, int depth) { for( int d=0; d { public static final int MAX_NODES = (1 << 12) / 4; // limit for a number decision nodes final byte _bits[] = new byte [100]; final float _fs [] = new float[100]; final SB _sbs [] = new SB [100]; final int _nodesCnt[] = new int [100]; SB _sb; SB _csb; SB _grpsplit; int _subtrees = 0; int _grpcnt = 0; public TreeJCodeGen(SharedTreeModel tm, CompressedTree ct, SB sb) { super(tm, ct); _sb = sb; _csb = new SB(); _grpsplit = new SB(); } // code preamble protected void preamble(SB sb, int subtree) throws RuntimeException { String subt = subtree>0?String.valueOf(subtree):""; sb.i().p("static final ").p(SharedTreeModel.PRED_TYPE).p(" predict").p(subt).p("(double[] data) {").nl().ii(1); // predict method for one tree sb.i().p(SharedTreeModel.PRED_TYPE).p(" pred = "); } // close the code protected void closure(SB sb) throws RuntimeException { sb.p(";").nl(); sb.i(1).p("return pred;").nl().di(1); sb.i().p("}").nl(); // sb.p(_grpsplit).di(1); } @Override protected void pre( int col, float fcmp, IcedBitSet gcmp, int equal ) { if(equal == 2 || equal == 3 && gcmp != null) { _grpsplit.i(1).p("// ").p(gcmp.toString()).nl(); _grpsplit.i(1).p("public static final byte[] GRPSPLIT").p(_grpcnt).p(" = new byte[] ").p(gcmp.toStrArray()).p(";").nl(); } if( _depth > 0 ) { int b = _bits[_depth-1]; assert b > 0 : Arrays.toString(_bits)+"\n"+_sb.toString(); if( b==1 ) _bits[_depth-1]=3; if( b==1 || b==2 ) _sb.p('\n').i(_depth).p("?"); if( b==2 ) _sb.p(' ').pj(_fs[_depth-1]); // Dump the leaf containing float value if( b==2 || b==3 ) _sb.p('\n').i(_depth).p(":"); } if (_nodes>MAX_NODES) { _sb.p("predict").p(_subtrees).p("(data)"); _nodesCnt[_depth] = _nodes; _sbs[_depth] = _sb; _sb = new SB(); _nodes = 0; preamble(_sb, _subtrees); _subtrees++; } // All NAs are going always to the left _sb.p(" (Double.isNaN(data[").p(col).p("]) || "); if(equal == 0 || equal == 1) { String scmp = _tm.isFromSpeeDRF() ? "<= " : "< "; _sb.p("(float) data[").p(col).p(" /* ").p(_tm._output._names[col]).p(" */").p("] ").p(equal == 1 ? "!= " : scmp).pj(fcmp); // then left and then right (left is !=) } else { //_sb.p("!water.genmodel.GeneratedModel.grpContains(GRPSPLIT").p(_grpcnt).p(", ").p(gcmp._offset).p(", (int) data[").p(col).p(" /* ").p(_tm._names[col]).p(" */").p("])"); _grpcnt++; throw H2O.unimpl(); // TODO: fold offset into IcedBitSet } assert _bits[_depth]==0; _bits[_depth]=1; } @Override protected void leaf( float pred ) { assert _depth==0 || _bits[_depth-1] > 0 : Arrays.toString(_bits); // it can be degenerated tree if( _depth==0) { // it is de-generated tree _sb.pj(pred); } else if( _bits[_depth-1] == 1 ) { // No prior leaf; just memorize this leaf _bits[_depth-1]=2; _fs[_depth-1]=pred; } else { // Else==2 (prior leaf) or 3 (prior tree) if( _bits[_depth-1] == 2 ) _sb.p(" ? ").pj(_fs[_depth-1]).p(" "); else _sb.p('\n').i(_depth); _sb.p(": ").pj(pred); } } @Override protected void post( int col, float fcmp, int equal ) { _sb.p(')'); _bits[_depth]=0; if (_sbs[_depth]!=null) { closure(_sb); _csb.p(_sb); _sb = _sbs[_depth]; _nodes = _nodesCnt[_depth]; _sbs[_depth] = null; } } public void generate() { preamble(_sb, _subtrees++); // TODO: Need to pass along group split BitSet visit(); closure(_sb); _sb.p(_grpsplit).di(1); _sb.p(_csb); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy