All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.eval.ROCMultiClass Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.eval;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.eval.curves.PrecisionRecallCurve;
import org.deeplearning4j.eval.curves.RocCurve;
import org.deeplearning4j.eval.serde.ROCArraySerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import java.util.Arrays;
import java.util.List;

/**
 * ROC (Receiver Operating Characteristic) for multi-class classifiers.
  As per {@link ROC}, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see {@link ROC} for details.
 * 

* The ROC curves are produced by treating the predictions as a set of one-vs-all classifiers, and then calculating * ROC curves for each. In practice, this means for N classes, we get N ROC curves. * * @author Alex Black */ @Data @EqualsAndHashCode(callSuper = true) public class ROCMultiClass extends BaseEvaluation { public static final int DEFAULT_STATS_PRECISION = 4; private int thresholdSteps; private boolean rocRemoveRedundantPts; @JsonSerialize(using = ROCArraySerializer.class) private ROC[] underlying; private List labels; public ROCMultiClass() { //Default to exact this(0); } /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation */ public ROCMultiClass(int thresholdSteps) { this(thresholdSteps, true); } /** * @param thresholdSteps Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation * @param rocRemoveRedundantPts Usually set to true. If true, remove any redundant points from ROC and P-R curves */ public ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts) { this.thresholdSteps = thresholdSteps; this.rocRemoveRedundantPts = rocRemoveRedundantPts; } @Override public void reset() { underlying = null; } @Override public String stats() { return stats(DEFAULT_STATS_PRECISION); } public String stats(int printPrecision) { StringBuilder sb = new StringBuilder(); int maxLabelsLength = 15; if (labels != null) { for (String s : labels) { maxLabelsLength = Math.max(s.length(), maxLabelsLength); } } String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s"; String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg"); String pattern = "%-" + (maxLabelsLength + 5) + "s" //Label + "%-12." + printPrecision + "f" //AUC + "%-10d%-10d"; //Count pos, count neg sb.append(header); if (underlying != null) { for (int i = 0; i < underlying.length; i++) { double auc = calculateAUC(i); String label = (labels == null ? String.valueOf(i) : labels.get(i)); sb.append("\n").append(String.format(pattern, label, auc, getCountActualPositive(i), getCountActualNegative(i))); } sb.append("Average AUC: ").append(String.format("%-12." + printPrecision + "f", calculateAverageAUC())); } else { //Empty evaluation sb.append("\n-- No Data --\n"); } return sb.toString(); } /** * Evaluate (collect statistics for) the given minibatch of data. * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)} * * @param labels Labels / true outcomes * @param predictions Predictions */ public void eval(INDArray labels, INDArray predictions) { if (labels.rank() == 3 && predictions.rank() == 3) { //Assume time series input -> reshape to 2d evalTimeSeries(labels, predictions); } if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) { throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2"); } int n = labels.size(1); if (underlying == null) { underlying = new ROC[n]; for (int i = 0; i < n; i++) { underlying[i] = new ROC(thresholdSteps, rocRemoveRedundantPts); } } if (underlying.length != labels.size(1)) { throw new IllegalArgumentException( "Cannot evaluate data: number of label classes does not match previous call. " + "Got " + labels.size(1) + " labels (from array shape " + Arrays.toString(labels.shape()) + ")" + " vs. expected number of label classes = " + underlying.length); } for (int i = 0; i < n; i++) { INDArray prob = predictions.getColumn(i); //Probability of class i INDArray label = labels.getColumn(i); underlying[i].eval(label, prob); } } /** * Get the (one vs. all) ROC curve for the specified class * @param classIdx Class index to get the ROC curve for * @return ROC curve for the given class */ public RocCurve getRocCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getRocCurve(); } /** * Get the (one vs. all) Precision-Recall curve for the specified class * @param classIdx Class to get the P-R curve for * @return Precision recall curve for the given class */ public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) { assertIndex(classIdx); return underlying[classIdx].getPrecisionRecallCurve(); } /** * Calculate the AUC - Area Under ROC Curve
* Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUC(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUC(); } /** * Calculate the AUPRC - Area Under Curve Precision Recall
* Utilizes trapezoidal integration internally * * @return AUC */ public double calculateAUCPR(int classIdx) { assertIndex(classIdx); return underlying[classIdx].calculateAUCPR(); } /** * Calculate the macro-average (one-vs-all) AUC for all classes */ public double calculateAverageAUC() { assertIndex(0); double sum = 0.0; for (int i = 0; i < underlying.length; i++) { sum += calculateAUC(i); } return sum / underlying.length; } /** * Get the actual positive count (accounting for any masking) for the specified class * * @param outputNum Index of the class */ public long getCountActualPositive(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualPositive(); } /** * Get the actual negative count (accounting for any masking) for the specified output/column * * @param outputNum Index of the class */ public long getCountActualNegative(int outputNum) { assertIndex(outputNum); return underlying[outputNum].getCountActualNegative(); } /** * Merge this ROCMultiClass instance with another. * This ROCMultiClass instance is modified, by adding the stats from the other instance. * * @param other ROCMultiClass instance to combine with this one */ @Override public void merge(ROCMultiClass other) { if (this.underlying == null) { this.underlying = other.underlying; return; } else if (other.underlying == null) { return; } //Both have data if (underlying.length != other.underlying.length) { throw new UnsupportedOperationException("Cannot merge ROCBinary: this expects " + underlying.length + "outputs, other expects " + other.underlying.length + " outputs"); } for (int i = 0; i < underlying.length; i++) { this.underlying[i].merge(other.underlying[i]); } } public int getNumClasses() { if (underlying == null) { return -1; } return underlying.length; } private void assertIndex(int classIdx) { if (underlying == null) { throw new IllegalStateException("Cannot get results: no data has been collected"); } if (classIdx < 0 || classIdx >= underlying.length) { throw new IllegalArgumentException("Invalid class index (" + classIdx + "): must be in range 0 to numClasses = " + underlying.length); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy