All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.xgboost.XGBoostJavaMojoModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.TreeSHAPHelper;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.PredictContributionsFactory;
import hex.genmodel.algos.tree.*;
import hex.genmodel.PredictContributions;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

/**
 * Implementation of XGBoostMojoModel that uses Pure Java Predict
 * see https://github.com/h2oai/xgboost-predictor
 */
public final class XGBoostJavaMojoModel extends XGBoostMojoModel implements PredictContributionsFactory {

  private Predictor _predictor;
  private TreeSHAPPredictor _treeSHAPPredictor;
  private OneHotEncoderFactory _1hotFactory;

  @Deprecated
  public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn) {
    this(boosterBytes, null, columns, domains, responseColumn, false);
  }

  @Deprecated
  public XGBoostJavaMojoModel(byte[] boosterBytes,
                              String[] columns, String[][] domains, String responseColumn,
                              boolean enableTreeSHAP) {
    this(boosterBytes, null, columns, domains, responseColumn, enableTreeSHAP);
  }

  public XGBoostJavaMojoModel(byte[] boosterBytes, byte[] auxNodeWeightBytes, 
                              String[] columns, String[][] domains, String responseColumn, 
                              boolean enableTreeSHAP) {
    super(columns, domains, responseColumn);
    _predictor = makePredictor(boosterBytes, auxNodeWeightBytes);
    _treeSHAPPredictor = enableTreeSHAP ? makeTreeSHAPPredictor(_predictor) : null;
  }

  @Override
  public void postReadInit() {
    _1hotFactory = new OneHotEncoderFactory(
        backwardsCompatibility10(), _sparse, _catOffsets, _cats, _nums, _useAllFactorLevels
    );
  }
  
  private boolean backwardsCompatibility10() {
    return _mojo_version == 1.0 && !"gbtree".equals(_boosterType);
  }

  public static Predictor makePredictor(byte[] boosterBytes, byte[] auxNodeWeightBytes) {
    try (InputStream is = new ByteArrayInputStream(boosterBytes)) {
      Predictor p = new Predictor(is);
      updateNodeWeights(p, auxNodeWeightBytes);
      return p;
    } catch (IOException e) {
      throw new IllegalStateException("Failed to load predictor.", e);
    }
  }
  public static void updateNodeWeights(Predictor predictor, byte[] auxNodeWeightBytes) {
    if (auxNodeWeightBytes == null)
      return;
    assert predictor.getNumClass() <= 2;
    GBTree gbTree = (GBTree) predictor.getBooster();
    RegTree[] trees = gbTree.getGroupedTrees()[0];
    double[][] weights = AuxNodeWeightsHelper.fromBytes(auxNodeWeightBytes);
    assert trees.length == weights.length;
    AuxNodeWeightsHelper.updateNodeWeights(trees, weights);
  }
  private static TreeSHAPPredictor makeTreeSHAPPredictor(Predictor predictor) {
    if (predictor.getNumClass() > 2) {
      throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
    }
    GBTree gbTree = (GBTree) predictor.getBooster();
    RegTree[] trees = gbTree.getGroupedTrees()[0];
    List> predictors = new ArrayList<>(trees.length);
    for (RegTree tree : trees) {
      predictors.add(TreeSHAPHelper.makePredictor(tree));
    }
    float initPred = predictor.getBaseScore();
    return new TreeSHAPEnsemble<>(predictors, initPred);
  }

  public final double[] score0(double[] doubles, double offset, double[] preds) {
    if (backwardsCompatibility10()) {
      // throw an exception for unexpectedly long input vector
      if (doubles.length > _cats + _nums) {
        throw new ArrayIndexOutOfBoundsException("Too many input values.");
      }
      // for unexpectedly short input vector handle the situation gracefully
      if (doubles.length < _cats + _nums) {
        double[] tmp = new double[_cats + _nums];
        System.arraycopy(doubles, 0,tmp, 0, doubles.length);
        doubles = tmp;
      }
    }
    FVec row = _1hotFactory.fromArray(doubles);
    float[] out;
    if (_hasOffset) {
      out = _predictor.predict(row, (float) offset);
    } else if (offset != 0) {
      throw new UnsupportedOperationException("Unsupported: offset != 0");
    } else {
      out = _predictor.predict(row);
    }
    return toPreds(doubles, out, preds, _nclasses, _priorClassDistrib, _defaultThreshold);
  }

  public final TreeSHAPPredictor.Workspace makeContributionsWorkspace() {
    return _treeSHAPPredictor.makeWorkspace();
  }

  public final float[] calculateContributions(FVec row, float[] out_contribs, TreeSHAPPredictor.Workspace workspace) {
    _treeSHAPPredictor.calculateContributions(row, out_contribs, 0, -1, workspace);
    return out_contribs;
  }

  @Override
  public final PredictContributions makeContributionsPredictor() {
    TreeSHAPPredictor treeSHAPPredictor = _treeSHAPPredictor != null ? 
            _treeSHAPPredictor : makeTreeSHAPPredictor(_predictor);
    return new XGBoostContributionsPredictor(this, treeSHAPPredictor);
  }

  static ObjFunction getObjFunction(String name) {
    return ObjFunction.fromName(name);
  }

  @Override
  public void close() {
    _predictor = null;
    _treeSHAPPredictor = null;
    _1hotFactory = null;
  }

  @Override
  public SharedTreeGraph convert(final int treeNumber, final String treeClass) {
    GradBooster booster = _predictor.getBooster();
    return computeGraph(booster, treeNumber);
  }

  @Override
  public SharedTreeGraph convert(final int treeNumber, final String treeClass, final ConvertTreeOptions options) {
    return convert(treeNumber, treeClass); // Options currently do not apply to XGBoost trees conversion
  }

  @Override
  public double getInitF() {
    return _predictor.getBaseScore();
  }

  @Override
  public SharedTreeMojoModel.LeafNodeAssignments getLeafNodeAssignments(double[] doubles) {
    FVec row = _1hotFactory.fromArray(doubles);
    final SharedTreeMojoModel.LeafNodeAssignments result = new SharedTreeMojoModel.LeafNodeAssignments();
    result._paths = _predictor.predictLeafPath(row);
    result._nodeIds = _predictor.predictLeaf(row);
    return result;
  }

  @Override
  public String[] getDecisionPath(double[] doubles) {
    FVec row = _1hotFactory.fromArray(doubles);
    return _predictor.predictLeafPath(row);
  }

  private final class XGBoostContributionsPredictor extends ContributionsPredictor {
    private XGBoostContributionsPredictor(XGBoostMojoModel model, TreeSHAPPredictor treeSHAPPredictor) {
      super(_nums + _catOffsets[_cats] + 1, makeFeatureContributionNames(model), treeSHAPPredictor);
    }

    @Override
    protected FVec toInputRow(double[] input) {
      return _1hotFactory.fromArray(input);
    }
  }

  private static String[] makeFeatureContributionNames(XGBoostMojoModel m) {
    final String[] names = new String[m._nums + m._catOffsets[m._cats]];
    final String[] features = m.features();
    int i = 0;
    for (int c = 0; c < features.length; c++) {
      if (m._domains[c] == null) {
        names[i++] = features[c];
      } else {
        for (String d : m._domains[c])
          names[i++] = features[c] + "." + d;
        names[i++] = features[c] + ".missing(NA)";
      }
    }
    assert names.length == i;
    return names;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy