com.etsy.conjecture.evaluation.MulticlassReceiverOperatingCharacteristic Maven / Gradle / Ivy
package com.etsy.conjecture.evaluation;
import java.io.Serializable;
import java.util.Collection;
import java.util.Map;
import java.util.HashMap;
import com.etsy.conjecture.GenericPair;
import com.etsy.conjecture.data.MulticlassPrediction;
import static com.google.common.base.Preconditions.checkArgument;
public class MulticlassReceiverOperatingCharacteristic implements Serializable {
private static final long serialVersionUID = 1L;
/** Num examples in each class. */
private Map classCounts;
/** Num total examples */
private int numExamples;
/** Binary ROCs for each class */
private Map classROC;
/**
* Instantiates a new receiver operating characteristic.
*/
public MulticlassReceiverOperatingCharacteristic(String[] categories) {
classROC = new HashMap();
classCounts = new HashMap();
for (String category : categories) {
classROC.put(category, new ReceiverOperatingCharacteristic());
classCounts.put(category, 0);
}
}
public void add(MulticlassReceiverOperatingCharacteristic other) {
numExamples += other.numExamples;
for(Map.Entry entry : other.classCounts.entrySet()) {
String category = entry.getKey();
Integer count = entry.getValue();
classCounts.put(category, classCounts.get(category)+count);
}
for(Map.Entry entry : other.classROC.entrySet()) {
String category = entry.getKey();
ReceiverOperatingCharacteristic update = entry.getValue();
ReceiverOperatingCharacteristic roc = classROC.get(category);
roc.add(update);
classROC.put(category, roc);
}
}
public void add(GenericPair labelPrediction) {
add(labelPrediction.first, labelPrediction.second);
}
public void add(String label, MulticlassPrediction prediction) {
checkArgument(classCounts.containsKey(label),
"label is of unknown category: %s", label);
checkArgument(classROC.containsKey(label),
"label is of unknown category: %s", label);
// accum class counts
int count = classCounts.get(label);
classCounts.put(label, count + 1);
// accum total counts;
numExamples++;
// add to individual binary ROC classes
for (String category : classCounts.keySet()) {
double binaryPrediction = prediction.getMap().get(category);
double classLabel = category.equals(label) ? 1d : 0d;
classROC.get(category).add(classLabel, binaryPrediction);
}
}
public double multiclassAUC() {
double weightedAverageAUC = 0d;
for (String label : classCounts.keySet()) {
double classInfluence = (double)classCounts.get(label)
/ numExamples;
ReceiverOperatingCharacteristic roc = classROC.get(label);
double classAUC = roc.binaryAUC();
weightedAverageAUC += classInfluence * classAUC;
}
return weightedAverageAUC;
}
public double singleClassAUC(String category) {
return classROC.get(category).binaryAUC();
}
public double multiclassBrierScore() {
double brierScore = 0d;
int numClasses = classCounts.keySet().size();
for (String label : classCounts.keySet()) {
brierScore += (classROC.get(label)).brierScore();
}
return brierScore / numClasses;
}
public double computePercent(String category) {
return classCounts.get(category) / (double) numExamples;
}
public static double computeAUC(
Collection> labelsAndPredictions,
String[] categories) {
MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic(
categories);
for (GenericPair p : labelsAndPredictions)
roc.add((String)p.first, (MulticlassPrediction)p.second);
return roc.multiclassAUC();
}
public static double computeBrierScore(
Collection> labelsAndPredictions,
String[] categories) {
MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic(
categories);
for (GenericPair p : labelsAndPredictions)
roc.add(p.first, p.second);
return roc.multiclassBrierScore();
}
}