gov.sandia.cognition.statistics.method.ReceiverOperatingCharacteristic Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: ReceiverOperatingCharacteristic.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright August 24, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.statistics.method;
import gov.sandia.cognition.learning.performance.categorization.DefaultBinaryConfusionMatrix;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.ScalarThresholdBinaryCategorizer;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
/**
* Class that describes a Receiver Operating Characteristic (usually called an
* "ROC Curve"). This is a function that describes the performance of a
* classification system where the x-axis is the FalsePositiveRate and the
* y-axis is the TruePositiveRate. Both axes are on the interval [0,1]. A
* typical ROC curve has a logarithm-shaped plot, ideally it looks like a
* capital Gamma letter. An ROC curve also has an associated group of
* statistics with it from a Mann-Whitney U-test, which gives the probability
* that the classifier is essentially randomly "guessing." We create ROC curves
* by calling the method: ReceiverOperatingCharacteristic.create(data)
*
* @author Kevin R. Dixon
* @since 2.0
*
*/
@PublicationReference(
author="Wikipedia",
title="Receiver operating characteristic",
type=PublicationType.WebPage,
year=2009,
url="http://en.wikipedia.org/wiki/Receiver_operating_characteristic"
)
public class ReceiverOperatingCharacteristic
extends AbstractCloneableSerializable
implements Evaluator
{
/**
* Sorted data containing a ConfusionMatrix at each point, sorted in an
* ascending order along the abscissa (x-axis), which is FalsePositiveRate
*/
private ArrayList sortedROCData;
/**
* Results from conducting a U-test on the underlying classification data,
* the null hypothesis determines if the classifier can reliably separate
* the classes, not just chance
*/
private MannWhitneyUConfidence.Statistic Utest;
/**
* Creates a new instance of ReceiverOperatingCharacteristic
* @param rocData
* Sorted data containing a ConfusionMatrix at each point
* @param Utest
* Results from conducting a U-test on the underlying classification data,
* the null hypothesis determines if the classifier can reliably separate
* the classes, not just chance
*/
private ReceiverOperatingCharacteristic(
Collection rocData,
MannWhitneyUConfidence.Statistic Utest )
{
ArrayList sortedData =
CollectionUtil.asArrayList(rocData);
Collections.sort( sortedData, new DataPoint.Sorter() );
this.setSortedROCData( sortedData );
this.setUtest( Utest );
}
@Override
public ReceiverOperatingCharacteristic clone()
{
ReceiverOperatingCharacteristic clone =
(ReceiverOperatingCharacteristic) super.clone();
clone.setSortedROCData( ObjectUtil.cloneSmartElementsAsArrayList(
this.getSortedROCData() ) );
clone.setUtest( ObjectUtil.cloneSmart( this.getUtest() ) );
return clone;
}
/**
* Evaluates the "pessimistic" value of the truePositiveRate for a given
* falsePositiveRate. This evaluation is pessimistic in that it holds
* the truePositiveRate (y-value) until we receive a corresponding
* falsePositiveRate (x-value) that is greater than the given value
* @param input
* falsePositiveRate from which to estimate the truePositiveRate
* @return
* Pessimistic TruePositiveRate for the given FalsePositiveRate
*/
@Override
public Double evaluate(
Double input )
{
double falsePositiveRate = input;
double truePositiveRate = 0.0;
for( DataPoint rocData : this.getSortedROCData() )
{
if( rocData.getFalsePositiveRate() <= falsePositiveRate )
{
truePositiveRate = rocData.getTruePositiveRate();
}
else
{
break;
}
}
return truePositiveRate;
}
/**
* Getter for sortedROCData
* @return
* Sorted data containing a ConfusionMatrix at each point, sorted in an
* ascending order along the abscissa (x-axis), which is FalsePositiveRate
*/
public ArrayList getSortedROCData()
{
return this.sortedROCData;
}
/**
* Setter for srtedROCData
* @param sortedROCData
* Sorted data containing a ConfusionMatrix at each point, sorted in an
* ascending order along the abscissa (x-axis), which is FalsePositiveRate
*/
protected void setSortedROCData(
ArrayList sortedROCData)
{
this.sortedROCData = sortedROCData;
}
/**
* Creates an ROC curve based on the scored data with target information.
*
* @param data
* Collection of target/estimate-score pairs. The second element in
* the pair is an estimated score, the first is a flag to determine
* which group the score belongs to. For example:
* {(true, 1.0), (false, 0.9)}
* means that data1=1.0 and data2=0.9 and so forth. This is useful
* for computing that classified data partitions data better than
* chance.
* @return
* ROC Curve describing the scoring system versus the targets.
*/
public static ReceiverOperatingCharacteristic createFromTargetEstimatePairs(
final Collection extends Pair> data)
{
// Transform the data to input-output pairs.
final ArrayList> transformed =
new ArrayList>(data.size());
for (Pair entry : data)
{
transformed.add(DefaultInputOutputPair.create(
entry.getSecond().doubleValue(), entry.getFirst()));
}
return create(transformed);
}
/**
* Creates an ROC curve based on the scored data with target information
* @param data
* Collection of estimate-score/target pairs. The second element in the
* Pair is an estimated score, the first is a flag to determine which
* group the score belongs to. For example {<1.0,true>, <0.9,false> means
* that data1=1.0 and data2=0.9 and so forth. This is useful for computing
* that classified data partitions data better than chance.
* @return
* ROC Curve describing the scoring system versus the targets
*/
public static ReceiverOperatingCharacteristic create(
Collection extends InputOutputPair> data)
{
// First we need to sort the data in increasing order by value.
ArrayList> sortedData =
new ArrayList>(data);
Collections.sort(sortedData, new ROCScoreSorter());
// Next we need to count the total number of positive examples.
int totalPositives = 0;
for ( InputOutputPair pair : sortedData )
{
if ( pair.getOutput() == true )
{
totalPositives++;
}
}
// Now we compute the total and the number of negatives.
int total = sortedData.size();
int totalNegatives = total - totalPositives;
// We will be computing the confusion matrix iteratively, which means
// keeping track of how many examples we have counted so far and the
// number of positive examples counted so far.
int countSoFar = 0;
int positivesSoFar = 0;
LinkedList rocData =
new LinkedList();
double lastThreshold = Double.NEGATIVE_INFINITY;
for ( InputOutputPair pair : sortedData )
{
// Compute the confusion matrix based on the current counters.
final double trueNegatives = countSoFar - positivesSoFar;
final double falseNegatives = positivesSoFar;
final double truePositives = totalPositives - falseNegatives;
final double falsePositives = totalNegatives - trueNegatives;
final double threshold = pair.getInput();
if ( threshold > lastThreshold )
{
// Only add a data point if we have evaluated a new
// threshold.
final DefaultBinaryConfusionMatrix confusion =
new DefaultBinaryConfusionMatrix();
confusion.setFalsePositivesCount(falsePositives);
confusion.setFalseNegativesCount(falseNegatives);
confusion.setTruePositivesCount(truePositives);
confusion.setTrueNegativesCount(trueNegatives);
rocData.add(new DataPoint(
new ScalarThresholdBinaryCategorizer(threshold), confusion));
lastThreshold = threshold;
}
// Update the count so far and the positives so far. This is done
// after the computations of the counts because the threshold is
// x >= threshold, so the threshold puts the current point on the
// positive side.
countSoFar++;
final boolean target = pair.getOutput();
if ( target == true )
{
positivesSoFar++;
}
}
Collections.sort(rocData, new ReceiverOperatingCharacteristic.DataPoint.Sorter());
// Compute a statistical test on the data.
MannWhitneyUConfidence.Statistic uTest =
new MannWhitneyUConfidence().evaluateNullHypothesis(data);
// Return the ROC.
return new ReceiverOperatingCharacteristic(rocData, uTest);
}
/**
* Computes useful statistical information associated with the ROC curve
* @return ROC Statistics describing the ROC curve
*/
public ReceiverOperatingCharacteristic.Statistic computeStatistics()
{
return new ReceiverOperatingCharacteristic.Statistic( this );
}
/**
* Getter for Utest
* @return
* Results from conducting a U-test on the underlying classification data,
* the null hypothesis determines if the classifier can reliably separate
* the classes, not just chance
*/
public MannWhitneyUConfidence.Statistic getUtest()
{
return this.Utest;
}
/**
* Setter for Utest
* @param Utest
* Results from conducting a U-test on the underlying classification data,
* the null hypothesis determines if the classifier can reliably separate
* the classes, not just chance
*/
public void setUtest(
MannWhitneyUConfidence.Statistic Utest)
{
this.Utest = Utest;
}
/**
* Contains useful statistics derived from a ROC curve
*/
public static class Statistic
extends MannWhitneyUConfidence.Statistic
{
/**
* Estimated distance between the two classes to be split. Larger
* values of d' indicate that the classes are easier to split,
* d'=0 means that the classes overlap, and negative values mean
* that your classifier is doing worse than chance, chump. This
* appears to only be used by psychologists.
*/
private double dPrime;
/**
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
private double areaUnderCurve;
/**
* DataPoint, with corresponding threshold, that maximizes the value
* of Area=TruePositiveRate*(1-FalsePositiveRate), usually the
* upper-left "knee" on the ROC curve
*/
private DataPoint optimalThreshold;
/**
* Creates a new instance of Statistic
* @param roc
* ROC Curve from which to pull statistics
*/
protected Statistic(
ReceiverOperatingCharacteristic roc )
{
super( roc.getUtest() );
this.setAreaUnderCurve( computeAreaUnderCurve( roc ) );
this.setOptimalThreshold( computeOptimalThreshold( roc ) );
this.setDPrime( computeDPrime( this.getOptimalThreshold() ) );
}
/**
* Computes the "pessimistic" area under the ROC curve using the
* top-left rectangle method for numerical integration.
* @param roc
* ROC Curve to compute the area under
* @return
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
public static double computeAreaUnderCurve(
ReceiverOperatingCharacteristic roc )
{
return computeAreaUnderCurveTopLeft( roc.getSortedROCData() );
}
/**
* Computes the Area Under Curve for an x-axis sorted Collection
* of ROC points using the top-left rectangle method for numerical
* integration.
* @param points
* x-axis sorted collection of x-axis points
* @return
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
@PublicationReference(
author="Wikipedia",
title="Rectangle method",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Rectangle_method"
)
public static double computeAreaUnderCurveTopLeft(
Collection points )
{
ReceiverOperatingCharacteristic.DataPoint current =
CollectionUtil.getFirst(points);
double auc = 0.0;
double xnm1 = 0.0;
double ynm1 = 0.0;
double xn = 0.0;
for( ReceiverOperatingCharacteristic.DataPoint point : points )
{
// Technically, this wastes the computation of the first point,
// but since the delta is 0.0, it doesn't effect the AUC.
ReceiverOperatingCharacteristic.DataPoint previous = current;
previous = current;
current = point;
xnm1 = previous.getFalsePositiveRate();
ynm1 = previous.getTruePositiveRate();
xn = current.getFalsePositiveRate();
final double area = ynm1*(xn-xnm1);
auc += area;
}
// Assume that the final point is at xn=1.0
xnm1 = xn;
xn = 1.0;
final double area = ynm1*(xn-xnm1);
auc += area;
return auc;
}
/**
* Computes the Area Under Curve for an x-axis sorted Collection
* of ROC points using the top-left rectangle method for numerical
* integration.
* @param points
* x-axis sorted collection of x-axis points
* @return
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
@PublicationReference(
author="Wikipedia",
title="Trapezoidal rule",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Trapezoidal_rule"
)
public static double computeAreaUnderCurveTrapezoid(
Collection points )
{
ReceiverOperatingCharacteristic.DataPoint current =
CollectionUtil.getFirst(points);
double auc = 0.0;
double xnm1 = 0.0;
double ynm1 = 0.0;
double yn = 0.0;
double xn = 0.0;
for( ReceiverOperatingCharacteristic.DataPoint point : points )
{
// Technically, this wastes the computation of the first point,
// but since the delta is 0.0, it doesn't effect the AUC.
ReceiverOperatingCharacteristic.DataPoint previous = current;
previous = current;
current = point;
xnm1 = previous.getFalsePositiveRate();
ynm1 = previous.getTruePositiveRate();
xn = current.getFalsePositiveRate();
yn = current.getTruePositiveRate();
final double area = (xn-xnm1) * (yn+ynm1) / 2.0;
auc += area;
}
// Assume that the final point is at xn=1.0
xnm1 = xn;
xn = 1.0;
yn = 1.0;
final double area = (xn-xnm1) * (yn+ynm1) / 2.0;
auc += area;
return auc;
}
/**
* Determines the DataPoint, and associated threshold, that
* simultaneously maximizes the value of
* Area=TruePositiveRate+TrueNegativeRate, usually the
* upper-left "knee" on the ROC curve.
*
* @param roc
* ROC Curve to consider
* @return DataPoint, with corresponding threshold, that maximizes the value
* of Area=TruePositiveRate*(1-FalsePositiveRate), usually the
* upper-left "knee" on the ROC curve.
*/
public static DataPoint computeOptimalThreshold(
ReceiverOperatingCharacteristic roc )
{
return computeOptimalThreshold( roc, 1.0, 1.0 );
}
/**
* Determines the DataPoint, and associated threshold, that
* simultaneously maximizes the value of
* Area=TruePositiveRate+TrueNegativeRate, usually the
* upper-left "knee" on the ROC curve.
*
*
* @return DataPoint, with corresponding threshold, that maximizes the value
* of Area=TruePositiveRate*(1-FalsePositiveRate), usually the
* upper-left "knee" on the ROC curve.
* @param truePositiveWeight
* Amount to weight the TruePositiveRate
* @param trueNegativeWeight
* Amount to weight the TrueNegativeRate
* @param roc ROC Curve to consider
*/
public static DataPoint computeOptimalThreshold(
ReceiverOperatingCharacteristic roc,
double truePositiveWeight,
double trueNegativeWeight )
{
DataPoint bestData = null;
double bestValue = Double.NEGATIVE_INFINITY;
for( ReceiverOperatingCharacteristic.DataPoint data : roc.getSortedROCData() )
{
// Find the point that maximizes the perimeter "below and to the
// right" of the point
DefaultBinaryConfusionMatrix cm = data.getConfusionMatrix();
double y = truePositiveWeight * cm.getTruePositivesRate();
double x = trueNegativeWeight * cm.getTrueNegativesRate();
double value = x + y;
if( bestValue < value )
{
bestValue = value;
bestData = data;
}
}
return bestData;
}
/**
* Computes the value of d-prime given a datapoint
* @param data
* Datapoint from which to estimate d'
* @return
* Estimated distance between the two classes to be split. Larger
* values of d' indicate that the classes are easier to split,
* d'=0 means that the classes overlap, and negative values mean
* that your classifier is doing worse than chance, chump. This
* appears to only be used by psychologists.
*/
public static double computeDPrime(
DataPoint data )
{
double hitRate = data.getConfusionMatrix().getTruePositivesRate();
double faRate = data.getFalsePositiveRate();
double zhr = UnivariateGaussian.CDF.Inverse.evaluate( hitRate, 0.0, 1.0 );
double zfa = UnivariateGaussian.CDF.Inverse.evaluate( faRate, 0.0, 1.0 );
return zhr - zfa;
}
/**
* Getter for dPrime
* @return
* Estimated distance between the two classes to be split. Larger
* values of d' indicate that the classes are easier to split,
* d'=0 means that the classes overlap, and negative values mean
* that your classifier is doing worse than chance, chump. This
* appears to only be used by psychologists.
*/
public double getDPrime()
{
return this.dPrime;
}
/**
* Setter for dPrime
* @param dPrime
* Estimated distance between the two classes to be split. Larger
* values of d' indicate that the classes are easier to split,
* d'=0 means that the classes overlap, and negative values mean
* that your classifier is doing worse than chance, chump. This
* appears to only be used by psychologists.
*/
protected void setDPrime(
double dPrime)
{
this.dPrime = dPrime;
}
/**
* Getter for areaUnderCurve
* @return
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
public double getAreaUnderCurve()
{
return this.areaUnderCurve;
}
/**
* Setter for areaUnderCurve
* @param areaUnderCurve
* Area underneath the ROC curve, on the interval [0,1]. A value of
* 0.5 means that the classifier is doing no better than chance and
* bigger is better
*/
protected void setAreaUnderCurve(
double areaUnderCurve)
{
this.areaUnderCurve = areaUnderCurve;
}
/**
* Getter for optimalThreshold
*
* @return DataPoint, with corresponding threshold, that maximizes the value
* of Area=TruePositiveRate*(1-FalsePositiveRate), usually the
* upper-left "knee" on the ROC curve.
*/
public DataPoint getOptimalThreshold()
{
return this.optimalThreshold;
}
/**
* Setter for optimalThreshold
*
* @param optimalThreshold
* DataPoint, with corresponding threshold, that maximizes the value
* of Area=TruePositiveRate*(1-FalsePositiveRate), usually the
* upper-left "knee" on the ROC curve.
*/
protected void setOptimalThreshold(
DataPoint optimalThreshold)
{
this.optimalThreshold = optimalThreshold;
}
}
/**
* Contains information about a datapoint on an ROC curve
*/
public static class DataPoint
extends AbstractCloneableSerializable
{
/**
* Binary classifier used to create the corresponding ConfusionMatrix,
* which is really a wrapper for the threshold
*/
private ScalarThresholdBinaryCategorizer classifier;
/**
* Corresponding ConfusionMatrix with this datapoint
*/
private DefaultBinaryConfusionMatrix confusionMatrix;
/**
* Creates a new instance of DataPoint
*
* @param classifier
* Binary classifier used to create the corresponding ConfusionMatrix,
* which is really a wrapper for the threshold
* @param confusionMatrix
* Corresponding ConfusionMatrix with this datapoint
*/
public DataPoint(
ScalarThresholdBinaryCategorizer classifier,
DefaultBinaryConfusionMatrix confusionMatrix )
{
this.setClassifier( classifier );
this.setConfusionMatrix( confusionMatrix );
}
/**
* Getter for classifier
* @return
* Binary classifier used to create the corresponding ConfusionMatrix,
* which is really a wrapper for the threshold
*/
public ScalarThresholdBinaryCategorizer getClassifier()
{
return this.classifier;
}
/**
* Setter for classifier
* @param classifier
* Binary classifier used to create the corresponding ConfusionMatrix,
* which is really a wrapper for the threshold
*/
public void setClassifier(
ScalarThresholdBinaryCategorizer classifier)
{
this.classifier = classifier;
}
/**
* Getter for confusionMatrix
* @return
* Corresponding ConfusionMatrix with this datapoint
*/
public DefaultBinaryConfusionMatrix getConfusionMatrix()
{
return this.confusionMatrix;
}
/**
* Setter for confusionMatrix
* @param confusionMatrix
* Corresponding ConfusionMatrix with this datapoint
*/
protected void setConfusionMatrix(
DefaultBinaryConfusionMatrix confusionMatrix)
{
this.confusionMatrix = confusionMatrix;
}
/**
* Gets the falsePositiveRate associated with this datapoint
* @return
* falsePositiveRate associated with this datapoint
*/
public double getFalsePositiveRate()
{
return this.getConfusionMatrix().getFalsePositivesRate();
}
/**
* Gets the truePositiveRate associated with this datapoint
* @return
* truePositiveRate associated with this datapoint
*/
public double getTruePositiveRate()
{
return this.getConfusionMatrix().getTruePositivesRate();
}
/**
* Sorts DataPoints in ascending order according to their
* falsePositiveRate (x-axis)
*/
public static class Sorter
extends AbstractCloneableSerializable
implements Comparator
{
/**
* Sorts ROCDataPoints in ascending order according to their
* falsePositiveRate (x-axis), used in Arrays.sort() method
*
* @param o1 First datapoint
* @param o2 Second datapoint
* @return
* -1 if o1o2, 0 if o1=o2
*/
@Override
public int compare(
ReceiverOperatingCharacteristic.DataPoint o1,
ReceiverOperatingCharacteristic.DataPoint o2)
{
double x1 = o1.getFalsePositiveRate();
double x2 = o2.getFalsePositiveRate();
if( x1 < x2 )
{
return -1;
}
else if( x1 > x2 )
{
return +1;
}
else
{
double y1 = o1.getTruePositiveRate();
double y2 = o2.getTruePositiveRate();
if( y1 < y2 )
{
return -1;
}
else if( y1 > y2 )
{
return +1;
}
else
{
return 0;
}
}
}
}
}
/**
* Sorts score estimates for the ROC create() method
*/
private static class ROCScoreSorter
extends AbstractCloneableSerializable
implements Comparator>
{
/**
* Sorts score estimates for the ROC create() method
*
* @param o1 First score
* @param o2 Second score
* @return -1 if o1o2, 0 if o1=02
*/
@Override
public int compare(
InputOutputPair o1,
InputOutputPair o2)
{
if( o1.getInput() < o2.getInput() )
{
return -1;
}
else if( o1.getInput() > o2.getInput() )
{
return +1;
}
else
{
return 0;
}
}
}
}