All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.tree.SharedTreeMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.tree;

import hex.genmodel.CategoricalEncoding;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.drf.DrfMojoModel;
import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.isotonic.IsotonicCalibrator;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import water.logging.Logger;
import water.logging.LoggerFactory;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/**
 * Common ancestor for {@link DrfMojoModel} and {@link GbmMojoModel}.
 * See also: `hex.tree.SharedTreeModel` and `hex.tree.TreeVisitor` classes.
 */
public abstract class SharedTreeMojoModel extends MojoModel implements TreeBackedMojoModel, CalibrationMojoHelper.MojoModelWithCalibration {
    
    private static final int NsdNaVsRest = NaSplitDir.NAvsREST.value();
    private static final int NsdNaLeft = NaSplitDir.NALeft.value();
    private static final int NsdLeft = NaSplitDir.Left.value();

    private ScoreTree _scoreTree;
    
    private static Logger logger = LoggerFactory.getLogger(SharedTreeMojoModel.class);

    /**
     * {@code _ntree_groups} is the number of trees requested by the user. For
     * binomial case or regression this is also the total number of trees
     * trained; however in multinomial case each requested "tree" is actually
     * represented as a group of trees, with {@code _ntrees_per_group} trees
     * in each group. Each of these individual trees assesses the likelihood
     * that a given observation belongs to class A, B, C, etc. of a
     * multiclass response.
     */
    protected int _ntree_groups;
    protected int _ntrees_per_group;

    /**
     * Array of binary tree data, each tree being a {@code byte[]} array. The
     * trees are logically grouped into a rectangular grid of dimensions
     * {@link #_ntree_groups} x {@link #_ntrees_per_group}, however physically
     * they are stored as 1-dimensional list, and an {@code [i, j]} logical
     * tree is mapped to the index {@link #treeIndex(int, int)}.
     */
    protected byte[][] _compressed_trees;

    /**
     * Array of auxiliary binary tree data, each being a {@code byte[]} array.
     */
    protected byte[][] _compressed_trees_aux;

    /**
     * GLM's beta used for calibrating output probabilities using Platt Scaling.
     */
    protected double[] _calib_glm_beta;

    /**
     * For calibrating using Isotonic Regression
     */
    protected IsotonicCalibrator _isotonic_calibrator;

    protected String _genmodel_encoding;
    
    protected String[] _orig_names;

    protected String[][] _orig_domain_values;

    protected double[] _orig_projection_array;


    protected void postInit() {
      if (_mojo_version == 1.0) {
        _scoreTree = new ScoreTree0(); // First version
      } else if (_mojo_version == 1.1) {
        _scoreTree = new ScoreTree1(); // Second version
      } else
        _scoreTree = new ScoreTree2(); // Current version
    }

    @Override
    public final int getNTreeGroups() {
      return _ntree_groups;
    }

    @Override
    public final int getNTreesPerGroup() {
      return _ntrees_per_group;
    }

    /**
     * @deprecated use {@link #scoreTree0(byte[], double[], boolean)} instead.
     */
    @Deprecated
    public static double scoreTree0(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
      // note that nclasses is ignored (and in fact, always was)
      return scoreTree0(tree, row, computeLeafAssignment);
    }

    /**
     * @deprecated use {@link #scoreTree1(byte[], double[], boolean)} instead.
     */
    @Deprecated
    public static double scoreTree1(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
      // note that nclasses is ignored (and in fact, always was)
      return scoreTree1(tree, row, computeLeafAssignment);
    }

    /**
     * @deprecated use {@link #scoreTree(byte[], double[], boolean, String[][])} instead.
     */
    @Deprecated
    public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment, String[][] domains) {
      // note that {@link nclasses} is ignored (and in fact, always was)
      return scoreTree(tree, row, computeLeafAssignment, domains);
    }

  public static final int __INTERNAL_MAX_TREE_DEPTH = 64;
  /**
   * Highly efficient (critical path) tree scoring
   *
   * Given a tree (in the form of a byte array) and the row of input data, compute either this tree's
   * predicted value when `computeLeafAssignment` is false, or the the decision path within the tree (but no more
   * than 64 levels) when `computeLeafAssignment` is true. If path has 64 levels or more, Double.NaN is returned.
   *
   * Note: this function is also used from the `hex.tree.CompressedTree` class in `h2o-algos` project.
   */
  @SuppressWarnings("ConstantConditions")  // Complains that the code is too complex. Well duh!
    public static double scoreTree(byte[] tree, double[] row, boolean computeLeafAssignment, String[][] domains) {
        ByteBufferWrapper ab = new ByteBufferWrapper(tree);
        GenmodelBitSet bs = null;
        long bitsRight = 0;
        int level = 0;
        while (true) {
            int nodeType = ab.get1U();
            int colId = ab.get2();
            if (colId == 65535) {
              if (computeLeafAssignment) {
                if (level >= __INTERNAL_MAX_TREE_DEPTH)
                  return Double.NaN;
                bitsRight |= 1L << level;  // mark the end of the tree
                return Double.longBitsToDouble(bitsRight);
              } else {
                return ab.get4f();
              }
            }
            int naSplitDir = ab.get1U();
            boolean naVsRest = naSplitDir == NsdNaVsRest;
            boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
            int lmask = (nodeType & 51);
            int equal = (nodeType & 12);  // Can be one of 0, 8, 12
            assert equal != 4;  // no longer supported

            float splitVal = -1;
            if (!naVsRest) {
                // Extract value or group to split on
                if (equal == 0) {
                    // Standard float-compare test (either < or ==)
                    splitVal = ab.get4f();  // Get the float to compare
                } else {
                    // Bitset test
                    if (bs == null) bs = new GenmodelBitSet(0);
                    if (equal == 8)
                        bs.fill2(tree, ab);
                    else
                        bs.fill3(tree, ab);
                }
            }

          // This logic:
          //
          //        double d = row[colId];
          //        if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d)
          //              ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) {

          // Really does this:
          //
          //        if (value is NaN or value is not in the range of the bitset or is outside the domain map length (but an integer) ) {
          //            if (leftward) {
          //                go left
          //            }
          //            else {
          //                go right
          //            }
          //        }
          //        else {
          //            if (naVsRest) {
          //                go left
          //            }
          //            else {
          //                if (numeric) {
          //                    if (value < split value) {
          //                        go left
          //                    }
          //                    else {
          //                        go right
          //                    }
          //                }
          //                else {
          //                    if (value not in bitset) {
          //                        go left
          //                    }
          //                    else {
          //                        go right
          //                    }
          //                }
          //            }
          //        }

            double d = row[colId];
            if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d)
                    ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) {
                // go RIGHT
                switch (lmask) {
                    case 0:  ab.skip(ab.get1U());  break;
                    case 1:  ab.skip(ab.get2());  break;
                    case 2:  ab.skip(ab.get3());  break;
                    case 3:  ab.skip(ab.get4());  break;
                    case 48: ab.skip(4);  break;  // skip the prediction
                    default:
                        assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
                }
                if (computeLeafAssignment) {
                    if (level >= __INTERNAL_MAX_TREE_DEPTH)
                      return Double.NaN;
                    bitsRight |= 1L << level;
                }
                lmask = (nodeType & 0xC0) >> 2;  // Replace leftmask with the rightmask
            } else {
                // go LEFT
                if (lmask <= 3)
                    ab.skip(lmask + 1);
            }

            level++;
            if ((lmask & 16) != 0) {
                if (computeLeafAssignment) {
                    if (level >= __INTERNAL_MAX_TREE_DEPTH)
                      return Double.NaN;
                    bitsRight |= 1L << level;  // mark the end of the tree
                    return Double.longBitsToDouble(bitsRight);
                } else {
                    return ab.get4f();
                }
            }
        }
    }
    

    @Override
    public CategoricalEncoding getCategoricalEncoding() {
        switch (_genmodel_encoding) {
            case "AUTO":
            case "Enum":
            case "SortByResponse":
                return CategoricalEncoding.AUTO;
            case "OneHotExplicit":
                return CategoricalEncoding.OneHotExplicit;
            case "Binary":
                return CategoricalEncoding.Binary;
            case "EnumLimited":
                return CategoricalEncoding.EnumLimited;
            case "Eigen":
                return CategoricalEncoding.Eigen;
            case "LabelEncoder":
                return CategoricalEncoding.LabelEncoder;
            default:
                return null;
        }
    }

    @Override
    public String[] getOrigNames() {
      return _orig_names;
    }

    @Override
    public double[] getOrigProjectionArray() {
        return _orig_projection_array;
    }

    @Override
    public String[][] getOrigDomainValues() {
        return _orig_domain_values;
    }

    public interface DecisionPathTracker {
        boolean go(int depth, boolean right);
        T terminate();
        T invalidPath();
    }

    public static class StringDecisionPathTracker implements DecisionPathTracker {
        private final char[] _sb = new char[64];
        private int _pos = 0;
        @Override
        public boolean go(int depth, boolean right) {
            _sb[depth] = right ? 'R' : 'L';
            if (right) _pos = depth;
            return true;
        }
        @Override
        public String terminate() {
            String path = new String(_sb, 0, _pos);
            _pos = 0;
            return path;
        }
        @Override
        public String invalidPath() {
            return null;
        }
    }
    
    public static class LeafDecisionPathTracker implements DecisionPathTracker {
        private final AuxInfoLightReader _auxInfo;
        private boolean _wentRight = false; // Was the last step _right_?

        // OUT
        private int _nodeId = 0; // Returned when the tree is empty (consistent with SharedTreeNode of an empty tree)

        private LeafDecisionPathTracker(byte[] auxTree) {
          _auxInfo = new AuxInfoLightReader(new ByteBufferWrapper(auxTree));
        }

        @Override
        public boolean go(int depth, boolean right) {
          if (!_auxInfo.hasNext()) {
            assert _wentRight || depth == 0; // this can only happen if previous step was _right_ or the tree has no nodes
            return false;
          }
          _auxInfo.readNext();
          if (right) {
            if (_wentRight && _nodeId != _auxInfo._nid)
              return false;
            _nodeId = _auxInfo.getRightNodeIdAndSkipNode();
            _auxInfo.skipNodes(_auxInfo._numLeftChildren);
            _wentRight = true;
          } else { // left
            _wentRight = false;
            if (_auxInfo._numLeftChildren == 0) {
              _nodeId = _auxInfo.getLeftNodeIdAndSkipNode();
              return false;
            } else {
              _auxInfo.skipNode(); // proceed to next _left_ node
            }
          }
          return true;
        }

        @Override
        public LeafDecisionPathTracker terminate() {
          return this;
        }

        final int getLeafNodeId() {
          return _nodeId;
        }

        @Override
        public LeafDecisionPathTracker invalidPath() {
          _nodeId = -1;
          return this;
        }
    }

    public static  T getDecisionPath(double leafAssignment, DecisionPathTracker tr) {
        if (Double.isNaN(leafAssignment)) {
          return tr.invalidPath();
        }
        long l = Double.doubleToRawLongBits(leafAssignment);
        for (int i = 0; i < 64; ++i) {
            boolean right = ((l>>i) & 0x1L) == 1;
            if (! tr.go(i, right)) break;
        }
        return tr.terminate();
    }

    public static String getDecisionPath(double leafAssignment) {
        return getDecisionPath(leafAssignment, new StringDecisionPathTracker());
    }

    public static int getLeafNodeId(double leafAssignment, byte[] auxTree) {
        LeafDecisionPathTracker tr = new LeafDecisionPathTracker(auxTree);
        return getDecisionPath(leafAssignment, tr).getLeafNodeId();
    }

    //------------------------------------------------------------------------------------------------------------------
    // Computing a Tree Graph
    //------------------------------------------------------------------------------------------------------------------

    private static void computeTreeGraph(SharedTreeSubgraph sg, SharedTreeNode node, byte[] tree, ByteBufferWrapper ab, HashMap auxMap,
                                         String names[], String[][] domains, ConvertTreeOptions options) {
        int nodeType = ab.get1U();
        int colId = ab.get2();
        if (colId == 65535) {
            float leafValue = ab.get4f();
            node.setPredValue(leafValue);
            return;
        }
        String colName = names[colId];
        node.setCol(colId, colName);

        int naSplitDir = ab.get1U();
        boolean naVsRest = naSplitDir == NsdNaVsRest;
        boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
        node.setLeftward(leftward);
        node.setNaVsRest(naVsRest);

        int lmask = (nodeType & 51);
        int equal = (nodeType & 12);  // Can be one of 0, 8, 12
        assert equal != 4;  // no longer supported

        if (!naVsRest) {
            // Extract value or group to split on
            if (equal == 0) {
              float splitVal = ab.get4f();

              if (domains[colId] != null) {
                node.setDomainValues(domains[colId]);
              }
                // Standard float-compare test (either < or ==)
              node.setSplitValue(splitVal);
            } else {
                // Bitset test
                GenmodelBitSet bs = new GenmodelBitSet(0);
                if (equal == 8)
                    bs.fill2(tree, ab);
                else
                    bs.fill3(tree, ab);
                node.setBitset(domains[colId], bs);
            }
        }

        AuxInfo auxInfo = auxMap.get(node.getNodeNumber());

        // go RIGHT
        {
            ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
            ab2.skip(ab.position());

            switch (lmask) {
                case 0:
                    ab2.skip(ab2.get1U());
                    break;
                case 1:
                    ab2.skip(ab2.get2());
                    break;
                case 2:
                    ab2.skip(ab2.get3());
                    break;
                case 3:
                    ab2.skip(ab2.get4());
                    break;
                case 48:
                    ab2.skip(4);
                    break;  // skip the prediction
                default:
                    assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
            }
            int lmask2 = (nodeType & 0xC0) >> 2;  // Replace leftmask with the rightmask

            SharedTreeNode newNode = sg.makeRightChildNode(node);
            newNode.setWeight(auxInfo.weightR);
            newNode.setNodeNumber(auxInfo.nidR);
            newNode.setPredValue(auxInfo.predR);
            newNode.setSquaredError(auxInfo.sqErrR);
            if ((lmask2 & 16) != 0) {
                float leafValue = ab2.get4f();
                newNode.setPredValue(leafValue);
                auxInfo.predR = leafValue;
            }
            else {
                computeTreeGraph(sg, newNode, tree, ab2, auxMap, names, domains, options);
            }
        }

        // go LEFT
        {
            ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
            ab2.skip(ab.position());

            if (lmask <= 3)
                ab2.skip(lmask + 1);

            SharedTreeNode newNode = sg.makeLeftChildNode(node);
            newNode.setWeight(auxInfo.weightL);
            newNode.setNodeNumber(auxInfo.nidL);
            newNode.setPredValue(auxInfo.predL);
            newNode.setSquaredError(auxInfo.sqErrL);
            if ((lmask & 16) != 0) {
                float leafValue = ab2.get4f();
                newNode.setPredValue(leafValue);
                auxInfo.predL = leafValue;
            }
            else {
                computeTreeGraph(sg, newNode, tree, ab2, auxMap, names, domains, options);
            }
        }
        if (node.getNodeNumber() == 0) {
          float p = (float)(((double)auxInfo.predL*(double)auxInfo.weightL + (double)auxInfo.predR*(double)auxInfo.weightR)/((double)auxInfo.weightL + (double)auxInfo.weightR));
          if (Math.abs(p) < 1e-7) p = 0;
          node.setPredValue(p);
          node.setSquaredError(auxInfo.sqErrR + auxInfo.sqErrL);
          node.setWeight(auxInfo.weightL + auxInfo.weightR);
        }
        if (options._checkTreeConsistency) {
          checkConsistency(auxInfo, node);
        }
    }

    /**
     * Compute a graph of the forest.
     *
     * @return A graph of the forest.
     */
    public SharedTreeGraph computeGraph(int treeToPrint, ConvertTreeOptions options) {
        SharedTreeGraph g = new SharedTreeGraph();

        if (treeToPrint >= _ntree_groups) {
            throw new IllegalArgumentException("Tree " + treeToPrint + " does not exist (max " + _ntree_groups + ")");
        }

        int j;
        if (treeToPrint >= 0) {
            j = treeToPrint;
        }
        else {
            j = 0;
        }

        for (; j < _ntree_groups; j++) {
            for (int i = 0; i < _ntrees_per_group; i++) {
                int itree = treeIndex(j, i);
                String[] domainValues = isSupervised() ? getDomainValues(getResponseIdx()) : null;
                String treeName = treeName(j, i, domainValues);
                SharedTreeSubgraph sg = g.makeSubgraph(treeName);
                computeTreeGraph(sg, _compressed_trees[itree], _compressed_trees_aux[itree],
                        getNames(), getDomainValues(), options);
            }

            if (treeToPrint >= 0) {
                break;
            }
        }

        return g;
    }

    public SharedTreeGraph computeGraph(int treeId) {
      return computeGraph(treeId, ConvertTreeOptions.DEFAULT);
    }

    @Deprecated
    @SuppressWarnings("unused")
    public SharedTreeGraph _computeGraph(int treeId) {
      return computeGraph(treeId);
    }

    public static SharedTreeSubgraph computeTreeGraph(int treeNum, String treeName, byte[] tree, byte[] auxTreeInfo,
                                                      String names[], String[][] domains) {
      return computeTreeGraph(treeNum, treeName, tree, auxTreeInfo, names, domains, ConvertTreeOptions.DEFAULT);
    }

    public static SharedTreeSubgraph computeTreeGraph(int treeNum, String treeName, byte[] tree, byte[] auxTreeInfo,
                                                      String names[], String[][] domains, ConvertTreeOptions options) {
      SharedTreeSubgraph sg = new SharedTreeSubgraph(treeNum, treeName);
      computeTreeGraph(sg, tree, auxTreeInfo, names, domains, options);
      return sg;
    }
  
    private static void computeTreeGraph(SharedTreeSubgraph sg, byte[] tree, byte[] auxTreeInfo,
                                         String names[], String[][] domains, ConvertTreeOptions options) {
      SharedTreeNode node = sg.makeRootNode();
      node.setSquaredError(Float.NaN);
      node.setPredValue(Float.NaN);
      ByteBufferWrapper ab = new ByteBufferWrapper(tree);
      ByteBufferWrapper abAux = new ByteBufferWrapper(auxTreeInfo);
      HashMap auxMap = readAuxInfos(abAux);
      computeTreeGraph(sg, node, tree, ab, auxMap, names, domains, options);
    }

    public static Map readAuxInfos(byte[] auxTreeInfo) {
      ByteBufferWrapper abAux = new ByteBufferWrapper(auxTreeInfo);
      return readAuxInfos(abAux); 
    }

    public static int findMaxNodeId(byte[] auxTreeInfo) {
      int maxNodeId = 0;
      AuxInfoLightReader reader = new AuxInfoLightReader(auxTreeInfo);
      while (reader.hasNext()) {
        int nodeId = reader.readMaxChildNodeIdAndSkip();
        if (maxNodeId < nodeId)
          maxNodeId = nodeId;
      }
      return maxNodeId;
    }
    
    private static HashMap readAuxInfos(ByteBufferWrapper abAux) {
      HashMap auxMap = new HashMap<>();
      Map nodeIdToParent = new HashMap<>();
      nodeIdToParent.put(0, new AuxInfo());
      boolean reservedFieldIsParentId = false; // In older H2O versions `reserved` field was used for parent id
      while (abAux.hasRemaining()) {
        AuxInfo auxInfo = new AuxInfo(abAux);
        if (auxMap.size() == 0) {
          reservedFieldIsParentId = auxInfo.reserved < 0; // `-1` indicates No Parent, reserved >= 0 indicates reserved is not used for parent ids!
        }
        AuxInfo parent = nodeIdToParent.get(auxInfo.nid);
        if (parent == null)
          throw new IllegalStateException("Parent for nodeId=" + auxInfo.nid + " not found.");
        assert !reservedFieldIsParentId || parent.nid == auxInfo.reserved : "Corrupted Tree Info: parent nodes do not correspond (pid: " +
                parent.nid + ", reserved: " + auxInfo.reserved + ")";
        auxInfo.setPid(parent.nid);
        nodeIdToParent.put(auxInfo.nidL, auxInfo);
        nodeIdToParent.put(auxInfo.nidR, auxInfo);
        auxMap.put(auxInfo.nid, auxInfo);
      }
      return auxMap;
    }

    public static void writeUpdatedAuxInfos(byte[] origAux, Map updatedAuxInfos, ByteBuffer bb) {
      AuxInfoLightReader reader = new AuxInfoLightReader(origAux);
      int count = 0;
      while (reader.hasNext()) {
          count++;
          int nid = reader.readNodeIdAndSkip();
          AuxInfo auxInfo = updatedAuxInfos.get(nid);
          if (auxInfo == null)
              throw new IllegalStateException("Updated AuxInfo for nodeId=" + nid + " doesn't exist. " +
                      "All AuxInfos need to be represented in the updated structure.");
          auxInfo.writeTo(bb);
      }
      assert count == updatedAuxInfos.size();
    }
    
    public static String treeName(int groupIndex, int classIndex, String[] domainValues) {
      String className = "";
      {
        if (domainValues != null) {
          className = ", Class " + domainValues[classIndex];
        }
      }
      return "Tree " + groupIndex + className;
    }

    // Please see AuxInfo for details of the serialized format
    private static class AuxInfoLightReader {
      private final ByteBufferWrapper _abAux;
      int _nid;
      int _numLeftChildren;

      private AuxInfoLightReader(byte[] auxInfo) {
        this(new ByteBufferWrapper(auxInfo));
      }

      private AuxInfoLightReader(ByteBufferWrapper abAux) {
        _abAux = abAux;
      }

      private void readNext() {
        _nid = _abAux.get4();
        _numLeftChildren = _abAux.get4();
      }

      private boolean hasNext() {
        return _abAux.hasRemaining();
      }

      private int readMaxChildNodeIdAndSkip() {
        _abAux.skip(AuxInfo.SIZE - 8);
        int leftId = _abAux.get4();
        int rightId = _abAux.get4();
        return Math.max(leftId, rightId);
      }
      
      private int readNodeIdAndSkip() {
        readNext();
        skipNode();
        return _nid;
      }
      
      private int getLeftNodeIdAndSkipNode() {
        _abAux.skip(4 * 6);
        int n = _abAux.get4();
        _abAux.skip(4);
        return n;
      }

      private int getRightNodeIdAndSkipNode() {
        _abAux.skip(4 * 7);
        return _abAux.get4();
      }

      private void skipNode() {
        _abAux.skip(AuxInfo.SIZE - 8);
      }

      private void skipNodes(int num) {
        _abAux.skip(AuxInfo.SIZE * num);
      }

    }

    public static class AuxInfo {
      private static final int SIZE = 10 * 4;

      private AuxInfo() {
        nid = -1;
        reserved = -1;
      }

      // Warning: any changes in this structure need to be reflected also in AuxInfoLightReader!!!
      AuxInfo(ByteBufferWrapper abAux) {
        // node ID
        nid = abAux.get4();

        // ignored - can contain either parent id or number of children (depending on a MOJO version)
        reserved = abAux.get4();

        //sum of observation weights (typically, that's just the count of observations)
        weightL = abAux.get4f();
        weightR = abAux.get4f();

        //predicted values
        predL   = abAux.get4f();
        predR   = abAux.get4f();

        //squared error
        sqErrL  = abAux.get4f();
        sqErrR  = abAux.get4f();

        //node IDs (consistent with tree construction)
        nidL = abAux.get4();
        nidR = abAux.get4();
      }

      void writeTo(ByteBuffer bb) {
        // node ID
        bb.putInt(nid);

        // reserved
        bb.putInt(reserved);

        // sum of observation weights
        bb.putFloat(weightL);
        bb.putFloat(weightR);

        // predicted values
        bb.putFloat(predL);
        bb.putFloat(predR);

        // squared error
        bb.putFloat(sqErrL);
        bb.putFloat(sqErrR);

        // node IDs
        bb.putInt(nidL);
        bb.putInt(nidR);
      }

      final void setPid(int parentId) {
        pid = parentId;
      }

      @Override public String toString() {
        return  "nid: " + nid + "\n" +
                "pid: " + pid + "\n" +
                "nidL: " + nidL + "\n" +
                "nidR: " + nidR + "\n" +
                "weightL: " + weightL + "\n" +
                "weightR: " + weightR + "\n" +
                "predL: " + predL + "\n" +
                "predR: " + predR + "\n" +
                "sqErrL: " + sqErrL + "\n" +
                "sqErrR: " + sqErrR + "\n" +
                "reserved: " + reserved + "\n";
      }

      public int nid, pid, nidL, nidR;
      private final int reserved;
      public float weightL, weightR, predL, predR, sqErrL, sqErrR;
    }

    static void checkConsistency(AuxInfo auxInfo, SharedTreeNode node) {
      boolean ok = true;
      boolean weight_ok = true;
      ok &= (auxInfo.nid == node.getNodeNumber());
      double sum = 0;
      if (node.leftChild!=null) {
        ok &= (auxInfo.nidL == node.leftChild.getNodeNumber());
        ok &= (auxInfo.weightL == node.leftChild.getWeight());
        ok &= (auxInfo.predL == node.leftChild.predValue);
        ok &= (auxInfo.sqErrL == node.leftChild.squaredError);
        sum += node.leftChild.getWeight();
      }
      if (node.rightChild!=null) {
        ok &= (auxInfo.nidR == node.rightChild.getNodeNumber());
        ok &= (auxInfo.weightR == node.rightChild.getWeight());
        ok &= (auxInfo.predR == node.rightChild.predValue);
        ok &= (auxInfo.sqErrR == node.rightChild.squaredError);
        sum += node.rightChild.getWeight();
      }
      if (node.parent!=null) {
        ok &= (auxInfo.pid == node.parent.getNodeNumber());
        weight_ok = (Math.abs(node.getWeight() - sum) < 1e-5 * (node.getWeight() + sum));
        ok &= weight_ok;
      }
      if (!ok && logger.isErrorEnabled()) {
        logger.error("\nTree inconsistency found:");
        if (node.depth == 1 && !weight_ok) { 
            logger.error("Note: this is a known issue for DRF and Isolation Forest models, " +
                  "please refer to https://github.com/h2oai/h2o-3/issues/12971");
        }
        logger.error(node.getPrintString("parent"));
        logger.error(node.leftChild.getPrintString("left child"));
        logger.error(node.rightChild.getPrintString("right child"));
        logger.error("Auxiliary tree info:");
        logger.error(auxInfo.toString());
      }
    }

    //------------------------------------------------------------------------------------------------------------------
    // Private
    //------------------------------------------------------------------------------------------------------------------

    protected SharedTreeMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    protected SharedTreeMojoModel(String[] columns, String[][] domains, String responseColumn, String treatmentColumn) {
        super(columns, domains, responseColumn, treatmentColumn);
    }

    /**
     * Score all trees and fill in the `preds` array.
     */
    protected void scoreAllTrees(double[] row, double[] preds) {
        java.util.Arrays.fill(preds, 0);
        scoreTreeRange(row, 0, _ntree_groups, preds);
    }

    /**
     * Transforms tree predictions into the final model predictions.
     * For classification: converts tree preds into probability distribution and picks predicted class.
     * For regression: projects tree prediction from link-space into the original space.
     * @param row input row.
     * @param offset offset.
     * @param preds final output, same structure as of {@link SharedTreeMojoModel#score0}.
     * @return preds array.
     */
    public abstract double[] unifyPreds(double[] row, double offset, double[] preds);

  /**
   * Generates a (per-class) prediction using only a single tree.
   * @param row input row
   * @param index index of the tree (0..N-1)
   * @param preds array of partial predictions.
   */
    public final void scoreSingleTree(double[] row, int index, double preds[]) {
      scoreTreeRange(row, index, index + 1, preds);
    }

    /**
     * Generates (partial, per-class) predictions using only trees from a given range.
     * @param row input row
     * @param fromIndex low endpoint (inclusive) of the tree range
     * @param toIndex high endpoint (exclusive) of the tree range
     * @param preds array of partial predictions.
     *              To get final predictions pass the result to {@link SharedTreeMojoModel#unifyPreds}.
     */
    public final void scoreTreeRange(double[] row, int fromIndex, int toIndex, double[] preds) {
        final int clOffset = _nclasses == 1 ? 0 : 1;
        for (int classIndex = 0; classIndex < _ntrees_per_group; classIndex++) {
            int k = clOffset + classIndex;
            int itree = treeIndex(fromIndex, classIndex);
            for (int groupIndex = fromIndex; groupIndex < toIndex; groupIndex++) {
                if (_compressed_trees[itree] != null) { // Skip all empty trees
                  preds[k] += _scoreTree.scoreTree(_compressed_trees[itree], row, false, _domains);
                }
                itree++;
            }
        }
    }

    // note that _ntree_group = _treekeys.length
    // ntrees_per_group = _treeKeys[0].length
    public String[] getDecisionPathNames() {
      int classTrees = 0;
      for (int i = 0; i < _ntrees_per_group; ++i) {
        int itree = treeIndex(0, i);
        if (_compressed_trees[itree] != null) classTrees++;
      }
      final int outputcols = _ntree_groups * classTrees;
      final String[] names = new String[outputcols];
      for (int c = 0; c < _ntrees_per_group; c++) {
        for (int tidx = 0; tidx < _ntree_groups; tidx++) {
          int itree = treeIndex(tidx, c);
          if (_compressed_trees[itree] != null) {
            names[itree] = "T" + (tidx + 1) + ".C" + (c + 1);
          }
        }
      }
      return names;
    }

    public static class LeafNodeAssignments {
      public String[] _paths;
      public int[] _nodeIds;
    }

    public LeafNodeAssignments getLeafNodeAssignments(final double[] row) {
      LeafNodeAssignments assignments = new LeafNodeAssignments();
      assignments._paths = new String[_compressed_trees.length];
      if (_mojo_version >= 1.3 && _compressed_trees_aux != null) { // enable only for compatible MOJOs
        assignments._nodeIds = new int[_compressed_trees_aux.length];
      }
      traceDecisions(row, assignments._paths, assignments._nodeIds);
      return assignments;
    }

    public String[] getDecisionPath(final double[] row) {
      String[] paths = new String[_compressed_trees.length];
      traceDecisions(row, paths, null);
      return paths;
    }

    private void traceDecisions(final double[] row, String[] paths, int[] nodeIds) {
      if (_mojo_version < 1.2) {
        throw new IllegalArgumentException("You can only obtain decision tree path with mojo versions 1.2 or higher");
      }
      for (int j = 0; j < _ntree_groups; j++) {
        for (int i = 0; i < _ntrees_per_group; i++) {
          int itree = treeIndex(j, i);
          double d = scoreTree(_compressed_trees[itree], row, true, _domains);
          if (paths != null)
            paths[itree] = SharedTreeMojoModel.getDecisionPath(d);
          if (nodeIds != null) {
            assert _mojo_version >= 1.3;
            nodeIds[itree] = SharedTreeMojoModel.getLeafNodeId(d, _compressed_trees_aux[itree]);
          }
        }
      }
    }

    /**
     * Locates a tree in the array of compressed trees.
     * @param groupIndex index of the tree in a class-group of trees
     * @param classIndex index of the class
     * @return index of the tree in _compressed_trees.
     */
    final int treeIndex(int groupIndex, int classIndex) {
        return classIndex * _ntree_groups + groupIndex;
    }

    public final byte[] treeBytes(int groupIndex, int classIndex) {
        return _compressed_trees[treeIndex(groupIndex, classIndex)];
    }

  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  // DO NOT CHANGE THE CODE BELOW THIS LINE
  /////////////////////////////////////////////////////
  /**
   * SET IN STONE FOR MOJO VERSION "1.00" - DO NOT CHANGE
   * @param tree
   * @param row
   * @param computeLeafAssignment
   * @return
   */
  @SuppressWarnings("ConstantConditions")  // Complains that the code is too complex. Well duh!
  public static double scoreTree0(byte[] tree, double[] row, boolean computeLeafAssignment) {
    ByteBufferWrapper ab = new ByteBufferWrapper(tree);
    GenmodelBitSet bs = null;  // Lazily set on hitting first group test
    long bitsRight = 0;
    int level = 0;
    while (true) {
      int nodeType = ab.get1U();
      int colId = ab.get2();
      if (colId == 65535) return ab.get4f();
      int naSplitDir = ab.get1U();
      boolean naVsRest = naSplitDir == NsdNaVsRest;
      boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
      int lmask = (nodeType & 51);
      int equal = (nodeType & 12);  // Can be one of 0, 8, 12
      assert equal != 4;  // no longer supported

      float splitVal = -1;
      if (!naVsRest) {
        // Extract value or group to split on
        if (equal == 0) {
          // Standard float-compare test (either < or ==)
          splitVal = ab.get4f();  // Get the float to compare
        } else {
          // Bitset test
          if (bs == null) bs = new GenmodelBitSet(0);
          if (equal == 8)
            bs.fill2(tree, ab);
          else
            bs.fill3_1(tree, ab);
        }
      }

      double d = row[colId];
      if (Double.isNaN(d)? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains0((int)d))) {
        // go RIGHT
        switch (lmask) {
          case 0:  ab.skip(ab.get1U());  break;
          case 1:  ab.skip(ab.get2());  break;
          case 2:  ab.skip(ab.get3());  break;
          case 3:  ab.skip(ab.get4());  break;
          case 48: ab.skip(4);  break;  // skip the prediction
          default:
            assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
        }
        if (computeLeafAssignment && level < 64) bitsRight |= 1 << level;
        lmask = (nodeType & 0xC0) >> 2;  // Replace leftmask with the rightmask
      } else {
        // go LEFT
        if (lmask <= 3)
          ab.skip(lmask + 1);
      }

      level++;
      if ((lmask & 16) != 0) {
        if (computeLeafAssignment) {
          bitsRight |= 1 << level;  // mark the end of the tree
          return Double.longBitsToDouble(bitsRight);
        } else {
          return ab.get4f();
        }
      }
    }
  }

  /**
   * SET IN STONE FOR MOJO VERSION "1.10" - DO NOT CHANGE
   * @param tree
   * @param row
   * @param computeLeafAssignment
   * @return
   */
  @SuppressWarnings("ConstantConditions")  // Complains that the code is too complex. Well duh!
  public static double scoreTree1(byte[] tree, double[] row, boolean computeLeafAssignment) {
    ByteBufferWrapper ab = new ByteBufferWrapper(tree);
    GenmodelBitSet bs = null;
    long bitsRight = 0;
    int level = 0;
    while (true) {
      int nodeType = ab.get1U();
      int colId = ab.get2();
      if (colId == 65535) return ab.get4f();
      int naSplitDir = ab.get1U();
      boolean naVsRest = naSplitDir == NsdNaVsRest;
      boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
      int lmask = (nodeType & 51);
      int equal = (nodeType & 12);  // Can be one of 0, 8, 12
      assert equal != 4;  // no longer supported

      float splitVal = -1;
      if (!naVsRest) {
        // Extract value or group to split on
        if (equal == 0) {
          // Standard float-compare test (either < or ==)
          splitVal = ab.get4f();  // Get the float to compare
        } else {
          // Bitset test
          if (bs == null) bs = new GenmodelBitSet(0);
          if (equal == 8)
            bs.fill2(tree, ab);
          else
            bs.fill3_1(tree, ab);
        }
      }

      double d = row[colId];
      if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) )
              ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) {
        // go RIGHT
        switch (lmask) {
          case 0:  ab.skip(ab.get1U());  break;
          case 1:  ab.skip(ab.get2());  break;
          case 2:  ab.skip(ab.get3());  break;
          case 3:  ab.skip(ab.get4());  break;
          case 48: ab.skip(4);  break;  // skip the prediction
          default:
            assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
        }
        if (computeLeafAssignment && level < 64) bitsRight |= 1L << level;
        lmask = (nodeType & 0xC0) >> 2;  // Replace leftmask with the rightmask
      } else {
        // go LEFT
        if (lmask <= 3)
          ab.skip(lmask + 1);
      }

      level++;
      if ((lmask & 16) != 0) {
        if (computeLeafAssignment) {
          bitsRight |= 1L << level;  // mark the end of the tree
          return Double.longBitsToDouble(bitsRight);
        } else {
          return ab.get4f();
        }
      }
    }
  }

    @Override
    public boolean calibrateClassProbabilities(double[] preds) { 
      return CalibrationMojoHelper.calibrateClassProbabilities(this, preds);
    }

    @Override
    public double[] getCalibGlmBeta() {
        return _calib_glm_beta;
    }

    @Override
    public IsotonicCalibrator getIsotonicCalibrator() {
        return _isotonic_calibrator;
    }

    @Override
    public SharedTreeGraph convert(final int treeNumber, final String treeClass) {
        return computeGraph(treeNumber);
    }

    @Override
    public SharedTreeGraph convert(final int treeNumber, final String treeClass, ConvertTreeOptions options) {
      return computeGraph(treeNumber, options);
    }

    /**
     * Returns staged predictions of tree algorithms (prediction probabilities of trees per iteration).
     * The output structure is for tree Tt and class Cc:
     * Binomial models: [probability T1.C1, probability T2.C1, ..., Tt.C1] where Tt.C1 correspond to the the probability p0
     * Multinomial models: [probability T1.C1, probability T1.C2, ..., Tt.Cc]
     * @param row Input row.
     * @param predsLength Length of prediction result.
     * @return array of staged prediction probabilities
     */
    public double[] scoreStagedPredictions(double[] row, int predsLength) {
        int contribOffset = nclasses() == 1 ? 0 : 1;
        double[] trees_result = new double[_ntree_groups * _ntrees_per_group];

        for (int groupIndex = 0; groupIndex < _ntree_groups; groupIndex++) {
            double[] tmpPreds = new double[predsLength];
            scoreTreeRange(row, 0, groupIndex+1, tmpPreds);
            unifyPreds(row, 0, tmpPreds);
            for (int classIndex = 0; classIndex < _ntrees_per_group; classIndex++) {
                int tree_index = groupIndex * _ntrees_per_group + classIndex;
                trees_result[tree_index] = tmpPreds[contribOffset+classIndex];
            }
        }
        return trees_result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy