All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.predict.PredictTreeSHAPTask Maven / Gradle / Ivy

package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

public class PredictTreeSHAPTask extends MRTask {

  private final DataInfo _di;
  private final XGBoostModelInfo _modelInfo;
  private final XGBoostOutput _output;

  private transient XGBoostJavaMojoModel _mojo; 

  public PredictTreeSHAPTask(DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output) {
    _di = di;
    _modelInfo = modelInfo;
    _output = output;
  }

  @Override
  protected void setupLocal() {
    _mojo = new XGBoostJavaMojoModel(
        _modelInfo._boosterBytes, _output._names, _output._domains, _output.responseName(), true
    );
  }

  @Override
  public void map(Chunk chks[], NewChunk[] nc) {
    MutableOneHotEncoderFVec rowFVec = new MutableOneHotEncoderFVec(_di, _output._sparse);

    double[] input = MemoryManager.malloc8d(chks.length);
    float[] contribs = MemoryManager.malloc4f(nc.length);

    Object workspace = _mojo.makeContributionsWorkspace();

    for (int row = 0; row < chks[0]._len; row++) {
      for (int i = 0; i < chks.length; i++) {
        input[i] = chks[i].atd(row);
      }
      for (int i = 0; i < contribs.length; i++) {
        contribs[i] = 0;
      }
      rowFVec.setInput(input);

      // calculate Shapley values
      _mojo.calculateContributions(rowFVec, contribs, workspace);

      for (int i = 0; i < nc.length; i++) {
        nc[i].addNum(contribs[i]);
      }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy