All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.evaluation.MulticlassReceiverOperatingCharacteristic Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.evaluation;

import java.io.Serializable;
import java.util.Collection;
import java.util.Map;
import java.util.HashMap;

import com.etsy.conjecture.GenericPair;
import com.etsy.conjecture.data.MulticlassPrediction;
import static com.google.common.base.Preconditions.checkArgument;

public class MulticlassReceiverOperatingCharacteristic implements Serializable {

    private static final long serialVersionUID = 1L;

    /** Num examples in each class. */
    private Map classCounts;

    /** Num total examples */
    private int numExamples;

    /** Binary ROCs for each class */
    private Map classROC;

    /**
     * Instantiates a new receiver operating characteristic.
     */
    public MulticlassReceiverOperatingCharacteristic(String[] categories) {
        classROC = new HashMap();
        classCounts = new HashMap();
        for (String category : categories) {
            classROC.put(category, new ReceiverOperatingCharacteristic());
            classCounts.put(category, 0);
        }
    }

    public void add(MulticlassReceiverOperatingCharacteristic other) {
        numExamples += other.numExamples;
        for(Map.Entry entry : other.classCounts.entrySet()) {
            String category = entry.getKey();
            Integer count = entry.getValue();
            classCounts.put(category, classCounts.get(category)+count);
        }

        for(Map.Entry entry : other.classROC.entrySet()) {
            String category = entry.getKey();
            ReceiverOperatingCharacteristic update = entry.getValue();
            ReceiverOperatingCharacteristic roc = classROC.get(category);
            roc.add(update);
            classROC.put(category, roc);
        }
    }

    public void add(GenericPair labelPrediction) {
        add(labelPrediction.first, labelPrediction.second);
    }

    public void add(String label, MulticlassPrediction prediction) {
        checkArgument(classCounts.containsKey(label),
                "label is of unknown category: %s", label);
        checkArgument(classROC.containsKey(label),
                "label is of unknown category: %s", label);

        // accum class counts
        int count = classCounts.get(label);
        classCounts.put(label, count + 1);

        // accum total counts;
        numExamples++;

        // add to individual binary ROC classes
        for (String category : classCounts.keySet()) {
            double binaryPrediction = prediction.getMap().get(category);
            double classLabel = category.equals(label) ? 1d : 0d;
            classROC.get(category).add(classLabel, binaryPrediction);
        }
    }

    public double multiclassAUC() {
        double weightedAverageAUC = 0d;
        for (String label : classCounts.keySet()) {
            double classInfluence = (double)classCounts.get(label)
                    / numExamples;
            ReceiverOperatingCharacteristic roc = classROC.get(label);
            double classAUC = roc.binaryAUC();
            weightedAverageAUC += classInfluence * classAUC;
        }
        return weightedAverageAUC;
    }

    public double singleClassAUC(String category) {
        return classROC.get(category).binaryAUC();
    }

    public double multiclassBrierScore() {
        double brierScore = 0d;
        int numClasses = classCounts.keySet().size();
        for (String label : classCounts.keySet()) {
            brierScore += (classROC.get(label)).brierScore();
        }
        return brierScore / numClasses;
    }

    public double computePercent(String category) {
        return classCounts.get(category) / (double) numExamples;
    }

    public static double computeAUC(
            Collection> labelsAndPredictions,
            String[] categories) {
        MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic(
                categories);
        for (GenericPair p : labelsAndPredictions)
            roc.add((String)p.first, (MulticlassPrediction)p.second);
        return roc.multiclassAUC();
    }

    public static double computeBrierScore(
            Collection> labelsAndPredictions,
            String[] categories) {
        MulticlassReceiverOperatingCharacteristic roc = new MulticlassReceiverOperatingCharacteristic(
                categories);
        for (GenericPair p : labelsAndPredictions)
            roc.add(p.first, p.second);
        return roc.multiclassBrierScore();
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy