All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.TreeJCodeGen Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import hex.Model;
import water.util.IcedBitSet;
import water.util.SB;

/** A tree code generator producing Java code representation of the tree:
 *
 *  - A generated class contains score0 method
 *  - if score0 method is too long, it redirects prediction to a new subclass's score0 method
 */
class TreeJCodeGen extends TreeVisitor {
  public static final int MAX_NODES = (1 << 12) / 4; // limit for a number decision nodes visited per generated class
  //public static final int MAX_NODES = 5; // limit for a number decision nodes
  private static final int MAX_DEPTH = 70;

  private static final int MAX_CONSTANT_POOL_SIZE = (1 << 16) - 4096; // Keep some space for method and string constants
  private static final int MAX_METHOD_SIZE = (1 << 16) - 4096;
  // FIXME: the dataset gbm_test/30k_cattest.csv produces trees ~ 100 depth
  //
  // Simulate stack since we need to preserve each info per generated class
  final SB    _sbs []  = new SB   [MAX_DEPTH];
  final int   _nodesCnt[]= new int[MAX_DEPTH];
  final SB    _grpSplits[] = new SB[MAX_DEPTH];
  final int   _grpSplitsCnt[] = new int[MAX_DEPTH];
  final int   _constantPool[] = new int[MAX_DEPTH];
  final int   _staticInit[] = new int[MAX_DEPTH];

  final String _javaClassName;
  final Model.Output _output;
  SB _sb;
  SB _csb;
  SB _grpsplit;

  int _subtrees = 0;
  int _grpCnt = 0;
  int _constantPoolSize = 0;
  int _staticInitSize = 0;

  final private boolean _verboseCode;

  public TreeJCodeGen(Model.Output output, CompressedTree ct, SB sb, String javaClassName, boolean verboseCode) {
    super(ct);
    _output = output;
    _sb = sb;
    _csb = new SB();
    _grpsplit = new SB();
    _verboseCode = verboseCode;
    _javaClassName = javaClassName;
  }

  // code preamble
  protected void preamble(SB sb, int subtree) throws RuntimeException {
    String subt = subtree > 0 ? "_" + String.valueOf(subtree) : "";
    sb.p("class ").p(_javaClassName).p(subt).p(" {").nl().ii(1);
    sb.ip("static final double score0").p("(double[] data) {").nl().ii(1); // predict method for one tree
    sb.ip("double pred = ");
  }

  // close the code
  protected void closure(SB sb) throws RuntimeException {
    sb.p(";").nl();
    sb.ip("return pred;").nl().di(1);
    sb.ip("}").p(" // constant pool size = ").p(_constantPoolSize).p("B, number of visited nodes = ").p(_nodes).p(", static init size = ").p(_staticInitSize).p("B");
    sb.nl(); // close the method
    // Append actual group splits
    _sb.p(_grpsplit);
    sb.di(1).ip("}").nl().nl(); // close the class
  }

  @Override protected void pre(int col, float fcmp, IcedBitSet gcmp, int equal, int naSplitDirInt) {
    // Check for method size and number of constants generated in constant pool
    if (_nodes > MAX_NODES || _constantPoolSize > MAX_CONSTANT_POOL_SIZE || _staticInitSize > MAX_METHOD_SIZE ) {
      _sb.p(_javaClassName).p('_').p(_subtrees).p(".score0").p("(data)");
      _nodesCnt[_depth] = _nodes;
      _sbs[_depth] = _sb;
      _grpSplits[_depth] = _grpsplit;
      _grpSplitsCnt[_depth] = _grpCnt;
      _constantPool[_depth] = _constantPoolSize;
      _staticInit[_depth] = _staticInitSize;
      _sb = new SB();
      _nodes = 0;
      _grpsplit = new SB();
      _grpCnt = 0;
      _constantPoolSize = 0;
      _staticInitSize = 0;
      preamble(_sb, _subtrees);
      _subtrees++;
    }
    // Generates array for group splits
    if(equal == 2 || equal == 3 && gcmp != null) {
      _grpsplit.i(1).p("// ").p(gcmp.toString()).nl();
      _grpsplit.i(1).p("public static final byte[] GRPSPLIT").p(_grpCnt).p(" = new byte[] ").p(gcmp.toStrArray()).p(";").nl();
      _constantPoolSize += gcmp.numBytes() + 3; // Each byte stored in split (NOT TRUE) and field reference and field name (Utf8) and NameAndType
      _staticInitSize += 6 + gcmp.numBytes() * 6; // byte size of instructions to create an array and load all byte values (upper bound = dup, bipush, bipush, bastore = 5bytes)
    }
    // Generates decision
    _sb.ip(" (");

    // Generate column names only if necessary
    String colName = _verboseCode ? " /* " + _output._names[col] + " */" : "";

    String[][] domains = _output._domains;
    // size of the training domains (i.e., one larger than the max number of "seen" categorical IDs)
    int limit = (domains != null && domains[col] != null) ? domains[col].length : Integer.MAX_VALUE;

    assert(equal!=1);
    if(equal == 0) {
      // for the special case of a split of a categorical column if there's not enough bins to resolve the levels,
      // we treat the categorical levels as ordinal integer levels, and split at a certain point (<=, not using a bitset)
      // => need to add the out-of-bound check explicitly here to handle unseen categoricals
      if (naSplitDirInt == DhnasdNaVsRest) {
        _sb.p("!Double.isNaN(data[").p(col).p("])");
        if (limit != Integer.MAX_VALUE)
          _sb.p(" && (data[").p(col).p("] < " + limit + ") ");
      }
      else if (naSplitDirInt == DhnasdNaLeft || naSplitDirInt == DhnasdLeft) {
        _sb.p("Double.isNaN(data[").p(col).p("]) ");
        if (limit != Integer.MAX_VALUE)
          _sb.p("|| (data[").p(col).p("] >= " + limit + ") ");
        _sb.p("|| ");
      }
      if (naSplitDirInt != DhnasdNaVsRest) {
        _sb.p("data[").p(col);
        _sb.p(colName);
        _sb.p("] < ").pj(fcmp);
        _constantPoolSize += 2; // * bytes for generated float which is represented as double because of cast (Double occupies 2 slots in constant pool)
      }
    } else {
      assert naSplitDirInt != DhnasdNaVsRest : "NAvsREST splits are expected to be represented with equal==0";
      boolean leftward = naSplitDirInt == DhnasdNaLeft || naSplitDirInt == DhnasdLeft;
      if (leftward) {
        _sb.p("Double.isNaN(data[").p(col).p(colName).p("]) || !"); //NAs (or out of range) go left
        gcmp.toJavaRangeCheck(_sb, col);
        if (limit != Integer.MAX_VALUE) {
          _sb.p(" || (data[").p(col).p("] >= " + limit + ")");
        }
        _sb.p(" || ");
      } else {
        _sb.p("!Double.isNaN(data[").p(col).p(colName).p("]) && ");
      }
      _sb.p("(");
      gcmp.toJavaRangeCheck(_sb, col);
      _sb.p(" && ");
      if (limit != Integer.MAX_VALUE) {
        _sb.p("(data[").p(col).p("] < " + limit + ")");
      }
      _sb.p(" && ");
      gcmp.toJava(_sb, "GRPSPLIT" + _grpCnt, col);
      _sb.p(")");
      _grpCnt++;
    }
    _sb.p(" ? ").ii(2).nl();
  }
  @Override protected void leaf( float pred  ) {
    _sb.i().pj(pred);
    // We are generating float which occupies single slot in constant pool, however
    // left side of final expression is double, hence javac directly stores double in constant pool (2places)
    _constantPoolSize += 2;
  }

  @Override
  protected void mid(int col, float fcmp, int equal) throws RuntimeException {
    _sb.p(" : ").nl();
  }

  @Override protected void post(int col, float fcmp, int equal ) {
    _sb.p(')').di(2);
    if (_sbs[_depth]!=null) { // Top of stack  - finalize the class generate into _sb
      closure(_sb);
      _csb.p(_sb);
      _sb = _sbs[_depth];
      _nodes = _nodesCnt[_depth];
      _sbs[_depth] = null;
      _grpsplit = _grpSplits[_depth];
      _grpCnt = _grpSplitsCnt[_depth];
      _grpSplits[_depth] = null;
      _constantPoolSize = _constantPool[_depth];
      _staticInitSize = _staticInit[_depth];
    }
  }
  public void generate() {
    preamble(_sb, _subtrees++);   // TODO: Need to pass along group split BitSet
    visit();
    closure(_sb);
    _sb.p(_csb);
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy