All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.algos.tree.TreeSHAPEnsemble Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.algos.tree;

import java.util.Collection;

public class TreeSHAPEnsemble implements TreeSHAPPredictor {

  private final TreeSHAPPredictor[] _predictors;
  private final float _initPred;
  private final int _wsMakerIndex;

  @SuppressWarnings("unchecked")
  public TreeSHAPEnsemble(Collection> predictors, float initPred) {
    _predictors = predictors.toArray(new TreeSHAPPredictor[0]);
    _initPred = initPred;
    _wsMakerIndex = findWorkspaceMaker(_predictors);
  }

  @Override
  public float[] calculateContributions(R feat, float[] out_contribs) {
    return calculateContributions(feat, out_contribs, 0, -1, makeWorkspace());
  }

  @Override
  public float[] calculateContributions(R feat, float[] out_contribs, int condition, int condition_feature, TreeSHAP.Workspace workspace) {
    if (condition == 0) {
      out_contribs[out_contribs.length - 1] += _initPred;
    }
    for (TreeSHAPPredictor predictor : _predictors) {
      predictor.calculateContributions(feat, out_contribs, condition, condition_feature, workspace);
    }
    return out_contribs; 
  }

  @Override
  public double[] calculateInterventionalContributions(R feat, R background, double[] out_contribs, int[] catOffsets, boolean expand) {
    out_contribs[out_contribs.length - 1] += _initPred;

    for (TreeSHAPPredictor predictor : _predictors) {
      predictor.calculateInterventionalContributions(feat, background, out_contribs, catOffsets, expand);
    }
    return out_contribs;
  }

  @Override
  public TreeSHAPPredictor.Workspace makeWorkspace() {
    return _wsMakerIndex >= 0 ? _predictors[_wsMakerIndex].makeWorkspace() : null;
  }

  @Override
  public int getWorkspaceSize() {
    return _wsMakerIndex >= 0 ? _predictors[_wsMakerIndex].getWorkspaceSize() : 0;
  }

  private static int findWorkspaceMaker(TreeSHAPPredictor[] predictors) {
    if (predictors.length == 0)
      return -1;
    int maxSize = 0;
    int wsMakerIndex = 0;
    for (int i = 0; i < predictors.length; i++) {
      int size = predictors[i].getWorkspaceSize();
      if (size > maxSize) {
        maxSize = size;
        wsMakerIndex = i;
      }
    }
    return wsMakerIndex;
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy