weka.classifiers.evaluation.ThresholdCurve Maven / Gradle / Ivy
Show all versions of weka-dev Show documentation
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* ThresholdCurve.java
* Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.evaluation;
import java.util.ArrayList;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* Generates points illustrating prediction tradeoffs that can be obtained by
* varying the threshold value between classes. For example, the typical
* threshold value of 0.5 means the predicted probability of "positive" must be
* higher than 0.5 for the instance to be predicted as "positive". The resulting
* dataset can be used to visualize precision/recall tradeoff, or for ROC curve
* analysis (true positive rate vs false positive rate). Weka just varies the
* threshold on the class probability estimates in each case. The Mann Whitney
* statistic is used to calculate the AUC.
*
* @author Len Trigg ([email protected])
* @version $Revision: 10153 $
*/
public class ThresholdCurve implements RevisionHandler {
/** The name of the relation used in threshold curve datasets */
public static final String RELATION_NAME = "ThresholdCurve";
/** attribute name: True Positives */
public static final String TRUE_POS_NAME = "True Positives";
/** attribute name: False Negatives */
public static final String FALSE_NEG_NAME = "False Negatives";
/** attribute name: False Positives */
public static final String FALSE_POS_NAME = "False Positives";
/** attribute name: True Negatives */
public static final String TRUE_NEG_NAME = "True Negatives";
/** attribute name: False Positive Rate" */
public static final String FP_RATE_NAME = "False Positive Rate";
/** attribute name: True Positive Rate */
public static final String TP_RATE_NAME = "True Positive Rate";
/** attribute name: Precision */
public static final String PRECISION_NAME = "Precision";
/** attribute name: Recall */
public static final String RECALL_NAME = "Recall";
/** attribute name: Fallout */
public static final String FALLOUT_NAME = "Fallout";
/** attribute name: FMeasure */
public static final String FMEASURE_NAME = "FMeasure";
/** attribute name: Sample Size */
public static final String SAMPLE_SIZE_NAME = "Sample Size";
/** attribute name: Lift */
public static final String LIFT_NAME = "Lift";
/** attribute name: Threshold */
public static final String THRESHOLD_NAME = "Threshold";
/**
* Calculates the performance stats for the default class and return results
* as a set of Instances. The structure of these Instances is as follows:
*
*
* - True Positives
*
- False Negatives
*
- False Positives
*
- True Negatives
*
- False Positive Rate
*
- True Positive Rate
*
- Precision
*
- Recall
*
- Fallout
*
- Threshold contains the probability threshold that gives rise to
* the previous performance values.
*
*
* For the definitions of these measures, see TwoClassStats
*
*
* @see TwoClassStats
* @param predictions the predictions to base the curve on
* @return datapoints as a set of instances, null if no predictions have been
* made.
*/
public Instances getCurve(ArrayList predictions) {
if (predictions.size() == 0) {
return null;
}
return getCurve(predictions,
((NominalPrediction) predictions.get(0)).distribution().length - 1);
}
/**
* Calculates the performance stats for the desired class and return results
* as a set of Instances.
*
* @param predictions the predictions to base the curve on
* @param classIndex index of the class of interest.
* @return datapoints as a set of instances.
*/
public Instances getCurve(ArrayList predictions, int classIndex) {
if ((predictions.size() == 0)
|| (((NominalPrediction) predictions.get(0)).distribution().length <= classIndex)) {
return null;
}
double totPos = 0, totNeg = 0;
double[] probs = getProbabilities(predictions, classIndex);
// Get distribution of positive/negatives
for (int i = 0; i < probs.length; i++) {
NominalPrediction pred = (NominalPrediction) predictions.get(i);
if (pred.actual() == Prediction.MISSING_VALUE) {
System.err.println(getClass().getName()
+ " Skipping prediction with missing class value");
continue;
}
if (pred.weight() < 0) {
System.err.println(getClass().getName()
+ " Skipping prediction with negative weight");
continue;
}
if (pred.actual() == classIndex) {
totPos += pred.weight();
} else {
totNeg += pred.weight();
}
}
Instances insts = makeHeader();
int[] sorted = Utils.sort(probs);
TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
double threshold = 0;
double cumulativePos = 0;
double cumulativeNeg = 0;
for (int i = 0; i < sorted.length; i++) {
if ((i == 0) || (probs[sorted[i]] > threshold)) {
tc.setTruePositive(tc.getTruePositive() - cumulativePos);
tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
threshold = probs[sorted[i]];
insts.add(makeInstance(tc, threshold));
cumulativePos = 0;
cumulativeNeg = 0;
if (i == sorted.length - 1) {
break;
}
}
NominalPrediction pred = (NominalPrediction) predictions.get(sorted[i]);
if (pred.actual() == Prediction.MISSING_VALUE) {
System.err.println(getClass().getName()
+ " Skipping prediction with missing class value");
continue;
}
if (pred.weight() < 0) {
System.err.println(getClass().getName()
+ " Skipping prediction with negative weight");
continue;
}
if (pred.actual() == classIndex) {
cumulativePos += pred.weight();
} else {
cumulativeNeg += pred.weight();
}
/*
* System.out.println(tc + " " + probs[sorted[i]] + " " + (pred.actual()
* == classIndex));
*/
/*
* if ((i != (sorted.length - 1)) && ((i == 0) || (probs[sorted[i]] !=
* probs[sorted[i - 1]]))) { insts.add(makeInstance(tc,
* probs[sorted[i]])); }
*/
}
// make sure a zero point gets into the curve
if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) {
tc = new TwoClassStats(0, 0, totNeg, totPos);
threshold = probs[sorted[sorted.length - 1]] + 10e-6;
insts.add(makeInstance(tc, threshold));
}
return insts;
}
/**
* Calculates the n point precision result, which is the precision averaged
* over n evenly spaced (w.r.t recall) samples of the curve.
*
* @param tcurve a previously extracted threshold curve Instances.
* @param n the number of points to average over.
* @return the n-point precision.
*/
public static double getNPointPrecision(Instances tcurve, int n) {
if (!RELATION_NAME.equals(tcurve.relationName())
|| (tcurve.numInstances() == 0)) {
return Double.NaN;
}
int recallInd = tcurve.attribute(RECALL_NAME).index();
int precisInd = tcurve.attribute(PRECISION_NAME).index();
double[] recallVals = tcurve.attributeToDoubleArray(recallInd);
int[] sorted = Utils.sort(recallVals);
double isize = 1.0 / (n - 1);
double psum = 0;
for (int i = 0; i < n; i++) {
int pos = binarySearch(sorted, recallVals, i * isize);
double recall = recallVals[sorted[pos]];
double precis = tcurve.instance(sorted[pos]).value(precisInd);
/*
* System.err.println("Point " + (i + 1) + ": i=" + pos + " r=" + (i *
* isize) + " p'=" + precis + " r'=" + recall);
*/
// interpolate figures for non-endpoints
while ((pos != 0) && (pos < sorted.length - 1)) {
pos++;
double recall2 = recallVals[sorted[pos]];
if (recall2 != recall) {
double precis2 = tcurve.instance(sorted[pos]).value(precisInd);
double slope = (precis2 - precis) / (recall2 - recall);
double offset = precis - recall * slope;
precis = isize * i * slope + offset;
/*
* System.err.println("Point2 " + (i + 1) + ": i=" + pos + " r=" + (i
* * isize) + " p'=" + precis2 + " r'=" + recall2 + " p''=" + precis);
*/
break;
}
}
psum += precis;
}
return psum / n;
}
/**
* Calculates the area under the precision-recall curve (AUPRC).
*
* @param tcurve a previously extracted threshold curve Instances.
* @return the PRC area, or Double.NaN if you don't pass in a ThresholdCurve
* generated Instances.
*/
public static double getPRCArea(Instances tcurve) {
final int n = tcurve.numInstances();
if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) {
return Double.NaN;
}
final int pInd = tcurve.attribute(PRECISION_NAME).index();
final int rInd = tcurve.attribute(RECALL_NAME).index();
final double[] pVals = tcurve.attributeToDoubleArray(pInd);
final double[] rVals = tcurve.attributeToDoubleArray(rInd);
double area = 0;
double xlast = rVals[n - 1];
// start from the first real p/r pair (not the artificial zero point)
for (int i = n - 2; i >= 0; i--) {
double recallDelta = rVals[i] - xlast;
area += (pVals[i] * recallDelta);
xlast = rVals[i];
}
if (area == 0) {
return Utils.missingValue();
}
return area;
}
/**
* Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney
* statistic.
*
* @param tcurve a previously extracted threshold curve Instances.
* @return the ROC area, or Double.NaN if you don't pass in a ThresholdCurve
* generated Instances.
*/
public static double getROCArea(Instances tcurve) {
final int n = tcurve.numInstances();
if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) {
return Double.NaN;
}
final int tpInd = tcurve.attribute(TRUE_POS_NAME).index();
final int fpInd = tcurve.attribute(FALSE_POS_NAME).index();
final double[] tpVals = tcurve.attributeToDoubleArray(tpInd);
final double[] fpVals = tcurve.attributeToDoubleArray(fpInd);
double area = 0.0, cumNeg = 0.0;
final double totalPos = tpVals[0];
final double totalNeg = fpVals[0];
for (int i = 0; i < n; i++) {
double cip, cin;
if (i < n - 1) {
cip = tpVals[i] - tpVals[i + 1];
cin = fpVals[i] - fpVals[i + 1];
} else {
cip = tpVals[n - 1];
cin = fpVals[n - 1];
}
area += cip * (cumNeg + (0.5 * cin));
cumNeg += cin;
}
area /= (totalNeg * totalPos);
return area;
}
/**
* Gets the index of the instance with the closest threshold value to the
* desired target
*
* @param tcurve a set of instances that have been generated by this class
* @param threshold the target threshold
* @return the index of the instance that has threshold closest to the target,
* or -1 if this could not be found (i.e. no data, or bad threshold
* target)
*/
public static int getThresholdInstance(Instances tcurve, double threshold) {
if (!RELATION_NAME.equals(tcurve.relationName())
|| (tcurve.numInstances() == 0) || (threshold < 0) || (threshold > 1.0)) {
return -1;
}
if (tcurve.numInstances() == 1) {
return 0;
}
double[] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1);
int[] sorted = Utils.sort(tvals);
return binarySearch(sorted, tvals, threshold);
}
/**
* performs a binary search
*
* @param index the indices
* @param vals the values
* @param target the target to look for
* @return the index of the target
*/
private static int binarySearch(int[] index, double[] vals, double target) {
int lo = 0, hi = index.length - 1;
while (hi - lo > 1) {
int mid = lo + (hi - lo) / 2;
double midval = vals[index[mid]];
if (target > midval) {
lo = mid;
} else if (target < midval) {
hi = mid;
} else {
while ((mid > 0) && (vals[index[mid - 1]] == target)) {
mid--;
}
return mid;
}
}
return lo;
}
/**
*
* @param predictions the predictions to use
* @param classIndex the class index
* @return the probabilities
*/
private double[] getProbabilities(ArrayList predictions,
int classIndex) {
// sort by predicted probability of the desired class.
double[] probs = new double[predictions.size()];
for (int i = 0; i < probs.length; i++) {
NominalPrediction pred = (NominalPrediction) predictions.get(i);
probs[i] = pred.distribution()[classIndex];
}
return probs;
}
/**
* generates the header
*
* @return the header
*/
private Instances makeHeader() {
ArrayList fv = new ArrayList();
fv.add(new Attribute(TRUE_POS_NAME));
fv.add(new Attribute(FALSE_NEG_NAME));
fv.add(new Attribute(FALSE_POS_NAME));
fv.add(new Attribute(TRUE_NEG_NAME));
fv.add(new Attribute(FP_RATE_NAME));
fv.add(new Attribute(TP_RATE_NAME));
fv.add(new Attribute(PRECISION_NAME));
fv.add(new Attribute(RECALL_NAME));
fv.add(new Attribute(FALLOUT_NAME));
fv.add(new Attribute(FMEASURE_NAME));
fv.add(new Attribute(SAMPLE_SIZE_NAME));
fv.add(new Attribute(LIFT_NAME));
fv.add(new Attribute(THRESHOLD_NAME));
return new Instances(RELATION_NAME, fv, 100);
}
/**
* generates an instance out of the given data
*
* @param tc the statistics
* @param prob the probability
* @return the generated instance
*/
private Instance makeInstance(TwoClassStats tc, double prob) {
int count = 0;
double[] vals = new double[13];
vals[count++] = tc.getTruePositive();
vals[count++] = tc.getFalseNegative();
vals[count++] = tc.getFalsePositive();
vals[count++] = tc.getTrueNegative();
vals[count++] = tc.getFalsePositiveRate();
vals[count++] = tc.getTruePositiveRate();
vals[count++] = tc.getPrecision();
vals[count++] = tc.getRecall();
vals[count++] = tc.getFallout();
vals[count++] = tc.getFMeasure();
double ss = (tc.getTruePositive() + tc.getFalsePositive())
/ (tc.getTruePositive() + tc.getFalsePositive() + tc.getTrueNegative() + tc
.getFalseNegative());
vals[count++] = ss;
double expectedByChance = (ss * (tc.getTruePositive() + tc
.getFalseNegative()));
if (expectedByChance < 1) {
vals[count++] = Utils.missingValue();
} else {
vals[count++] = tc.getTruePositive() / expectedByChance;
}
vals[count++] = prob;
return new DenseInstance(1.0, vals);
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10153 $");
}
/**
* Tests the ThresholdCurve generation from the command line. The classifier
* is currently hardcoded. Pipe in an arff file.
*
* @param args currently ignored
*/
public static void main(String[] args) {
try {
Instances inst = new Instances(new java.io.InputStreamReader(System.in));
if (0 != Math.log(1)) { // false
System.out.println(ThresholdCurve.getNPointPrecision(inst, 11));
} else if (3 != 1 + 1) { // true
inst.setClassIndex(inst.numAttributes() - 1);
ThresholdCurve tc = new ThresholdCurve();
EvaluationUtils eu = new EvaluationUtils();
Classifier classifier = new weka.classifiers.functions.Logistic();
ArrayList predictions = new ArrayList();
for (int i = 0; i < 2; i++) { // Do two runs.
eu.setSeed(i);
predictions.addAll(eu.getCVPredictions(classifier, inst, 10));
// System.out.println("\n\n\n");
}
Instances result = tc.getCurve(predictions);
System.out.println(result);
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
}