All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.tree.ContributionsPredictor Maven / Gradle / Ivy

There is a newer version: 3.46.0.5
Show newest version
package hex.genmodel.algos.tree;

import hex.genmodel.PredictContributions;
import hex.genmodel.attributes.parameters.FeatureContribution;
import hex.genmodel.utils.ArrayUtils;

import java.util.Arrays;

public abstract class ContributionsPredictor implements PredictContributions {
  private final int _ncontribs;
  private final String[] _contribution_names;
  private final TreeSHAPPredictor _treeSHAPPredictor;
  private final int _workspaceSize;

  private static final ThreadLocal _workspace = new ThreadLocal<>();

  public ContributionsPredictor(int ncontribs, String[] featureContributionNames, TreeSHAPPredictor treeSHAPPredictor) {
    _ncontribs = ncontribs;
    _contribution_names = ArrayUtils.append(featureContributionNames, "BiasTerm");
    _treeSHAPPredictor = treeSHAPPredictor;
    _workspaceSize = _treeSHAPPredictor.getWorkspaceSize();
  }

  @Override
  public final String[] getContributionNames() {
    return _contribution_names;
  }

  public final float[] calculateContributions(double[] input) {
    float[] contribs = new float[_ncontribs];
    _treeSHAPPredictor.calculateContributions(toInputRow(input), contribs, 0, -1, getWorkspace());
    return getContribs(contribs);
  }

  protected abstract E toInputRow(double[] input);

  public float[] getContribs(float[] contribs) {
    return contribs;
  }

  private TreeSHAPPredictor.Workspace getWorkspace() {
    TreeSHAPPredictor.Workspace workspace = _workspace.get();
    if (workspace == null || workspace.getSize() != _workspaceSize) {
      workspace = _treeSHAPPredictor.makeWorkspace();
      assert workspace.getSize() == _workspaceSize;
      _workspace.set(workspace);
    }
    return workspace;
  }

  @Override
  public FeatureContribution[] calculateContributions(double[] input, int topN, int bottomN, boolean compareAbs) {
    float[] contributions = calculateContributions(input);
    int[] contributionNameIds = ArrayUtils.range(0, _contribution_names.length -1);
    int[] sorted = (new ContributionComposer()).composeContributions(contributionNameIds, contributions, topN, bottomN, compareAbs);
    FeatureContribution[] out = new FeatureContribution[sorted.length];
    for (int i = 0; i < sorted.length; i++) {
      out[i] = new FeatureContribution(_contribution_names[sorted[i]], contributions[sorted[i]]);
    }
    return out;
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy