All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.classification.EvaluationCalibration Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.evaluation.classification;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.ReliabilityDiagram;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArraySerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import java.io.Serializable;
import java.util.List;

/**
 * EvaluationCalibration is an evaluation class designed to analyze the calibration of a classifier.
* It provides a number of tools for this purpose: * - Counts of the number of labels and predictions for each class
* - Reliability diagram (or reliability curve)
* - Residual plot (histogram)
* - Histograms of probabilities, including probabilities for each class separately
*
* References:
* - Reliability diagram: see for example Niculescu-Mizil and Caruana 2005, Predicting Good Probabilities With * Supervised Learning
* - Residual plot: see Wallace and Dahabreh 2012, Class Probability Estimates are Unreliable for Imbalanced Data * (and How to Fix Them)
* * * @author Alex Black */ @Getter @EqualsAndHashCode public class EvaluationCalibration extends BaseEvaluation { public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10; public static final int DEFAULT_HISTOGRAM_NUM_BINS = 50; private final int reliabilityDiagNumBins; private final int histogramNumBins; private final boolean excludeEmptyBins; @EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality protected int axis = 1; @JsonSerialize(using = NDArraySerializer.class) @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray rDiagBinPosCount; @JsonSerialize(using = NDArraySerializer.class) @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray rDiagBinTotalCount; @JsonSerialize(using = NDArraySerializer.class) @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray rDiagBinSumPredictions; @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray labelCountsEachClass; @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray predictionCountsEachClass; @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray residualPlotOverall; @JsonSerialize(using = NDArraySerializer.class) @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray residualPlotByLabelClass; @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray probHistogramOverall; //Simple histogram over all probabilities @JsonSerialize(using = NDArraySerializer.class) @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray probHistogramByLabelClass; //Histogram - for each label class separately protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins) { this.axis = axis; this.reliabilityDiagNumBins = reliabilityDiagNumBins; this.histogramNumBins = histogramNumBins; this.excludeEmptyBins = excludeEmptyBins; } /** * Create an EvaluationCalibration instance with the default number of bins */ public EvaluationCalibration() { this(DEFAULT_RELIABILITY_DIAG_NUM_BINS, DEFAULT_HISTOGRAM_NUM_BINS, true); } /** * Create an EvaluationCalibration instance with the specified number of bins * * @param reliabilityDiagNumBins Number of bins for the reliability diagram (usually 10) * @param histogramNumBins Number of bins for the histograms */ public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins) { this(reliabilityDiagNumBins, histogramNumBins, true); } /** * Create an EvaluationCalibration instance with the specified number of bins * * @param reliabilityDiagNumBins Number of bins for the reliability diagram (usually 10) * @param histogramNumBins Number of bins for the histograms * @param excludeEmptyBins For the reliability diagram, whether empty bins should be excluded */ public EvaluationCalibration(@JsonProperty("reliabilityDiagNumBins") int reliabilityDiagNumBins, @JsonProperty("histogramNumBins") int histogramNumBins, @JsonProperty("excludeEmptyBins") boolean excludeEmptyBins) { this.reliabilityDiagNumBins = reliabilityDiagNumBins; this.histogramNumBins = histogramNumBins; this.excludeEmptyBins = excludeEmptyBins; } /** * Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
* For DL4J, this can be left as the default setting (axis = 1).
* Axis should be set as follows:
* For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
* For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
* For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
* * @param axis Axis to use for evaluation */ public void setAxis(int axis){ this.axis = axis; } /** * Get the axis - see {@link #setAxis(int)} for details */ public int getAxis(){ return axis; } @Override public void eval(INDArray labels, INDArray predictions, INDArray mask) { Triple triple = BaseEvaluation.reshapeAndExtractNotMasked(labels, predictions, mask, axis); if(triple == null){ //All values masked out; no-op return; } INDArray labels2d = triple.getFirst(); INDArray predictions2d = triple.getSecond(); INDArray maskArray = triple.getThird(); Preconditions.checkState(maskArray == null, "Per-output masking for EvaluationCalibration is not supported"); //Stats for the reliability diagram: one reliability diagram for each class // For each bin, we need: (a) the number of positive cases AND total cases, (b) the average probability val nClasses = labels2d.size(1); if (rDiagBinPosCount == null) { DataType dt = DataType.DOUBLE; //Initialize rDiagBinPosCount = Nd4j.create(DataType.LONG, reliabilityDiagNumBins, nClasses); rDiagBinTotalCount = Nd4j.create(DataType.LONG, reliabilityDiagNumBins, nClasses); rDiagBinSumPredictions = Nd4j.create(dt, reliabilityDiagNumBins, nClasses); labelCountsEachClass = Nd4j.create(DataType.LONG, 1, nClasses); predictionCountsEachClass = Nd4j.create(DataType.LONG, 1, nClasses); residualPlotOverall = Nd4j.create(dt, 1, histogramNumBins); residualPlotByLabelClass = Nd4j.create(dt, histogramNumBins, nClasses); probHistogramOverall = Nd4j.create(dt, 1, histogramNumBins); probHistogramByLabelClass = Nd4j.create(dt, histogramNumBins, nClasses); } //First: loop over classes, determine positive count and total count - for each bin double histogramBinSize = 1.0 / histogramNumBins; double reliabilityBinSize = 1.0 / reliabilityDiagNumBins; INDArray p = predictions2d; INDArray l = labels2d; if (maskArray != null) { //2 options: per-output masking, or if (maskArray.isColumnVectorOrScalar()) { //Per-example masking l = l.mulColumnVector(maskArray); } else { l = l.mul(maskArray); } } for (int j = 0; j < reliabilityDiagNumBins; j++) { INDArray geqBinLower = p.gte(j * reliabilityBinSize).castTo(predictions2d.dataType()); INDArray ltBinUpper; if (j == reliabilityDiagNumBins - 1) { //Handle edge case ltBinUpper = p.lte(1.0).castTo(predictions2d.dataType()); } else { ltBinUpper = p.lt((j + 1) * reliabilityBinSize).castTo(predictions2d.dataType()); } //Calculate bit-mask over each entry - whether that entry is in the current bin or not INDArray currBinBitMask = geqBinLower.muli(ltBinUpper); if (maskArray != null) { if (maskArray.isColumnVectorOrScalar()) { currBinBitMask.muliColumnVector(maskArray); } else { currBinBitMask.muli(maskArray); } } INDArray isPosLabelForBin = l.mul(currBinBitMask); INDArray maskedProbs = predictions2d.mul(currBinBitMask); INDArray numPredictionsCurrBin = currBinBitMask.sum(0); rDiagBinSumPredictions.getRow(j).addi(maskedProbs.sum(0).castTo(rDiagBinSumPredictions.dataType())); rDiagBinPosCount.getRow(j).addi(isPosLabelForBin.sum(0).castTo(rDiagBinPosCount.dataType())); rDiagBinTotalCount.getRow(j).addi(numPredictionsCurrBin.castTo(rDiagBinTotalCount.dataType())); } //Second, we want histograms of: //(a) Distribution of label classes: label counts for each class //(b) Distribution of prediction classes: prediction counts for each class //(c) residual plots, for each class - (i) all instances, (ii) positive instances only, (iii) negative only //(d) Histograms of probabilities, for each class labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType())); //For prediction counts: do an IsMax op, but we need to take masking into account... INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0]; if (maskArray != null) { LossUtil.applyMask(isPredictedClass, maskArray); } predictionCountsEachClass.addi(isPredictedClass.sum(0).castTo(predictionCountsEachClass.dataType())); //Residual plots: want histogram of |labels - predicted prob| //ND4J's histogram op: dynamically calculates the bin positions, which is not what I want here... INDArray labelsSubPredicted = labels2d.sub(predictions2d); INDArray maskedProbs = predictions2d.dup(); Transforms.abs(labelsSubPredicted, false); //if masking: replace entries with < 0 to effectively remove them if (maskArray != null) { //Assume per-example masking INDArray newMask = maskArray.mul(-10); labelsSubPredicted.addiColumnVector(newMask); maskedProbs.addiColumnVector(newMask); } for (int j = 0; j < histogramNumBins; j++) { INDArray geqBinLower = labelsSubPredicted.gte(j * histogramBinSize).castTo(predictions2d.dataType()); INDArray ltBinUpper; INDArray geqBinLowerProbs = maskedProbs.gte(j * histogramBinSize).castTo(predictions2d.dataType()); INDArray ltBinUpperProbs; if (j == histogramNumBins - 1) { //Handle edge case ltBinUpper = labelsSubPredicted.lte(1.0).castTo(predictions2d.dataType()); ltBinUpperProbs = maskedProbs.lte(1.0).castTo(predictions2d.dataType()); } else { ltBinUpper = labelsSubPredicted.lt((j + 1) * histogramBinSize).castTo(predictions2d.dataType()); ltBinUpperProbs = maskedProbs.lt((j + 1) * histogramBinSize).castTo(predictions2d.dataType()); } INDArray currBinBitMask = geqBinLower.muli(ltBinUpper); INDArray currBinBitMaskProbs = geqBinLowerProbs.muli(ltBinUpperProbs); int newTotalCount = residualPlotOverall.getInt(0, j) + currBinBitMask.sumNumber().intValue(); residualPlotOverall.putScalar(0, j, newTotalCount); //Counts for positive class only: values are in the current bin AND it's a positive label INDArray isPosLabelForBin = l.mul(currBinBitMask); residualPlotByLabelClass.getRow(j).addi(isPosLabelForBin.sum(0).castTo(residualPlotByLabelClass.dataType())); int probNewTotalCount = probHistogramOverall.getInt(0, j) + currBinBitMaskProbs.sumNumber().intValue(); probHistogramOverall.putScalar(0, j, probNewTotalCount); INDArray isPosLabelForBinProbs = l.mul(currBinBitMaskProbs); INDArray temp = isPosLabelForBinProbs.sum(0); probHistogramByLabelClass.getRow(j).addi(temp.castTo(probHistogramByLabelClass.dataType())); } } @Override public void eval(INDArray labels, INDArray networkPredictions) { eval(labels, networkPredictions, (INDArray) null); } @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { throw new UnsupportedOperationException("Not yet implemented"); } @Override public void merge(EvaluationCalibration other) { if (reliabilityDiagNumBins != other.reliabilityDiagNumBins) { throw new UnsupportedOperationException( "Cannot merge EvaluationCalibration instances with different numbers of bins"); } if (other.rDiagBinPosCount == null) { return; } if (rDiagBinPosCount == null) { this.rDiagBinPosCount = other.rDiagBinPosCount; this.rDiagBinTotalCount = other.rDiagBinTotalCount; this.rDiagBinSumPredictions = other.rDiagBinSumPredictions; } this.rDiagBinPosCount.addi(other.rDiagBinPosCount); this.rDiagBinTotalCount.addi(other.rDiagBinTotalCount); this.rDiagBinSumPredictions.addi(other.rDiagBinSumPredictions); } @Override public void reset() { rDiagBinPosCount = null; rDiagBinTotalCount = null; rDiagBinSumPredictions = null; } @Override public String stats() { return "EvaluationCalibration(nBins=" + reliabilityDiagNumBins + ")"; } public int numClasses() { if (rDiagBinTotalCount == null) { return -1; } return (int) rDiagBinTotalCount.size(1); } /** * Get the reliability diagram for the specified class * * @param classIdx Index of the class to get the reliability diagram for */ public ReliabilityDiagram getReliabilityDiagram(int classIdx) { Preconditions.checkState(rDiagBinPosCount != null, "Unable to get reliability diagram: no evaluation has been performed (no data)"); INDArray totalCountBins = rDiagBinTotalCount.getColumn(classIdx); INDArray countPositiveBins = rDiagBinPosCount.getColumn(classIdx); double[] meanPredictionBins = rDiagBinSumPredictions.getColumn(classIdx).castTo(DataType.DOUBLE) .div(totalCountBins.castTo(DataType.DOUBLE)).data().asDouble(); double[] fracPositives = countPositiveBins.castTo(DataType.DOUBLE).div(totalCountBins.castTo(DataType.DOUBLE)).data().asDouble(); if (excludeEmptyBins) { val condition = new MatchCondition(totalCountBins, Conditions.equals(0)); int numZeroBins = Nd4j.getExecutioner().exec(condition).getInt(0); if (numZeroBins != 0) { double[] mpb = meanPredictionBins; double[] fp = fracPositives; meanPredictionBins = new double[(int) (totalCountBins.length() - numZeroBins)]; fracPositives = new double[meanPredictionBins.length]; int j = 0; for (int i = 0; i < mpb.length; i++) { if (totalCountBins.getDouble(i) != 0) { meanPredictionBins[j] = mpb[i]; fracPositives[j] = fp[i]; j++; } } } } String title = "Reliability Diagram: Class " + classIdx; return new ReliabilityDiagram(title, meanPredictionBins, fracPositives); } /** * @return The number of observed labels for each class. For N classes, be returned array is of length N, with * out[i] being the number of labels of class i */ public int[] getLabelCountsEachClass() { return labelCountsEachClass == null ? null : labelCountsEachClass.data().asInt(); } /** * @return The number of network predictions for each class. For N classes, be returned array is of length N, with * out[i] being the number of predicted values (max probability) for class i */ public int[] getPredictionCountsEachClass() { return predictionCountsEachClass == null ? null : predictionCountsEachClass.data().asInt(); } /** * Get the residual plot for all classes combined. The residual plot is defined as a histogram of
* |label_i - prob(class_i | input)| for all classes i and examples.
* In general, small residuals indicate a superior classifier to large residuals. * * @return Residual plot (histogram) - all predictions/classes */ public Histogram getResidualPlotAllClasses() { String title = "Residual Plot - All Predictions and Classes"; int[] counts = residualPlotOverall.data().asInt(); return new Histogram(title, 0.0, 1.0, counts); } /** * Get the residual plot, only for examples of the specified class.. The residual plot is defined as a histogram of
* |label_i - prob(class_i | input)| for all and examples; for this particular method, only predictions where * i == labelClassIdx are included.
* In general, small residuals indicate a superior classifier to large residuals. * * @param labelClassIdx Index of the class to get the residual plot for * @return Residual plot (histogram) - all predictions/classes */ public Histogram getResidualPlot(int labelClassIdx) { Preconditions.checkState(rDiagBinPosCount != null, "Unable to get residual plot: no evaluation has been performed (no data)"); String title = "Residual Plot - Predictions for Label Class " + labelClassIdx; int[] counts = residualPlotByLabelClass.getColumn(labelClassIdx).dup().data().asInt(); return new Histogram(title, 0.0, 1.0, counts); } /** * Return a probability histogram for all predictions/classes. * * @return Probability histogram */ public Histogram getProbabilityHistogramAllClasses() { String title = "Network Probabilities Histogram - All Predictions and Classes"; int[] counts = probHistogramOverall.data().asInt(); return new Histogram(title, 0.0, 1.0, counts); } /** * Return a probability histogram of the specified label class index. That is, for label class index i, * a histogram of P(class_i | input) is returned, only for those examples that are labelled as class i. * * @param labelClassIdx Index of the label class to get the histogram for * @return Probability histogram */ public Histogram getProbabilityHistogram(int labelClassIdx) { Preconditions.checkState(rDiagBinPosCount != null, "Unable to get probability histogram: no evaluation has been performed (no data)"); String title = "Network Probabilities Histogram - P(class " + labelClassIdx + ") - Data Labelled Class " + labelClassIdx + " Only"; int[] counts = probHistogramByLabelClass.getColumn(labelClassIdx).dup().data().asInt(); return new Histogram(title, 0.0, 1.0, counts); } public static EvaluationCalibration fromJson(String json){ return fromJson(json, EvaluationCalibration.class); } @Override public double getValue(IMetric metric){ throw new IllegalStateException("Can't get value for non-calibration Metric " + metric); } @Override public EvaluationCalibration newInstance() { return new EvaluationCalibration(axis, reliabilityDiagNumBins, histogramNumBins, excludeEmptyBins); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy