All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.predict.PredictTreeSHAPTask Maven / Gradle / Ivy

package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

import java.util.Arrays;

import static hex.Model.Contributions.ContributionsOptions;
import static hex.Model.Contributions.ContributionsOutputFormat;

public class PredictTreeSHAPTask extends MRTask {

  protected final DataInfo _di;
  protected final XGBoostModelInfo _modelInfo;
  protected final XGBoostOutput _output;
  protected final boolean _outputAggregated;


  protected transient XGBoostJavaMojoModel _mojo;

  public PredictTreeSHAPTask(DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output,
                             ContributionsOptions options) {
    _di = di;
    _modelInfo = modelInfo;
    _output = output;
    _outputAggregated = ContributionsOutputFormat.Compact.equals(options._outputFormat);
  }

  @Override
  protected void setupLocal() {
    _mojo = new XGBoostJavaMojoModel(
            _modelInfo._boosterBytes, _modelInfo.auxNodeWeightBytes(), 
            _output._names, _output._domains, _output.responseName(), 
            true
    );
  }

  protected void fillInput(Chunk chks[], int row, double[] input, float[] contribs) {
    for (int i = 0; i < chks.length; i++) {
      input[i] = chks[i].atd(row);
    }
    Arrays.fill(contribs, 0);
  }

  @Override
  public void map(Chunk[] chks, NewChunk[] nc) {
    MutableOneHotEncoderFVec rowFVec = new MutableOneHotEncoderFVec(_di, _output._sparse);

    double[] input = MemoryManager.malloc8d(chks.length);
    float[] contribs = MemoryManager.malloc4f(_di.fullN() + 1);
    float[] output = _outputAggregated ? MemoryManager.malloc4f(nc.length) : contribs;

    TreeSHAPPredictor.Workspace workspace = _mojo.makeContributionsWorkspace();

    for (int row = 0; row < chks[0]._len; row++) {
      fillInput(chks, row, input, contribs);
      rowFVec.setInput(input);

      // calculate Shapley values
      _mojo.calculateContributions(rowFVec, contribs, workspace);

      handleOutputFormat(rowFVec, contribs, output);

      addContribToNewChunk(output, nc);
    }
  }

  protected void handleOutputFormat(final MutableOneHotEncoderFVec rowFVec, final float[] contribs, final float[] output) {
    if (_outputAggregated) {
      rowFVec.decodeAggregate(contribs, output);
      output[output.length - 1] = contribs[contribs.length - 1]; // bias term
    }
  }

  protected void addContribToNewChunk(final float[] contribs, final NewChunk[] nc) {
    for (int i = 0; i < nc.length; i++) {
      nc[i].addNum(contribs[i]);
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy