All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.XGBoostRegTree Maven / Gradle / Ivy

package ml.dmlc.xgboost4j.java;

import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeImpl;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;
import water.util.UnsafeUtils;

import java.io.IOException;

/**
 * Regression tree.
 */
public class XGBoostRegTree implements RegTree {

  private static final int NODE_SIZE = 20;
  private static final int STATS_SIZE = 16;

  private byte[] _nodes;

  /**
   * Loads model from stream.
   *
   * @param reader input stream
   * @throws IOException If an I/O error occurs
   */
  XGBoostRegTree(ModelReader reader) throws IOException {
    final int numNodes = readNumNodes(reader);
    _nodes = reader.readByteArray(numNodes * NODE_SIZE);
    reader.skip(numNodes * STATS_SIZE);
  }

  /**
   * Retrieves nodes from root to leaf and returns leaf index.
   *
   * @param feat    feature vector
   * @param root_id starting root index
   * @return leaf index
   */
  @Override
  public int getLeafIndex(FVec feat, int root_id) {
    throw new UnsupportedOperationException("Leaf node id assignment is currently not supported");
  }

  /**
   * Retrieves nodes from root to leaf and returns leaf value.
   *
   * @param feat    feature vector
   * @param root_id starting root index
   * @return leaf value
   */
  @Override
  public final float getLeafValue(FVec feat, int root_id) {
    int pid = root_id;

    int pos = pid * NODE_SIZE + 4;
    int cleft_ = UnsafeUtils.get4(_nodes, pos);

    while (cleft_ != -1) {
      final int sindex_ = UnsafeUtils.get4(_nodes, pos + 8);
      final float fvalue = feat.fvalue((int) (sindex_ & ((1L << 31) - 1L)));
      if (Float.isNaN(fvalue)) {
        pid = (sindex_ >>> 31) != 0 ? cleft_ : UnsafeUtils.get4(_nodes, pos + 4);
      } else {
        final float value_ = UnsafeUtils.get4f(_nodes, pos + 12);
        pid = (fvalue < value_) ? cleft_ : UnsafeUtils.get4(_nodes, pos + 4);
      }
      pos = pid * NODE_SIZE + 4;
      cleft_ = UnsafeUtils.get4(_nodes, pos);
    }

    return UnsafeUtils.get4f(_nodes, pos + 12);
  }

  @Override
  public RegTreeNode[] getNodes() {
    throw new UnsupportedOperationException();
  }


  private static int readNumNodes(ModelReader reader) throws IOException {
    int numRoots = reader.readInt();
    assert numRoots == 1;
    int numNodes = reader.readInt();
    reader.skip(4 * 4 + 31 * 4); // skip {int num_deleted, int max_depth, int num_feature, size_leaf_vector, 31 * reserved int}
    return numNodes;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy