All Downloads are FREE. Search and download functionalities are using the official Maven repository.

water.test.util.NaiveTreeSHAP Maven / Gradle / Ivy

The newest version!
package water.test.util;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import org.junit.Ignore;

import java.util.*;

@Ignore
public class NaiveTreeSHAP, S extends INodeStat> {

  private final int rootNodeId;
  private final N[] nodes;
  private final S[] stats;

  public NaiveTreeSHAP(N[] nodes, S[] stats, int rootNodeId) {
    this.rootNodeId = rootNodeId;
    this.nodes = nodes;
    this.stats = stats;
  }

  public double calculateContributions(R row, double[] contribsNaive) {
    final Set usedFeatures = usedFeatures();
    final int M = usedFeatures.size();
    // last element is the bias term
    contribsNaive[contribsNaive.length - 1] += treeMeanValue() /* tree bias */;
    // pre-calculate expValue for each subset
    Map, Double> expVals = new HashMap<>();
    for (Set subset : allSubsets(usedFeatures)) {
      expVals.put(subset, expValue(row, subset));
    }
    // calculate contributions using pre-calculated expValues
    for (Integer feature : usedFeatures) {
      for (Set subset : expVals.keySet()) {
        if (subset.contains(feature)) {
          Set noFeature = new HashSet<>(subset);
          noFeature.remove(feature);
          double mult = fact(noFeature.size()) * (long) fact(M - subset.size()) / (double) fact(M);
          double contrib = mult * (expVals.get(subset) - expVals.get(noFeature));
          contribsNaive[feature] += contrib;
        }
      }
    }
    // expValue of a tree with all features marked as used should sum-up to the total prediction
    return expValue(row, usedFeatures);
  }
  
  private double expValue(R v, Set s) {
    return expValue(rootNodeId, v, s, 1.0);
  }

  private static int fact(int v) {
    int f = 1;
    for (int i = 1; i <= v; i++) {
      f *= i;
    }
    return f;
  }

  private static List> allSubsets(Set s) {
    List> result = new LinkedList<>();
    Integer[] ary = s.toArray(new Integer[0]);
    // Run a loop from 0 to 2^n
    for (int i = 0; i < (1< subset = new HashSet<>(s.size());
      int m = 1;
      for (Integer item : ary) {
        if ((i & m) > 0) {
          subset.add(item);
        }
        m = m << 1;
      }

      result.add(subset);
    }
    return result;
  }

  private Set usedFeatures() {
    Set features = new HashSet<>();
    for(N n : nodes) {
      features.add(n.getSplitIndex());
    }
    return features;
  }

  private double expValue(int node, R v, Set s, double w) {
    final INode n = nodes[node];
    if (n.isLeaf()) {
      return w * n.getLeafValue();
    } else {
      if (s.contains(n.getSplitIndex())) {
        return expValue(n.next(v), v, s, w);
      } else {
        double wP = stats[node].getWeight();
        double wL = stats[n.getLeftChildIndex()].getWeight();
        double wR = stats[n.getRightChildIndex()].getWeight();
        return expValue(n.getLeftChildIndex(), v, s, w * wL / wP) +
                expValue(n.getRightChildIndex(), v, s, w * wR / wP);
      }
    }
  }

  private double treeMeanValue() {
    return nodeMeanValue(rootNodeId);
  }

  private double nodeMeanValue(int node) {
    final INode n = nodes[node];
    if (n.isLeaf()) {
      return n.getLeafValue();
    } else {
      return (stats[n.getLeftChildIndex()].getWeight() * nodeMeanValue(n.getLeftChildIndex()) +
              stats[n.getRightChildIndex()].getWeight() * nodeMeanValue(n.getRightChildIndex())) / stats[node].getWeight();
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy