All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.classification.ROC Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.evaluation.classification;

import lombok.*;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCSerializer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import java.io.Serializable;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;

/**
 * ROC (Receiver Operating Characteristic) for binary classifiers.
* ROC has 2 modes of operation: * (a) Thresholded (less memory)
* (b) Exact (default; use numSteps == 0 to set. May not scale to very large datasets) *

*

* Thresholded Is an approximate method, that (for large datasets) may use significantly less memory than exact.. * Whereas exact implementations will automatically calculate the threshold points based on the data set to give a * 'smoother' and more accurate ROC curve (or optimal cut points for diagnostic purposes), thresholded uses fixed steps * of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the * full data set is not available in memory on any one machine at once). * Note that in some cases (very skewed probability predictions, for example) the threshold approach can be inaccurate, * often underestimating the true area. *

* The data is assumed to be binary classification - nColumns == 1 (single binary output variable) or nColumns == 2 * (probability distribution over 2 classes, with column 1 being values for 'positive' examples) * * @author Alex Black */ @EqualsAndHashCode(callSuper = true, exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"}) @Data @JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"}) @JsonSerialize(using = ROCSerializer.class) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) public class ROC extends BaseEvaluation { /** * AUROC: Area under ROC curve
* AUPRC: Area under Precision-Recall Curve */ public enum Metric implements IMetric { AUROC, AUPRC; @Override public Class getEvaluationClass() { return ROC.class; } @Override public boolean minimize() { return false; } } private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048; private final Map counts = new LinkedHashMap<>(); private int thresholdSteps; private long countActualPositive; private long countActualNegative; private Double auc; private Double auprc; private RocCurve rocCurve; private PrecisionRecallCurve prCurve; private boolean isExact; private INDArray probAndLabel; private int exampleCount = 0; private boolean rocRemoveRedundantPts; private int exactAllocBlockSize; protected int axis = 1; public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis) { this(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize); this.axis = axis; } public ROC() { //Default to exact this(0); } /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation */ public ROC(int thresholdSteps) { this(thresholdSteps, true); } /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation * @param rocRemoveRedundantPts Usually set to true. If true, remove any redundant points from ROC and P-R curves */ public ROC(int thresholdSteps, boolean rocRemoveRedundantPts) { this(thresholdSteps, rocRemoveRedundantPts, DEFAULT_EXACT_ALLOC_BLOCK_SIZE); } /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation * @param rocRemoveRedundantPts Usually set to true. If true, remove any redundant points from ROC and P-R curves * @param exactAllocBlockSize if using exact mode, the block size relocation. Users can likely use the default * setting in almost all cases */ public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize) { if (thresholdSteps > 0) { this.thresholdSteps = thresholdSteps; double step = 1.0 / thresholdSteps; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; counts.put(currThreshold, new CountsForThreshold(currThreshold)); } isExact = false; } else { //Exact isExact = true; } this.rocRemoveRedundantPts = rocRemoveRedundantPts; this.exactAllocBlockSize = exactAllocBlockSize; } public static ROC fromJson(String json) { return fromJson(json, ROC.class); } /** * Set the axis for evaluation - this should be a size 1 dimension * For DL4J, this can be left as the default setting (axis = 1).
* Axis should be set as follows:
* For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
* For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
* For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
* * @param axis Axis to use for evaluation */ public void setAxis(int axis){ this.axis = axis; } /** * Get the axis - see {@link #setAxis(int)} for details */ public int getAxis(){ return axis; } private double getAuc() { if (auc != null) { return auc; } auc = calculateAUC(); return auc; } /** * Calculate the AUROC - Area Under ROC Curve
* Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUC() { if (auc != null) { return auc; } Preconditions.checkState(exampleCount > 0, "Unable to calculate AUC: no evaluation has been performed (no examples)"); this.auc = getRocCurve().calculateAUC(); return auc; } /** * Get the ROC curve, as a set of (threshold, falsePositive, truePositive) points * * @return ROC curve */ public RocCurve getRocCurve() { if (rocCurve != null) { return rocCurve; } Preconditions.checkState(exampleCount > 0, "Unable to get ROC curve: no evaluation has been performed (no examples)"); if (isExact) { //Sort ascending. As we decrease threshold, more are predicted positive. //if(prob <= threshold> predict 0, otherwise predict 1 //So, as we iterate from i=0..length, first 0 to i (inclusive) are predicted class 1, all others are predicted class 0 INDArray pl = getProbAndLabelUsed(); INDArray sorted = Nd4j.sortRows(pl, 0, false); INDArray isPositive = sorted.getColumn(1,true); INDArray isNegative = sorted.getColumn(1,true).rsub(1.0); INDArray cumSumPos = isPositive.cumsum(-1); INDArray cumSumNeg = isNegative.cumsum(-1); val length = sorted.size(0); INDArray t = Nd4j.create(DataType.DOUBLE, length + 2, 1); t.put(new INDArrayIndex[]{interval(1, length + 1), all()}, sorted.getColumn(0,true)); INDArray fpr = Nd4j.create(DataType.DOUBLE, length + 2, 1); fpr.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumNeg.div(countActualNegative)); INDArray tpr = Nd4j.create(DataType.DOUBLE, length + 2, 1); tpr.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumPos.div(countActualPositive)); //Edge cases t.putScalar(0, 0, 1.0); fpr.putScalar(0, 0, 0.0); tpr.putScalar(0, 0, 0.0); fpr.putScalar(length + 1, 0, 1.0); tpr.putScalar(length + 1, 0, 1.0); double[] x_fpr_out = fpr.data().asDouble(); double[] y_tpr_out = tpr.data().asDouble(); double[] tOut = t.data().asDouble(); //Note: we can have multiple FPR for a given TPR, and multiple TPR for a given FPR //These can be omitted, without changing the area (as long as we keep the edge points) if (rocRemoveRedundantPts) { Pair p = removeRedundant(tOut, x_fpr_out, y_tpr_out, null, null, null); double[][] temp = p.getFirst(); tOut = temp[0]; x_fpr_out = temp[1]; y_tpr_out = temp[2]; } this.rocCurve = new RocCurve(tOut, x_fpr_out, y_tpr_out); return rocCurve; } else { double[][] out = new double[3][thresholdSteps + 1]; int i = 0; for (Map.Entry entry : counts.entrySet()) { CountsForThreshold c = entry.getValue(); double tpr = c.getCountTruePositive() / ((double) countActualPositive); double fpr = c.getCountFalsePositive() / ((double) countActualNegative); out[0][i] = c.getThreshold(); out[1][i] = fpr; out[2][i] = tpr; i++; } return new RocCurve(out[0], out[1], out[2]); } } protected INDArray getProbAndLabelUsed() { if (probAndLabel == null || exampleCount == 0) { return null; } return probAndLabel.get(interval(0, exampleCount), all()); } private static Pair removeRedundant(double[] threshold, double[] x, double[] y, int[] tpCount, int[] fpCount, int[] fnCount) { double[] t_compacted = new double[threshold.length]; double[] x_compacted = new double[x.length]; double[] y_compacted = new double[y.length]; int[] tp_compacted = null; int[] fp_compacted = null; int[] fn_compacted = null; boolean hasInts = false; if (tpCount != null) { tp_compacted = new int[tpCount.length]; fp_compacted = new int[fpCount.length]; fn_compacted = new int[fnCount.length]; hasInts = true; } int lastOutPos = -1; for (int i = 0; i < threshold.length; i++) { boolean keep; if (i == 0 || i == threshold.length - 1) { keep = true; } else { boolean ommitSameY = y[i - 1] == y[i] && y[i] == y[i + 1]; boolean ommitSameX = x[i - 1] == x[i] && x[i] == x[i + 1]; keep = !ommitSameX && !ommitSameY; } if (keep) { lastOutPos++; t_compacted[lastOutPos] = threshold[i]; y_compacted[lastOutPos] = y[i]; x_compacted[lastOutPos] = x[i]; if (hasInts) { tp_compacted[lastOutPos] = tpCount[i]; fp_compacted[lastOutPos] = fpCount[i]; fn_compacted[lastOutPos] = fnCount[i]; } } } if (lastOutPos < x.length - 1) { t_compacted = Arrays.copyOfRange(t_compacted, 0, lastOutPos + 1); x_compacted = Arrays.copyOfRange(x_compacted, 0, lastOutPos + 1); y_compacted = Arrays.copyOfRange(y_compacted, 0, lastOutPos + 1); if (hasInts) { tp_compacted = Arrays.copyOfRange(tp_compacted, 0, lastOutPos + 1); fp_compacted = Arrays.copyOfRange(fp_compacted, 0, lastOutPos + 1); fn_compacted = Arrays.copyOfRange(fn_compacted, 0, lastOutPos + 1); } } return new Pair<>(new double[][]{t_compacted, x_compacted, y_compacted}, hasInts ? new int[][]{tp_compacted, fp_compacted, fn_compacted} : null); } private double getAuprc() { if (auprc != null) { return auprc; } auprc = calculateAUCPR(); return auprc; } /** * Calculate the area under the precision/recall curve - aka AUCPR * * @return */ public double calculateAUCPR() { if (auprc != null) { return auprc; } Preconditions.checkState(exampleCount > 0, "Unable to calculate AUPRC: no evaluation has been performed (no examples)"); auprc = getPrecisionRecallCurve().calculateAUPRC(); return auprc; } /** * Get the precision recall curve as array. * return[0] = threshold array
* return[1] = precision array
* return[2] = recall array
* * @return */ public PrecisionRecallCurve getPrecisionRecallCurve() { if (prCurve != null) { return prCurve; } Preconditions.checkState(exampleCount > 0, "Unable to get PR curve: no evaluation has been performed (no examples)"); double[] thresholdOut; double[] precisionOut; double[] recallOut; int[] tpCountOut; int[] fpCountOut; int[] fnCountOut; if (isExact) { INDArray pl = getProbAndLabelUsed(); INDArray sorted = Nd4j.sortRows(pl, 0, false); INDArray isPositive = sorted.getColumn(1,true); INDArray cumSumPos = isPositive.cumsum(-1); val length = sorted.size(0); /* Sort descending. As we iterate: decrease probability threshold T... all values <= T are predicted as class 0, all others are predicted as class 1 Precision: sum(TP) / sum(predicted pos at threshold) Recall: sum(TP) / total actual positives predicted positive at threshold: # values <= threshold, i.e., just i */ INDArray t = Nd4j.create(DataType.DOUBLE, length + 2, 1); t.put(new INDArrayIndex[]{interval(1, length + 1), all()}, sorted.getColumn(0,true)); INDArray linspace = Nd4j.linspace(1, length, length, DataType.DOUBLE); INDArray precision = cumSumPos.castTo(DataType.DOUBLE).div(linspace.reshape(cumSumPos.shape())); INDArray prec = Nd4j.create(DataType.DOUBLE, length + 2, 1); prec.put(new INDArrayIndex[]{interval(1, length + 1), all()}, precision); //Recall/TPR INDArray rec = Nd4j.create(DataType.DOUBLE, length + 2, 1); rec.put(new INDArrayIndex[]{interval(1, length + 1), all()}, cumSumPos.div(countActualPositive)); //Edge cases t.putScalar(0, 0, 1.0); prec.putScalar(0, 0, 1.0); rec.putScalar(0, 0, 0.0); prec.putScalar(length + 1, 0, cumSumPos.getDouble(cumSumPos.length() - 1) / length); rec.putScalar(length + 1, 0, 1.0); thresholdOut = t.data().asDouble(); precisionOut = prec.data().asDouble(); recallOut = rec.data().asDouble(); //Counts. Note the edge cases tpCountOut = new int[thresholdOut.length]; fpCountOut = new int[thresholdOut.length]; fnCountOut = new int[thresholdOut.length]; for (int i = 1; i < tpCountOut.length - 1; i++) { tpCountOut[i] = cumSumPos.getInt(i - 1); fpCountOut[i] = i - tpCountOut[i]; //predicted positive - true positive fnCountOut[i] = (int) countActualPositive - tpCountOut[i]; } //Edge cases: last idx -> threshold of 0.0, all predicted positive tpCountOut[tpCountOut.length - 1] = (int) countActualPositive; fpCountOut[tpCountOut.length - 1] = (int) (exampleCount - countActualPositive); fnCountOut[tpCountOut.length - 1] = 0; //Edge case: first idx -> threshold of 1.0, all predictions negative tpCountOut[0] = 0; fpCountOut[0] = 0; //(int)(exampleCount - countActualPositive); //All negatives are predicted positive fnCountOut[0] = (int) countActualPositive; //Finally: 2 things to do //(a) Reverse order: lowest to highest threshold //(b) remove unnecessary/rendundant points (doesn't affect graph or AUPRC) ArrayUtils.reverse(thresholdOut); ArrayUtils.reverse(precisionOut); ArrayUtils.reverse(recallOut); ArrayUtils.reverse(tpCountOut); ArrayUtils.reverse(fpCountOut); ArrayUtils.reverse(fnCountOut); if (rocRemoveRedundantPts) { Pair pair = removeRedundant(thresholdOut, precisionOut, recallOut, tpCountOut, fpCountOut, fnCountOut); double[][] temp = pair.getFirst(); int[][] temp2 = pair.getSecond(); thresholdOut = temp[0]; precisionOut = temp[1]; recallOut = temp[2]; tpCountOut = temp2[0]; fpCountOut = temp2[1]; fnCountOut = temp2[2]; } } else { thresholdOut = new double[counts.size()]; precisionOut = new double[counts.size()]; recallOut = new double[counts.size()]; tpCountOut = new int[counts.size()]; fpCountOut = new int[counts.size()]; fnCountOut = new int[counts.size()]; int i = 0; for (Map.Entry entry : counts.entrySet()) { double t = entry.getKey(); CountsForThreshold c = entry.getValue(); long tpCount = c.getCountTruePositive(); long fpCount = c.getCountFalsePositive(); //For edge cases: http://stats.stackexchange.com/questions/1773/what-are-correct-values-for-precision-and-recall-in-edge-cases //precision == 1 when FP = 0 -> no incorrect positive predictions //recall == 1 when no dataset positives are present (got all 0 of 0 positives) double precision; if (tpCount == 0 && fpCount == 0) { //At this threshold: no predicted positive cases precision = 1.0; } else { precision = tpCount / (double) (tpCount + fpCount); } double recall; if (countActualPositive == 0) { recall = 1.0; } else { recall = tpCount / ((double) countActualPositive); } thresholdOut[i] = c.getThreshold(); precisionOut[i] = precision; recallOut[i] = recall; tpCountOut[i] = (int) tpCount; fpCountOut[i] = (int) fpCount; fnCountOut[i] = (int) (countActualPositive - tpCount); i++; } } prCurve = new PrecisionRecallCurve(thresholdOut, precisionOut, recallOut, tpCountOut, fpCountOut, fnCountOut, exampleCount); return prCurve; } @AllArgsConstructor @Data @NoArgsConstructor public static class CountsForThreshold implements Serializable, Cloneable { private double threshold; private long countTruePositive; private long countFalsePositive; public CountsForThreshold(double threshold) { this(threshold, 0, 0); } @Override public CountsForThreshold clone() { return new CountsForThreshold(threshold, countTruePositive, countFalsePositive); } public void incrementFalsePositive(long count) { countFalsePositive += count; } public void incrementTruePositive(long count) { countTruePositive += count; } } /** * Evaluate (collect statistics for) the given minibatch of data. * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)} * * @param labels Labels / true outcomes * @param predictions Predictions */ @Override public void eval(INDArray labels, INDArray predictions, INDArray mask, List recordMetaData) { Triple p = BaseEvaluation.reshapeAndExtractNotMasked(labels, predictions, mask, axis); if (p == null) { //All values masked out; no-op return; } INDArray labels2d = p.getFirst(); INDArray predictions2d = p.getSecond(); if (labels2d.rank() == 3 && predictions2d.rank() == 3) { //Assume time series input -> reshape to 2d evalTimeSeries(labels2d, predictions2d); } if (labels2d.rank() > 2 || predictions2d.rank() > 2 || labels2d.size(1) != predictions2d.size(1) || labels2d.size(1) > 2) { throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels2d.shape()) + ", predictions shape = " + Arrays.toString(predictions2d.shape()) + "; require rank 2 array with size(1) == 1 or 2"); } if (labels2d.dataType() != predictions2d.dataType()) labels2d = labels2d.castTo(predictions2d.dataType()); //Check for NaNs in predictions - without this, evaulation could silently be intepreted as class 0 prediction due to argmax long count = Nd4j.getExecutioner().execAndReturn(new MatchCondition(predictions2d, Conditions.isNan())).getFinalResult().longValue(); Preconditions.checkState(count == 0, "Cannot perform evaluation with NaN(s) present:" + " %s NaN(s) present in predictions INDArray", count); double step = 1.0 / thresholdSteps; boolean singleOutput = labels2d.size(1) == 1; if (isExact) { //Exact approach: simply add them to the storage for later computation/use if (probAndLabel == null) { //Do initial allocation val initialSize = Math.max(labels2d.size(0), exactAllocBlockSize); probAndLabel = Nd4j.create(DataType.DOUBLE, new long[]{initialSize, 2}, 'c'); //First col: probability of class 1. Second col: "is class 1" } //Allocate a larger array if necessary if (exampleCount + labels2d.size(0) >= probAndLabel.size(0)) { val newSize = probAndLabel.size(0) + Math.max(exactAllocBlockSize, labels2d.size(0)); INDArray newProbAndLabel = Nd4j.create(DataType.DOUBLE, new long[]{newSize, 2}, 'c'); if (exampleCount > 0) { //If statement to handle edge case: no examples, but we need to re-allocate right away newProbAndLabel.get(interval(0, exampleCount), all()).assign( probAndLabel.get(interval(0, exampleCount), all())); } probAndLabel = newProbAndLabel; } //put values INDArray probClass1; INDArray labelClass1; if (singleOutput) { probClass1 = predictions2d; labelClass1 = labels2d; } else { probClass1 = predictions2d.getColumn(1,true); labelClass1 = labels2d.getColumn(1,true); } val currMinibatchSize = labels2d.size(0); probAndLabel.get(interval(exampleCount, exampleCount + currMinibatchSize), NDArrayIndex.point(0)).assign(probClass1); probAndLabel.get(interval(exampleCount, exampleCount + currMinibatchSize), NDArrayIndex.point(1)).assign(labelClass1); int countClass1CurrMinibatch = labelClass1.sumNumber().intValue(); countActualPositive += countClass1CurrMinibatch; countActualNegative += labels2d.size(0) - countClass1CurrMinibatch; } else { //Thresholded approach INDArray positivePredictedClassColumn; INDArray positiveActualClassColumn; INDArray negativeActualClassColumn; if (singleOutput) { //Single binary variable case positiveActualClassColumn = labels2d; negativeActualClassColumn = labels2d.rsub(1.0); //1.0 - label positivePredictedClassColumn = predictions2d; } else { //Standard case - 2 output variables (probability distribution) positiveActualClassColumn = labels2d.getColumn(1,true); negativeActualClassColumn = labels2d.getColumn(0,true); positivePredictedClassColumn = predictions2d.getColumn(1,true); } //Increment global counts - actual positive/negative observed countActualPositive += positiveActualClassColumn.sumNumber().intValue(); countActualNegative += negativeActualClassColumn.sumNumber().intValue(); //Here: calculate true positive rate (TPR) vs. false positive rate (FPR) at different threshold INDArray ppc = null; INDArray itp = null; INDArray ifp = null; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold Condition condGeq = Conditions.greaterThanOrEqual(currThreshold); Condition condLeq = Conditions.lessThanOrEqual(currThreshold); if (ppc == null) { ppc = positivePredictedClassColumn.dup(positiveActualClassColumn.ordering()); } else { ppc.assign(positivePredictedClassColumn); } Op op = new CompareAndSet(ppc, 1.0, condGeq); INDArray predictedClass1 = Nd4j.getExecutioner().exec(op); op = new CompareAndSet(predictedClass1, 0.0, condLeq); predictedClass1 = Nd4j.getExecutioner().exec(op); //True positives: occur when positive predicted class and actual positive actual class... //False positive occurs when positive predicted class, but negative actual class INDArray isTruePositive; // = predictedClass1.mul(positiveActualClassColumn); //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise INDArray isFalsePositive; // = predictedClass1.mul(negativeActualClassColumn); //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise if (i == 0) { isTruePositive = predictedClass1.mul(positiveActualClassColumn); isFalsePositive = predictedClass1.mul(negativeActualClassColumn); itp = isTruePositive; ifp = isFalsePositive; } else { isTruePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, positiveActualClassColumn, itp))[0]; isFalsePositive = Nd4j.getExecutioner().exec(new MulOp(predictedClass1, negativeActualClassColumn, ifp))[0]; } //Counts for this batch: int truePositiveCount = isTruePositive.sumNumber().intValue(); int falsePositiveCount = isFalsePositive.sumNumber().intValue(); //Increment counts for this thold CountsForThreshold thresholdCounts = counts.get(currThreshold); thresholdCounts.incrementTruePositive(truePositiveCount); thresholdCounts.incrementFalsePositive(falsePositiveCount); } } exampleCount += labels2d.size(0); auc = null; auprc = null; rocCurve = null; prCurve = null; } /** * Merge this ROC instance with another. * This ROC instance is modified, by adding the stats from the other instance. * * @param other ROC instance to combine with this one */ @Override public void merge(ROC other) { if (this.thresholdSteps != other.thresholdSteps) { throw new UnsupportedOperationException( "Cannot merge ROC instances with different numbers of threshold steps (" + this.thresholdSteps + " vs. " + other.thresholdSteps + ")"); } this.countActualPositive += other.countActualPositive; this.countActualNegative += other.countActualNegative; this.auc = null; this.auprc = null; this.rocCurve = null; this.prCurve = null; if (isExact) { if (other.exampleCount == 0) { return; } if (this.exampleCount == 0) { this.exampleCount = other.exampleCount; this.probAndLabel = other.probAndLabel; return; } if (this.exampleCount + other.exampleCount > this.probAndLabel.size(0)) { //Allocate new array val newSize = this.probAndLabel.size(0) + Math.max(other.probAndLabel.size(0), exactAllocBlockSize); INDArray newProbAndLabel = Nd4j.create(DataType.DOUBLE, newSize, 2); newProbAndLabel.put(new INDArrayIndex[]{interval(0, exampleCount), all()}, probAndLabel.get(interval(0, exampleCount), all())); probAndLabel = newProbAndLabel; } INDArray toPut = other.probAndLabel.get(interval(0, other.exampleCount), all()); probAndLabel.put(new INDArrayIndex[]{ interval(exampleCount, exampleCount + other.exampleCount), all()}, toPut); } else { for (Double d : this.counts.keySet()) { CountsForThreshold cft = this.counts.get(d); CountsForThreshold otherCft = other.counts.get(d); cft.countTruePositive += otherCft.countTruePositive; cft.countFalsePositive += otherCft.countFalsePositive; } } this.exampleCount += other.exampleCount; } @Override public void reset() { countActualPositive = 0L; countActualNegative = 0L; counts.clear(); if (isExact) { probAndLabel = null; } else { double step = 1.0 / thresholdSteps; for (int i = 0; i <= thresholdSteps; i++) { double currThreshold = i * step; counts.put(currThreshold, new CountsForThreshold(currThreshold)); } } exampleCount = 0; auc = null; auprc = null; } @Override public String stats() { if(this.exampleCount == 0){ return "ROC: No data available (no data has been performed)"; } StringBuilder sb = new StringBuilder(); sb.append("AUC (Area under ROC Curve): ").append(calculateAUC()).append("\n"); sb.append("AUPRC (Area under Precision/Recall Curve): ").append(calculateAUCPR()); if (!isExact) { sb.append("\n"); sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(thresholdSteps) .append(" steps); accuracy may reduced compared to exact mode]"); } return sb.toString(); } @Override public String toString(){ return stats(); } public double scoreForMetric(Metric metric){ switch (metric){ case AUROC: return calculateAUC(); case AUPRC: return calculateAUCPR(); default: throw new IllegalStateException("Unknown metric: " + metric); } } @Override public double getValue(IMetric metric){ if(metric instanceof Metric){ return scoreForMetric((Metric) metric); } else throw new IllegalStateException("Can't get value for non-ROC Metric " + metric); } @Override public ROC newInstance() { return new ROC(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize, axis); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy