All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.eval.ROC Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.eval;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

import java.io.Serializable;
import java.util.*;

/**
 * ROC (Receiver Operating Characteristic) for binary classifiers, using the specified number of threshold steps.
 * 

* Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' * ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size * 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the * full data set is not available in memory on any one machine at once). *

* The data is assumed to be binary classification - nColumns == 1 (single binary output variable) or nColumns == 2 * (probability distribution over 2 classes, with column 1 being values for 'positive' examples) * * @author Alex Black */ @Getter public class ROC extends BaseEvaluation { private final int thresholdSteps; private long countActualPositive; private long countActualNegative; private final Map counts = new LinkedHashMap<>(); /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation */ public ROC(int thresholdSteps) { this.thresholdSteps = thresholdSteps; double step = 1.0 / thresholdSteps; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; counts.put(currThreshold, new CountsForThreshold(currThreshold)); } } /** * Evaluate (collect statistics for) the given minibatch of data. * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)} * * @param labels Labels / true outcomes * @param predictions Predictions */ public void eval(INDArray labels, INDArray predictions) { if (labels.rank() == 3 && predictions.rank() == 3) { //Assume time series input -> reshape to 2d evalTimeSeries(labels, predictions); } if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) { throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2"); } double step = 1.0 / thresholdSteps; boolean singleOutput = labels.size(1) == 1; INDArray positivePredictedClassColumn; INDArray positiveActualClassColumn; INDArray negativeActualClassColumn; if (singleOutput) { //Single binary variable case positiveActualClassColumn = labels; negativeActualClassColumn = labels.rsub(1.0); //1.0 - label positivePredictedClassColumn = predictions; } else { //Standard case - 2 output variables (probability distribution) positiveActualClassColumn = labels.getColumn(1); negativeActualClassColumn = labels.getColumn(0); positivePredictedClassColumn = predictions.getColumn(1); } //Increment global counts - actual positive/negative observed countActualPositive += positiveActualClassColumn.sumNumber().intValue(); countActualNegative += negativeActualClassColumn.sumNumber().intValue(); //Here: calculate true positive rate (TPR) vs. false positive rate (FPR) at different threshold for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold Condition condGeq = Conditions.greaterThanOrEqual(currThreshold); Condition condLeq = Conditions.lessThanOrEqual(currThreshold); Op op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq); INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op); op = new CompareAndSet(predictedClass1, 0.0, condLeq); predictedClass1 = Nd4j.getExecutioner().execAndReturn(op); //True positives: occur when positive predicted class and actual positive actual class... //False positive occurs when positive predicted class, but negative actual class INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn); //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn); //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise //Counts for this batch: int truePositiveCount = isTruePositive.sumNumber().intValue(); int falsePositiveCount = isFalsePositive.sumNumber().intValue(); //Increment counts for this thold CountsForThreshold thresholdCounts = counts.get(currThreshold); thresholdCounts.incrementTruePositive(truePositiveCount); thresholdCounts.incrementFalsePositive(falsePositiveCount); } } /** * Get the ROC curve, as a set of points * * @return ROC curve, as a list of points */ public List getResults() { List out = new ArrayList<>(counts.size()); for (Map.Entry entry : counts.entrySet()) { double t = entry.getKey(); CountsForThreshold c = entry.getValue(); double tpr = c.getCountTruePositive() / ((double) countActualPositive); double fpr = c.getCountFalsePositive() / ((double) countActualNegative); out.add(new ROCValue(t, tpr, fpr)); } return out; } public List getPrecisionRecallCurve() { //Precision: (true positive count) / (true positive count + false positive count) == true positive rate //Recall: (true positive count) / (true positive count + false negative count) = (TP count) / (total dataset positives) List out = new ArrayList<>(counts.size()); for (Map.Entry entry : counts.entrySet()) { double t = entry.getKey(); CountsForThreshold c = entry.getValue(); long tpCount = c.getCountTruePositive(); long fpCount = c.getCountFalsePositive(); //For edge cases: http://stats.stackexchange.com/questions/1773/what-are-correct-values-for-precision-and-recall-in-edge-cases //precision == 1 when FP = 0 -> no incorrect positive predictions //recall == 1 when no dataset positives are present (got all 0 of 0 positives) double precision; if (tpCount == 0 && fpCount == 0) { //At this threshold: no predicted positive cases precision = 1.0; } else { precision = tpCount / (double) (tpCount + fpCount); } double recall; if (countActualPositive == 0) { recall = 1.0; } else { recall = tpCount / ((double) countActualPositive); } out.add(new PrecisionRecallPoint(c.getThreshold(), precision, recall)); } return out; } /** * Get the ROC curve, as a set of (falsePositive, truePositive) points *

* Returns a 2d array of {falsePositive, truePositive values}.
* Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives * * @return ROC curve as double[][] */ public double[][] getResultsAsArray() { double[][] out = new double[2][thresholdSteps + 1]; int i = 0; for (Map.Entry entry : counts.entrySet()) { CountsForThreshold c = entry.getValue(); double tpr = c.getCountTruePositive() / ((double) countActualPositive); double fpr = c.getCountFalsePositive() / ((double) countActualNegative); out[0][i] = fpr; out[1][i] = tpr; i++; } return out; } /** * Calculate the AUC - Area Under Curve
* Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUC() { //Calculate AUC using trapezoidal rule List list = getResults(); //Given the points double auc = 0.0; for (int i = 0; i < list.size() - 1; i++) { ROCValue left = list.get(i); ROCValue right = list.get(i + 1); //y axis: TPR //x axis: FPR double deltaX = Math.abs(right.getFalsePositiveRate() - left.getFalsePositiveRate()); //Iterating in threshold order, so FPR decreases as threshold increases double avg = (left.getTruePositiveRate() + right.getTruePositiveRate()) / 2.0; auc += deltaX * avg; } return auc; } /** * Merge this ROC instance with another. * This ROC instance is modified, by adding the stats from the other instance. * * @param other ROC instance to combine with this one */ @Override public void merge(ROC other) { if (this.thresholdSteps != other.thresholdSteps) { throw new UnsupportedOperationException( "Cannot merge ROC instances with different numbers of threshold steps (" + this.thresholdSteps + " vs. " + other.thresholdSteps + ")"); } this.countActualPositive += other.countActualPositive; this.countActualNegative += other.countActualNegative; for (Double d : this.counts.keySet()) { CountsForThreshold cft = this.counts.get(d); CountsForThreshold otherCft = other.counts.get(d); cft.countTruePositive += otherCft.countTruePositive; cft.countFalsePositive += otherCft.countFalsePositive; } } @AllArgsConstructor @Data public static class ROCValue { private final double threshold; private final double truePositiveRate; private final double falsePositiveRate; } @AllArgsConstructor @Data public static class PrecisionRecallPoint { private final double classiferThreshold; private final double precision; private final double recall; } @AllArgsConstructor @Data public static class CountsForThreshold implements Serializable, Cloneable { private double threshold; private long countTruePositive; private long countFalsePositive; public CountsForThreshold(double threshold) { this(threshold, 0, 0); } public void incrementTruePositive(long count) { countTruePositive += count; } public void incrementFalsePositive(long count) { countFalsePositive += count; } @Override public CountsForThreshold clone() { return new CountsForThreshold(threshold, countTruePositive, countFalsePositive); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy