All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.TreeJCodeGen Maven / Gradle / Ivy

package hex.tree;

import java.util.Arrays;

import water.util.IcedBitSet;
import water.util.SB;

/** A tree code generator producing Java code representation of the tree:
 *
 *  - A generated class contains score0 method
 *  - if score0 method is too long,
 *  - too long methods are
 */
class TreeJCodeGen extends TreeVisitor {
  public static final int MAX_NODES = (1 << 12) / 4; // limit for a number decision nodes
  //public static final int MAX_NODES = 5; // limit for a number decision nodes
  private static final int MAX_DEPTH = 70;
  final byte  _bits[]  = new byte [MAX_DEPTH];
  final float _fs  []  = new float[MAX_DEPTH];
  final SB    _sbs []  = new SB   [MAX_DEPTH];
  final int   _nodesCnt[]= new int[MAX_DEPTH];
  final SB    _grpSplits[] = new SB[MAX_DEPTH];
  final int   _grpSplitsCnt[] = new int[MAX_DEPTH];
  final String _javaClassName;
  final SharedTreeModel _tm;
  SB _sb;
  SB _csb;
  SB _grpsplit;

  int _subtrees = 0;
  int _grpCnt = 0;

  final private boolean _verboseCode;

  public TreeJCodeGen(SharedTreeModel tm, CompressedTree ct, SB sb, String javaClassName, boolean verboseCode) {
    super(ct);
    _tm = tm;
    _sb = sb;
    _csb = new SB();
    _grpsplit = new SB();
    _verboseCode = verboseCode;
    _javaClassName = javaClassName;
  }

  // code preamble
  protected void preamble(SB sb, int subtree) throws RuntimeException {
    String subt = subtree > 0 ? "_" + String.valueOf(subtree) : "";
    sb.p("class ").p(_javaClassName).p(subt).p(" {").nl().ii(1);
    sb.ip("static final double score0").p("(double[] data) {").nl().ii(1); // predict method for one tree
    sb.ip("double pred = ");
  }

  // close the code
  protected void closure(SB sb) throws RuntimeException {
    sb.p(";").nl();
    sb.ip("return pred;").nl().di(1);
    sb.ip("}").nl(); // close the method
    // Append actual group splits
    _sb.p(_grpsplit);
    sb.di(1).ip("}").nl().nl(); // close the class
  }

  @Override protected void pre( int col, float fcmp, IcedBitSet gcmp, int equal ) {
    if( _depth > 0 ) {
      int b = _bits[_depth-1];
      assert b > 0 : Arrays.toString(_bits)+"\n"+_sb.toString();
      if( b==1         ) _bits[_depth-1]=3;
      if( b==1 || b==2 ) _sb.p('\n').i(_depth).p("?");
      if( b==2         ) _sb.p(' ').pj(_fs[_depth-1]); // Dump the leaf containing float value
      if( b==2 || b==3 ) _sb.p('\n').i(_depth).p(":");
    }
    // Switch to a new generator
    if (_nodes>MAX_NODES) {
      _sb.p(_javaClassName).p('_').p(_subtrees).p(".score0").p("(data)");
      _nodesCnt[_depth] = _nodes;
      _sbs[_depth] = _sb;
      _grpSplits[_depth] = _grpsplit;
      _grpSplitsCnt[_depth] = _grpCnt;
      _sb = new SB();
      _nodes = 0;
      _grpsplit = new SB();
      _grpCnt = 0;
      preamble(_sb, _subtrees);
      _subtrees++;
    }
    // Generates array for group splits
    if(equal == 2 || equal == 3 && gcmp != null) {
      _grpsplit.i(1).p("// ").p(gcmp.toString()).nl();
      _grpsplit.i(1).p("public static final byte[] GRPSPLIT").p(_grpCnt).p(" = new byte[] ").p(gcmp.toStrArray()).p(";").nl();
    }
    // Generates decision
    _sb.p(" (");
    if(equal == 0 || equal == 1) {
      _sb.p("data[").p(col);
      // Generate column names only if necessary
      if (_verboseCode) {
        _sb.p(" /* ").p(_tm._output._names[col]).p(" */");
      }
      _sb.p("] ").p(equal == 1 ? "!= " : "<").pj(fcmp); // then left and then right (left is !=)
    } else {
      gcmp.toJava(_sb, "GRPSPLIT"+_grpCnt,col,_tm._output._names[col]);
      _grpCnt++;
    }
    assert _bits[_depth]==0;
    _bits[_depth]=1;
  }
  @Override protected void leaf( float pred  ) {
    assert _depth==0 || _bits[_depth-1] > 0 : Arrays.toString(_bits); // it can be degenerated tree
    if( _depth==0) { // it is de-generated tree
      _sb.pj(pred);
    } else if( _bits[_depth-1] == 1 ) { // No prior leaf; just memorize this leaf
      _bits[_depth-1]=2; _fs[_depth-1]=pred;
    } else {          // Else==2 (prior leaf) or 3 (prior tree)
      if( _bits[_depth-1] == 2 ) _sb.p(" ? ").pj(_fs[_depth-1]).p(" ");
      else                       _sb.p('\n').i(_depth);
      _sb.p(": ").pj(pred);
    }
  }
  @Override protected void post( int col, float fcmp, int equal ) {
    _sb.p(')');
    _bits[_depth]=0;
    if (_sbs[_depth]!=null) {
      closure(_sb);
      _csb.p(_sb);
      _sb = _sbs[_depth];
      _nodes = _nodesCnt[_depth];
      _sbs[_depth] = null;
      _grpsplit = _grpSplits[_depth];
      _grpCnt = _grpSplitsCnt[_depth];
      _grpSplits[_depth] = null;
    }
  }
  public void generate() {
    preamble(_sb, _subtrees++);   // TODO: Need to pass along group split BitSet
    visit();
    closure(_sb);
    _sb.p(_csb);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy