All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.SharedTreeModelWithContributions Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.*;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import hex.Model;
import water.Key;

public abstract class SharedTreeModelWithContributions<
        M extends SharedTreeModel,
        P extends SharedTreeModel.SharedTreeParameters,
        O extends SharedTreeModel.SharedTreeOutput
        > extends SharedTreeModel implements Model.Contributions {

  public SharedTreeModelWithContributions(Key selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

  @Override
  public Frame scoreContributions(Frame frame, Key destination_key) {
    return scoreContributions(frame, destination_key, null);
  }

  protected Frame removeSpecialColumns(Frame frame) {
    Frame adaptFrm = new Frame(frame);
    adaptTestForTrain(adaptFrm, true, false);
    // remove non-feature columns
    adaptFrm.remove(_parms._response_column);
    adaptFrm.remove(_parms._fold_column);
    adaptFrm.remove(_parms._weights_column);
    adaptFrm.remove(_parms._offset_column);
    return adaptFrm;
  }

  @Override
  public Frame scoreContributions(Frame frame, Key destination_key, Job j) {
    if (_output.nclasses() > 2) {
      throw new UnsupportedOperationException(
              "Calculating contributions is currently not supported for multinomial models.");
    }

    Frame adaptFrm = removeSpecialColumns(frame);

    final String[] outputNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
    
    return getScoreContributionsTask(this)
            .withPostMapAction(JobUpdatePostMap.forJob(j))
            .doAll(outputNames.length, Vec.T_NUM, adaptFrm)
            .outputFrame(destination_key, outputNames, null);
  }

  @Override
  public Frame scoreContributions(Frame frame, Key destination_key, Job j, ContributionsOptions options) {
    if (_output.nclasses() > 2) {
      throw new UnsupportedOperationException(
              "Calculating contributions is currently not supported for multinomial models.");
    }
    if (options._outputFormat == ContributionsOutputFormat.Compact) {
      throw new UnsupportedOperationException(
              "Only output_format \"Original\" is supported for this model.");
    }
    if (!options.isSortingRequired()) {
      return scoreContributions(frame, destination_key, j);
    }

    Frame adaptFrm = removeSpecialColumns(frame);
    final String[] contribNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");

    final ContributionComposer contributionComposer = new ContributionComposer();
    int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, adaptFrm.names().length);
    int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, adaptFrm.names().length);

    int outputSize = Math.min((topNAdjusted+bottomNAdjusted)*2, adaptFrm.names().length*2);
    String[] names = new String[outputSize+1];
    byte[] types = new byte[outputSize+1];
    String[][] domains = new String[outputSize+1][contribNames.length];

    composeScoreContributionTaskMetadata(names, types, domains, adaptFrm.names(), options);

    return getScoreContributionsSoringTask(this, options)
            .withPostMapAction(JobUpdatePostMap.forJob(j))
            .doAll(types, adaptFrm)
            .outputFrame(destination_key, names, domains);
  }

  protected abstract ScoreContributionsTask getScoreContributionsTask(SharedTreeModel model);

  protected abstract ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel model, ContributionsOptions options);

  public class ScoreContributionsTask extends MRTask {
    protected final Key _modelKey;

    protected transient SharedTreeModel _model;
    protected transient SharedTreeOutput _output;
    protected transient TreeSHAPPredictor _treeSHAP;

    public ScoreContributionsTask(SharedTreeModel model) {
      _modelKey = model._key;
    }

    @Override
    @SuppressWarnings("unchecked")
    protected void setupLocal() {
      _model = _modelKey.get();
      assert _model != null;
      _output = (SharedTreeOutput) _model._output; // Need to cast to SharedTreeModel to access ntrees, treeKeys, & init_f params
      assert _output != null;
      List> treeSHAPs = new ArrayList<>(_output._ntrees);
      for (int treeIdx = 0; treeIdx < _output._ntrees; treeIdx++) {
        for (int treeClass = 0; treeClass < _output._treeKeys[treeIdx].length; treeClass++) {
          if (_output._treeKeys[treeIdx][treeClass] == null) {
            continue;
          }
          SharedTreeSubgraph tree = _model.getSharedTreeSubgraph(treeIdx, treeClass);
          SharedTreeNode[] nodes = tree.getNodes();
          treeSHAPs.add(new TreeSHAP<>(nodes));
        }
      }
      assert treeSHAPs.size() == _output._ntrees; // for now only regression and binomial to keep the output sane
      _treeSHAP = new TreeSHAPEnsemble<>(treeSHAPs, (float) _output._init_f);
    }

    protected void fillInput(Chunk chks[], int row, double[] input, float[] contribs) {
      for (int i = 0; i < chks.length; i++) {
        input[i] = chks[i].atd(row);
      }
      Arrays.fill(contribs, 0);
    }

    @Override
    public void map(Chunk chks[], NewChunk[] nc) {
      assert chks.length == nc.length - 1; // calculate contribution for each feature + the model bias
      double[] input = MemoryManager.malloc8d(chks.length);
      float[] contribs = MemoryManager.malloc4f(nc.length);

      TreeSHAPPredictor.Workspace workspace = _treeSHAP.makeWorkspace();

      for (int row = 0; row < chks[0]._len; row++) {
        fillInput(chks, row, input, contribs);
        // calculate Shapley values
        _treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
        doModelSpecificComputation(contribs);
        // Add contribs to new chunk
        addContribToNewChunk(contribs, nc);
      }
    }

    protected void doModelSpecificComputation(float[] contribs) {/*For children*/}

    protected void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
      for (int i = 0; i < nc.length; i++) {
        nc[i].addNum(contribs[i]);
      }
    }
  }

  public class ScoreContributionsSortingTask extends ScoreContributionsTask {

    private final int _topN;
    private final int _bottomN;
    private final boolean _compareAbs;

    public ScoreContributionsSortingTask(SharedTreeModel model, ContributionsOptions options) {
      super(model);
      _topN = options._topN;
      _bottomN = options._bottomN;
      _compareAbs = options._compareAbs;
    }

    protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs, int[] contribNameIds) {
      super.fillInput(chks, row, input, contribs);
      for (int i = 0; i < contribNameIds.length; i++) {
        contribNameIds[i] = i;
      }
    }

    @Override
    public void map(Chunk chks[], NewChunk[] nc) {
      double[] input = MemoryManager.malloc8d(chks.length);
      float[] contribs = MemoryManager.malloc4f(chks.length+1);
      int[] contribNameIds = MemoryManager.malloc4(chks.length+1);

      TreeSHAPPredictor.Workspace workspace = _treeSHAP.makeWorkspace();

      for (int row = 0; row < chks[0]._len; row++) {
        fillInput(chks, row, input, contribs, contribNameIds);

        // calculate Shapley values
        _treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
        doModelSpecificComputation(contribs);
        ContributionComposer contributionComposer = new ContributionComposer();

        int[] contribNameIdsSorted = contributionComposer.composeContributions(
                contribNameIds, contribs, _topN, _bottomN, _compareAbs);

        // Add contribs to new chunk
        addContribToNewChunk(contribs, contribNameIdsSorted, nc);
      }
    }

    protected void addContribToNewChunk(float[] contribs, int[] contribNameIdsSorted, NewChunk[] nc) {
      for (int i = 0, inputPointer = 0; i < nc.length-1; i+=2, inputPointer++) {
        nc[i].addNum(contribNameIdsSorted[inputPointer]);
        nc[i+1].addNum(contribs[contribNameIdsSorted[inputPointer]]);
      }
      nc[nc.length-1].addNum(contribs[contribs.length-1]); // bias
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy