All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.classification.Evaluation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.evaluation.classification;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.meta.Prediction;
import org.nd4j.evaluation.serde.ConfusionMatrixDeserializer;
import org.nd4j.evaluation.serde.ConfusionMatrixSerializer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.*;

@Slf4j
@EqualsAndHashCode(callSuper = true)
@Getter
@Setter
@JsonIgnoreProperties({"confusionMatrixMetaData"})
public class Evaluation extends BaseEvaluation {

    public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC;

        @Override
        public Class getEvaluationClass() {
            return Evaluation.class;
        }

        @Override
        public boolean minimize() {
            return false;
        }
    }

    //What to output from the precision/recall function when we encounter an edge case
    protected static final double DEFAULT_EDGE_VALUE = 0.0;

    protected static final int CONFUSION_PRINT_MAX_CLASSES = 20;

    @EqualsAndHashCode.Exclude      //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality
    protected int axis = 1;
    protected Integer binaryPositiveClass = 1;  //Used *only* for binary classification; default value here to 1 for legacy JSON loading
    protected final int topN;
    protected int topNCorrectCount = 0;
    protected int topNTotalCount = 0; //Could use topNCountCorrect / (double)getNumRowCounter() - except for eval(int,int), hence separate counters
    protected Counter truePositives = new Counter<>();
    protected Counter falsePositives = new Counter<>();
    protected Counter trueNegatives = new Counter<>();
    protected Counter falseNegatives = new Counter<>();
    @JsonSerialize(using = ConfusionMatrixSerializer.class)
    @JsonDeserialize(using = ConfusionMatrixDeserializer.class)
    protected ConfusionMatrix confusion;
    protected int numRowCounter = 0;
    @Getter
    @Setter
    protected List labelsList = new ArrayList<>();

    protected Double binaryDecisionThreshold;
    @JsonSerialize(using = NDArrayTextSerializer.class)
    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    protected INDArray costArray;

    protected Map, List> confusionMatrixMetaData; //Pair: (Actual,Predicted)

    /**
     * For stats(): When classes are excluded from precision/recall, what is the maximum number we should print?
     * If this is set to a high value, the output (potentially thousands of classes) can become unreadable.
     */
    @Getter @Setter
    protected int maxWarningClassesToPrint = 16;

    protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List labelsList,
            Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint){
        this.axis = axis;
        this.binaryPositiveClass = binaryPositiveClass;
        this.topN = topN;
        this.labelsList = labelsList;
        this.binaryDecisionThreshold = binaryDecisionThreshold;
        this.costArray = costArray;
        this.maxWarningClassesToPrint = maxWarningClassesToPrint;
    }

    // Empty constructor
    public Evaluation() {
        this.topN = 1;
        this.binaryPositiveClass = 1;
    }

    /**
     * The number of classes to account for in the evaluation
     * @param numClasses the number of classes to account for in the evaluation
     */
    public Evaluation(int numClasses) {
        this(numClasses, (numClasses == 2 ? 1 : null));
    }

    /**
     * Constructor for specifying the number of classes, and optionally the positive class for binary classification.
     * See Evaluation javadoc for more details on evaluation in the binary case
     *
     * @param numClasses          The number of classes for the evaluation. Must be 2, if binaryPositiveClass is non-null
     * @param binaryPositiveClass If non-null, the positive class (0 or 1).
     */
    public Evaluation(int numClasses, Integer binaryPositiveClass){
        this(createLabels(numClasses), 1);
        if(binaryPositiveClass != null){
            Preconditions.checkArgument(binaryPositiveClass == 0 || binaryPositiveClass == 1,
                    "Only 0 and 1 are valid inputs for binaryPositiveClass; got " + binaryPositiveClass);
            Preconditions.checkArgument(numClasses == 2, "Cannot set binaryPositiveClass argument " +
                    "when number of classes is not equal to 2 (got: numClasses=" + numClasses + ")");
        }
        this.binaryPositiveClass = binaryPositiveClass;
    }


    /**
     * The labels to include with the evaluation.
     * This constructor can be used for
     * generating labeled output rather than just
     * numbers for the labels
     * @param labels the labels to use
     *               for the output
     */
    public Evaluation(List labels) {
        this(labels, 1);
    }

    /**
     * Use a map to generate labels
     * Pass in a label index with the actual label
     * you want to use for output
     * @param labels a map of label index to label value
     */
    public Evaluation(Map labels) {
        this(createLabelsFromMap(labels), 1);
    }

    /**
     * Constructor to use for top N accuracy
     *
     * @param labels Labels for the classes (may be null)
     * @param topN   Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N
     *               accuracy, an example is considered 'correct' if the probability for the true class is one of the
     *               highest N values
     */
    public Evaluation(List labels, int topN) {
        this.labelsList = labels;
        if (labels != null) {
            createConfusion(labels.size());
        }

        this.topN = topN;
        if(labels != null && labels.size() == 2){
            this.binaryPositiveClass = 1;
        }
    }

    /**
     * Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can
     * only be used with binary classifiers.
* Defaults to class 1 for the positive class - see class javadoc, and use {@link #Evaluation(double, Integer)} to * change this. * * @param binaryDecisionThreshold Decision threshold to use for binary predictions */ public Evaluation(double binaryDecisionThreshold) { this(binaryDecisionThreshold, 1); } /** * Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can * only be used with binary classifiers.
* This constructor also allows the user to specify the positive class for binary classification. See class javadoc * for more details. * * @param binaryDecisionThreshold Decision threshold to use for binary predictions */ public Evaluation(double binaryDecisionThreshold, @NonNull Integer binaryPositiveClass) { if(binaryPositiveClass != null){ Preconditions.checkArgument(binaryPositiveClass == 0 || binaryPositiveClass == 1, "Only 0 and 1 are valid inputs for binaryPositiveClass; got " + binaryPositiveClass); } this.binaryDecisionThreshold = binaryDecisionThreshold; this.topN = 1; this.binaryPositiveClass = binaryPositiveClass; } /** * Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class * predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) * instead of argMax(probability) when no cost array is present. * * @param costArray Row vector cost array. May be null */ public Evaluation(INDArray costArray) { this(null, costArray); } /** * Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class * predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) * instead of argMax(probability) when no cost array is present. * * @param labels Labels for the output classes. May be null * @param costArray Row vector cost array. May be null */ public Evaluation(List labels, INDArray costArray) { if (costArray != null && !costArray.isRowVectorOrScalar()) { throw new IllegalArgumentException("Invalid cost array: must be a row vector (got shape: " + Arrays.toString(costArray.shape()) + ")"); } if (costArray != null && costArray.minNumber().doubleValue() < 0.0) { throw new IllegalArgumentException("Invalid cost array: Cost array values must be positive"); } this.labelsList = labels; this.costArray = costArray == null ? null : costArray.castTo(DataType.FLOAT); this.topN = 1; } protected int numClasses(){ if(labelsList != null){ return labelsList.size(); } return confusion().getClasses().size(); } @Override public void reset() { confusion = null; truePositives = new Counter<>(); falsePositives = new Counter<>(); trueNegatives = new Counter<>(); falseNegatives = new Counter<>(); topNCorrectCount = 0; topNTotalCount = 0; numRowCounter = 0; } private ConfusionMatrix confusion() { return confusion; } private static List createLabels(int numClasses) { if (numClasses == 1) numClasses = 2; //Binary (single output variable) case... List list = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { list.add(String.valueOf(i)); } return list; } private static List createLabelsFromMap(Map labels) { int size = labels.size(); List labelsList = new ArrayList<>(size); for (int i = 0; i < size; i++) { String str = labels.get(i); if (str == null) throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size - 1) + ")"); labelsList.add(str); } return labelsList; } private void createConfusion(int nClasses) { List classes = new ArrayList<>(); for (int i = 0; i < nClasses; i++) { classes.add(i); } confusion = new ConfusionMatrix<>(classes); } /** * Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
* For DL4J, this can be left as the default setting (axis = 1).
* Axis should be set as follows:
* For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
* For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
* For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
* For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
* * @param axis Axis to use for evaluation */ public void setAxis(int axis){ this.axis = axis; } /** * Get the axis - see {@link #setAxis(int)} for details */ public int getAxis(){ return axis; } /** * Collects statistics on the real outcomes vs the * guesses. This is for logistic outcome matrices. *

* Note that an IllegalArgumentException is thrown if the two passed in * matrices aren't the same length. * * @param realOutcomes the real outcomes (labels - usually binary) * @param guesses the guesses/prediction (usually a probability vector) */ public void eval(INDArray realOutcomes, INDArray guesses) { eval(realOutcomes, guesses, (List) null); } /** * Evaluate the network, with optional metadata * * @param labels Data labels * @param predictions Network predictions * @param recordMetaData Optional; may be null. If not null, should have size equal to the number of outcomes/guesses * */ @Override public void eval(INDArray labels, INDArray predictions, INDArray mask, final List recordMetaData) { Triple p = BaseEvaluation.reshapeAndExtractNotMasked(labels, predictions, mask, axis); if(p == null){ //All values masked out; no-op return; } INDArray labels2d = p.getFirst(); INDArray predictions2d = p.getSecond(); INDArray maskArray = p.getThird(); Preconditions.checkState(maskArray == null, "Per-output masking for Evaluation is not supported"); //Check for NaNs in predictions - without this, evaulation could silently be intepreted as class 0 prediction due to argmax long count = Nd4j.getExecutioner().execAndReturn(new MatchCondition(predictions2d, Conditions.isNan())).getFinalResult().longValue(); Preconditions.checkState(count == 0, "Cannot perform evaluation with NaNs present in predictions:" + " %s NaNs present in predictions INDArray", count); // Add the number of rows to numRowCounter numRowCounter += labels2d.size(0); if(labels2d.dataType() != predictions2d.dataType()) labels2d = labels2d.castTo(predictions2d.dataType()); // If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from if (confusion == null) { int nClasses = labels2d.columns(); if (nClasses == 1) nClasses = 2; //Binary (single output variable) case if(labelsList == null || labelsList.isEmpty()) { labelsList = new ArrayList<>(nClasses); for (int i = 0; i < nClasses; i++) labelsList.add(String.valueOf(i)); } createConfusion(nClasses); } // Length of real labels must be same as length of predicted labels if (!Arrays.equals(labels2d.shape(),predictions2d.shape())) { throw new IllegalArgumentException("Unable to evaluate. Predictions and labels arrays are not same shape." + " Predictions shape: " + Arrays.toString(predictions2d.shape()) + ", Labels shape: " + Arrays.toString(labels2d.shape())); } // For each row get the most probable label (column) from prediction and assign as guessMax // For each row get the column of the true label and assign as currMax final int nCols = labels2d.columns(); final int nRows = labels2d.rows(); if (nCols == 1) { INDArray binaryGuesses = predictions2d.gt(binaryDecisionThreshold == null ? 0.5 : binaryDecisionThreshold).castTo(predictions.dataType()); INDArray notLabel = labels2d.rsub(1.0); //Invert entries (assuming 1 and 0) INDArray notGuess = binaryGuesses.rsub(1.0); //tp: predicted = 1, actual = 1 int tp = labels2d.mul(binaryGuesses).castTo(DataType.INT).sumNumber().intValue(); //fp: predicted = 1, actual = 0 int fp = notLabel.mul(binaryGuesses).castTo(DataType.INT).sumNumber().intValue(); //fn: predicted = 0, actual = 1 int fn = notGuess.mul(labels2d).castTo(DataType.INT).sumNumber().intValue(); int tn = nRows - tp - fp - fn; confusion().add(1, 1, tp); confusion().add(1, 0, fn); confusion().add(0, 1, fp); confusion().add(0, 0, tn); truePositives.incrementCount(1, tp); falsePositives.incrementCount(1, fp); falseNegatives.incrementCount(1, fn); trueNegatives.incrementCount(1, tn); truePositives.incrementCount(0, tn); falsePositives.incrementCount(0, fn); falseNegatives.incrementCount(0, fp); trueNegatives.incrementCount(0, tp); if (recordMetaData != null) { for (int i = 0; i < binaryGuesses.size(0); i++) { if (i >= recordMetaData.size()) break; int actual = labels2d.getDouble(0) == 0.0 ? 0 : 1; int predicted = binaryGuesses.getDouble(0) == 0.0 ? 0 : 1; addToMetaConfusionMatrix(actual, predicted, recordMetaData.get(i)); } } } else { INDArray guessIndex; if (binaryDecisionThreshold != null) { if (nCols != 2) { throw new IllegalStateException("Binary decision threshold is set, but number of columns for " + "predictions is " + nCols + ". Binary decision threshold can only be used for binary " + "prediction cases"); } INDArray pClass1 = predictions2d.getColumn(1); guessIndex = pClass1.gt(binaryDecisionThreshold); } else if (costArray != null) { //With a cost array: do argmax(cost * probability) instead of just argmax(probability) guessIndex = Nd4j.argMax(predictions2d.mulRowVector(costArray.castTo(predictions2d.dataType())), 1); } else { //Standard case: argmax guessIndex = Nd4j.argMax(predictions2d, 1); } INDArray realOutcomeIndex = Nd4j.argMax(labels2d, 1); val nExamples = guessIndex.length(); for (int i = 0; i < nExamples; i++) { int actual = (int) realOutcomeIndex.getDouble(i); int predicted = (int) guessIndex.getDouble(i); confusion().add(actual, predicted); if (recordMetaData != null && recordMetaData.size() > i) { Object m = recordMetaData.get(i); addToMetaConfusionMatrix(actual, predicted, m); } // instead of looping through each label for confusion // matrix, instead infer those values by determining if true/false negative/positive, // then just add across matrix // if actual == predicted, then it's a true positive, assign true negative to every other label if (actual == predicted) { truePositives.incrementCount(actual, 1); for (int col = 0; col < nCols; col++) { if (col == actual) { continue; } trueNegatives.incrementCount(col, 1); // all cols prior } } else { falsePositives.incrementCount(predicted, 1); falseNegatives.incrementCount(actual, 1); // first determine intervals for adding true negatives int lesserIndex, greaterIndex; if (actual < predicted) { lesserIndex = actual; greaterIndex = predicted; } else { lesserIndex = predicted; greaterIndex = actual; } // now loop through intervals for (int col = 0; col < lesserIndex; col++) { trueNegatives.incrementCount(col, 1); // all cols prior } for (int col = lesserIndex + 1; col < greaterIndex; col++) { trueNegatives.incrementCount(col, 1); // all cols after } for (int col = greaterIndex + 1; col < nCols; col++) { trueNegatives.incrementCount(col, 1); // all cols after } } } } if (nCols > 1 && topN > 1) { //Calculate top N accuracy //TODO: this could be more efficient INDArray realOutcomeIndex = Nd4j.argMax(labels2d, 1); val nExamples = realOutcomeIndex.length(); for (int i = 0; i < nExamples; i++) { int labelIdx = (int) realOutcomeIndex.getDouble(i); double prob = predictions2d.getDouble(i, labelIdx); INDArray row = predictions2d.getRow(i); int countGreaterThan = (int) Nd4j.getExecutioner() .exec(new MatchCondition(row, Conditions.greaterThan(prob))) .getDouble(0); if (countGreaterThan < topN) { //For example, for top 3 accuracy: can have at most 2 other probabilities larger topNCorrectCount++; } topNTotalCount++; } } } /** * Evaluate a single prediction (one prediction at a time) * * @param predictedIdx Index of class predicted by the network * @param actualIdx Index of actual class */ public void eval(int predictedIdx, int actualIdx) { // Add the number of rows to numRowCounter numRowCounter++; // If confusion is null, then Evaluation is instantiated without providing the classes if (confusion == null) { throw new UnsupportedOperationException( "Cannot evaluate single example without initializing confusion matrix first"); } addToConfusion(actualIdx, predictedIdx); // If they are equal if (predictedIdx == actualIdx) { // Then add 1 to True Positive // (For a particular label) incrementTruePositives(predictedIdx); // And add 1 for each negative class that is accurately predicted (True Negative) //(For a particular label) for (Integer clazz : confusion().getClasses()) { if (clazz != predictedIdx) trueNegatives.incrementCount(clazz, 1.0f); } } else { // Otherwise the real label is predicted as negative (False Negative) incrementFalseNegatives(actualIdx); // Otherwise the prediction is predicted as falsely positive (False Positive) incrementFalsePositives(predictedIdx); // Otherwise true negatives for (Integer clazz : confusion().getClasses()) { if (clazz != predictedIdx && clazz != actualIdx) trueNegatives.incrementCount(clazz, 1.0f); } } } /** * Report the classification statistics as a String * @return Classification statistics as a String */ public String stats() { return stats(false); } /** * Method to obtain the classification report as a String * * @param suppressWarnings whether or not to output warnings related to the evaluation results * @return A (multi-line) String with accuracy, precision, recall, f1 score etc */ public String stats(boolean suppressWarnings) { return stats(suppressWarnings, numClasses() <= CONFUSION_PRINT_MAX_CLASSES, numClasses() > CONFUSION_PRINT_MAX_CLASSES); } /** * Method to obtain the classification report as a String * * @param suppressWarnings whether or not to output warnings related to the evaluation results * @param includeConfusion whether the confusion matrix should be included it the returned stats or not * @return A (multi-line) String with accuracy, precision, recall, f1 score etc */ public String stats(boolean suppressWarnings, boolean includeConfusion){ return stats(suppressWarnings, includeConfusion, false); } private String stats(boolean suppressWarnings, boolean includeConfusion, boolean logConfusionSizeWarning){ if(numRowCounter == 0){ return "Evaluation: No data available (no evaluation has been performed)"; } StringBuilder builder = new StringBuilder().append("\n"); StringBuilder warnings = new StringBuilder(); ConfusionMatrix confusion = confusion(); if(confusion == null){ confusion = new ConfusionMatrix<>(); //Empty } List classes = confusion.getClasses(); List falsePositivesWarningClasses = new ArrayList<>(); List falseNegativesWarningClasses = new ArrayList<>(); for (Integer clazz : classes) { //Output possible warnings regarding precision/recall calculation if (!suppressWarnings && truePositives.getCount(clazz) == 0) { if (falsePositives.getCount(clazz) == 0) { falsePositivesWarningClasses.add(clazz); } if (falseNegatives.getCount(clazz) == 0) { falseNegativesWarningClasses.add(clazz); } } } if (!falsePositivesWarningClasses.isEmpty()) { warningHelper(warnings, falsePositivesWarningClasses, "precision"); } if (!falseNegativesWarningClasses.isEmpty()) { warningHelper(warnings, falseNegativesWarningClasses, "recall"); } int nClasses = confusion.getClasses().size(); DecimalFormat df = new DecimalFormat("0.0000"); double acc = accuracy(); double precision = precision(); //Macro precision for N>2, or binary class 1 (only) precision by default double recall = recall(); //Macro recall for N>2, or binary class 1 (only) precision by default double f1 = f1(); //Macro F1 for N>2, or binary class 1 (only) precision by default builder.append("\n========================Evaluation Metrics========================"); builder.append("\n # of classes: ").append(nClasses); builder.append("\n Accuracy: ").append(format(df, acc)); if (topN > 1) { double topNAcc = topNAccuracy(); builder.append("\n Top ").append(topN).append(" Accuracy: ").append(format(df, topNAcc)); } builder.append("\n Precision: ").append(format(df, precision)); if (nClasses > 2 && averagePrecisionNumClassesExcluded() > 0) { int ex = averagePrecisionNumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); if (ex > 1) builder.append("es"); builder.append(" excluded from average)"); } builder.append("\n Recall: ").append(format(df, recall)); if (nClasses > 2 && averageRecallNumClassesExcluded() > 0) { int ex = averageRecallNumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); if (ex > 1) builder.append("es"); builder.append(" excluded from average)"); } builder.append("\n F1 Score: ").append(format(df, f1)); if (nClasses > 2 && averageF1NumClassesExcluded() > 0) { int ex = averageF1NumClassesExcluded(); builder.append("\t(").append(ex).append(" class"); if (ex > 1) builder.append("es"); builder.append(" excluded from average)"); } if (nClasses > 2 || binaryPositiveClass == null) { builder.append("\nPrecision, recall & F1: macro-averaged (equally weighted avg. of ").append(nClasses) .append(" classes)"); } if(nClasses == 2 && binaryPositiveClass != null){ builder.append("\nPrecision, recall & F1: reported for positive class (class ").append(binaryPositiveClass); if(labelsList != null){ builder.append(" - \"").append(labelsList.get(binaryPositiveClass)).append("\""); } builder.append(") only"); } if (binaryDecisionThreshold != null) { builder.append("\nBinary decision threshold: ").append(binaryDecisionThreshold); } if (costArray != null) { builder.append("\nCost array: ").append(Arrays.toString(costArray.dup().data().asFloat())); } //Note that we could report micro-averaged too - but these are the same as accuracy //"Note that for “micro-averaging in a multiclass setting with all labels included will produce equal precision, recall and F," //http://scikit-learn.org/stable/modules/model_evaluation.html builder.append("\n\n"); builder.append(warnings); if(includeConfusion){ builder.append("\n=========================Confusion Matrix=========================\n"); builder.append(confusionMatrix()); } else if(logConfusionSizeWarning){ builder.append("\n\nNote: Confusion matrix not generated due to space requirements for ") .append(nClasses).append(" classes.\n") .append("Use stats(false,true) to generate anyway"); } builder.append("\n=================================================================="); return builder.toString(); } /** * Get the confusion matrix as a String * @return Confusion matrix as a String */ public String confusionMatrix(){ int nClasses = numClasses(); if(confusion == null){ return "Confusion matrix: "; } //First: work out the maximum count List classes = confusion.getClasses(); int maxCount = 1; for (Integer i : classes) { for (Integer j : classes) { int count = confusion().getCount(i, j); maxCount = Math.max(maxCount, count); } } maxCount = Math.max(maxCount, nClasses); //Include this as header might be bigger than actual values int numDigits = (int)Math.ceil(Math.log10(maxCount)); if(numDigits < 1) numDigits = 1; String digitFormat = "%" + (numDigits+1) + "d"; StringBuilder sb = new StringBuilder(); //Build header: for( int i=0; i clazz) return labelsList.get(clazz); return clazz.toString(); } private void warningHelper(StringBuilder warnings, List list, String metric) { warnings.append("Warning: ").append(list.size()).append(" class"); String wasWere; if (list.size() == 1) { wasWere = "was"; } else { wasWere = "were"; warnings.append("es"); } warnings.append(" ").append(wasWere); warnings.append(" never predicted by the model and ").append(wasWere).append(" excluded from average ") .append(metric); if(list.size() <= maxWarningClassesToPrint) { warnings.append("\nClasses excluded from average ").append(metric).append(": ") .append(list).append("\n"); } } /** * Returns the precision for a given class label * * @param classLabel the label * @return the precision for the label */ public double precision(Integer classLabel) { return precision(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the precision for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return the precision for the label */ public double precision(Integer classLabel, double edgeCase) { Preconditions.checkState(numRowCounter > 0, "Cannot get precision: no evaluation has been performed"); double tpCount = truePositives.getCount(classLabel); double fpCount = falsePositives.getCount(classLabel); return EvaluationUtils.precision((long) tpCount, (long) fpCount, edgeCase); } /** * Precision based on guesses so far.
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}
* * @return the total precision based on guesses so far */ public double precision() { if(binaryPositiveClass != null && numClasses() == 2){ return precision(binaryPositiveClass); } return precision(EvaluationAveraging.Macro); } /** * Calculate the average precision for all classes. Can specify whether macro or micro averaging should be used * NOTE: if any classes have tp=0 and fp=0, (precision=0/0) these are excluded from the average * * @param averaging Averaging method - macro or micro * @return Average precision */ public double precision(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get precision: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroPrecision = 0.0; int count = 0; for (int i = 0; i < nClasses; i++) { double thisClassPrec = precision(i, -1); if (thisClassPrec != -1) { macroPrecision += thisClassPrec; count++; } } macroPrecision /= count; return macroPrecision; } else if (averaging == EvaluationAveraging.Micro) { long tpCount = 0; long fpCount = 0; for (int i = 0; i < nClasses; i++) { tpCount += truePositives.getCount(i); fpCount += falsePositives.getCount(i); } return EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * When calculating the (macro) average precision, how many classes are excluded from the average due to * no predictions - i.e., precision would be the edge case of 0/0 * * @return Number of classes excluded from the average precision */ public int averagePrecisionNumClassesExcluded() { return numClassesExcluded("precision"); } /** * When calculating the (macro) average Recall, how many classes are excluded from the average due to * no predictions - i.e., recall would be the edge case of 0/0 * * @return Number of classes excluded from the average recall */ public int averageRecallNumClassesExcluded() { return numClassesExcluded("recall"); } /** * When calculating the (macro) average F1, how many classes are excluded from the average due to * no predictions - i.e., F1 would be calculated from a precision or recall of 0/0 * * @return Number of classes excluded from the average F1 */ public int averageF1NumClassesExcluded() { return numClassesExcluded("f1"); } /** * When calculating the (macro) average FBeta, how many classes are excluded from the average due to * no predictions - i.e., FBeta would be calculated from a precision or recall of 0/0 * * @return Number of classes excluded from the average FBeta */ public int averageFBetaNumClassesExcluded() { return numClassesExcluded("fbeta"); } private int numClassesExcluded(String metric) { int countExcluded = 0; int nClasses = confusion().getClasses().size(); for (int i = 0; i < nClasses; i++) { double d; switch (metric.toLowerCase()) { case "precision": d = precision(i, -1); break; case "recall": d = recall(i, -1); break; case "f1": case "fbeta": d = fBeta(1.0, i, -1); break; default: throw new RuntimeException("Unknown metric: " + metric); } if (d == -1) { countExcluded++; } } return countExcluded; } /** * Returns the recall for a given label * * @param classLabel the label * @return Recall rate as a double */ public double recall(int classLabel) { return recall(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the recall for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return Recall rate as a double */ public double recall(int classLabel, double edgeCase) { Preconditions.checkState(numRowCounter > 0, "Cannot get recall: no evaluation has been performed"); double tpCount = truePositives.getCount(classLabel); double fnCount = falseNegatives.getCount(classLabel); return EvaluationUtils.recall((long) tpCount, (long) fnCount, edgeCase); } /** * Recall based on guesses so far
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}
* * @return the recall for the outcomes */ public double recall() { if(binaryPositiveClass != null && numClasses() == 2){ return recall(binaryPositiveClass); } return recall(EvaluationAveraging.Macro); } /** * Calculate the average recall for all classes - can specify whether macro or micro averaging should be used * NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average * * @param averaging Averaging method - macro or micro * @return Average recall */ public double recall(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get recall: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroRecall = 0.0; int count = 0; for (int i = 0; i < nClasses; i++) { double thisClassRecall = recall(i, -1); if (thisClassRecall != -1) { macroRecall += thisClassRecall; count++; } } macroRecall /= count; return macroRecall; } else if (averaging == EvaluationAveraging.Micro) { long tpCount = 0; long fnCount = 0; for (int i = 0; i < nClasses; i++) { tpCount += truePositives.getCount(i); fnCount += falseNegatives.getCount(i); } return EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * Returns the false positive rate for a given label * * @param classLabel the label * @return fpr as a double */ public double falsePositiveRate(int classLabel) { return falsePositiveRate(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the false positive rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fpr as a double */ public double falsePositiveRate(int classLabel, double edgeCase) { Preconditions.checkState(numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed"); double fpCount = falsePositives.getCount(classLabel); double tnCount = trueNegatives.getCount(classLabel); return EvaluationUtils.falsePositiveRate((long) fpCount, (long) tnCount, edgeCase); } /** * False positive rate based on guesses so far
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged false positive rate, equivalent to * {@code falsePositiveRate(EvaluationAveraging.Macro)}
* * @return the fpr for the outcomes */ public double falsePositiveRate() { if(binaryPositiveClass != null && numClasses() == 2){ return falsePositiveRate(binaryPositiveClass); } return falsePositiveRate(EvaluationAveraging.Macro); } /** * Calculate the average false positive rate across all classes. Can specify whether macro or micro averaging should be used * * @param averaging Averaging method - macro or micro * @return Average false positive rate */ public double falsePositiveRate(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroFPR = 0.0; for (int i = 0; i < nClasses; i++) { macroFPR += falsePositiveRate(i); } macroFPR /= nClasses; return macroFPR; } else if (averaging == EvaluationAveraging.Micro) { long fpCount = 0; long tnCount = 0; for (int i = 0; i < nClasses; i++) { fpCount += falsePositives.getCount(i); tnCount += trueNegatives.getCount(i); } return EvaluationUtils.falsePositiveRate(fpCount, tnCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * Returns the false negative rate for a given label * * @param classLabel the label * @return fnr as a double */ public double falseNegativeRate(Integer classLabel) { return falseNegativeRate(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the false negative rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { Preconditions.checkState(numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed"); double fnCount = falseNegatives.getCount(classLabel); double tpCount = truePositives.getCount(classLabel); return EvaluationUtils.falseNegativeRate((long) fnCount, (long) tpCount, edgeCase); } /** * False negative rate based on guesses so far * Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged false negative rate, equivalent to * {@code falseNegativeRate(EvaluationAveraging.Macro)}
* * @return the fnr for the outcomes */ public double falseNegativeRate() { if(binaryPositiveClass != null && numClasses() == 2){ return falseNegativeRate(binaryPositiveClass); } return falseNegativeRate(EvaluationAveraging.Macro); } /** * Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used * * @param averaging Averaging method - macro or micro * @return Average false negative rate */ public double falseNegativeRate(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroFNR = 0.0; for (int i = 0; i < nClasses; i++) { macroFNR += falseNegativeRate(i); } macroFNR /= nClasses; return macroFNR; } else if (averaging == EvaluationAveraging.Micro) { long fnCount = 0; long tnCount = 0; for (int i = 0; i < nClasses; i++) { fnCount += falseNegatives.getCount(i); tnCount += trueNegatives.getCount(i); } return EvaluationUtils.falseNegativeRate(fnCount, tnCount, DEFAULT_EDGE_VALUE); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged false alarm rate) * * @return the fpr for the outcomes */ public double falseAlarmRate() { if(binaryPositiveClass != null && numClasses() == 2){ return (falsePositiveRate(binaryPositiveClass) + falseNegativeRate(binaryPositiveClass)) / 2.0; } return (falsePositiveRate() + falseNegativeRate()) / 2.0; } /** * Calculate f1 score for a given class * * @param classLabel the label to calculate f1 for * @return the f1 score for the given label */ public double f1(int classLabel) { return fBeta(1.0, classLabel); } /** * Calculate the f_beta for a given class, where f_beta is defined as:
* (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
* F1 is a special case of f_beta, with beta=1.0 * * @param beta Beta value to use * @param classLabel Class label * @return F_beta */ public double fBeta(double beta, int classLabel) { return fBeta(beta, classLabel, 0.0); } /** * Calculate the f_beta for a given class, where f_beta is defined as:
* (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
* F1 is a special case of f_beta, with beta=1.0 * * @param beta Beta value to use * @param classLabel Class label * @param defaultValue Default value to use when precision or recall is undefined (0/0 for prec. or recall) * @return F_beta */ public double fBeta(double beta, int classLabel, double defaultValue) { Preconditions.checkState(numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed"); double precision = precision(classLabel, -1); double recall = recall(classLabel, -1); if (precision == -1 || recall == -1) { return defaultValue; } return EvaluationUtils.fBeta(beta, precision, recall); } /** * Calculate the F1 score
* F1 score is defined as:
* TP: true positive
* FP: False Positive
* FN: False Negative
* F1 score: 2 * TP / (2TP + FP + FN)
*
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
* 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}
* * @return the f1 score or harmonic mean of precision and recall based on current guesses */ public double f1() { if(binaryPositiveClass != null && numClasses() == 2){ return f1(binaryPositiveClass); } return f1(EvaluationAveraging.Macro); } /** * Calculate the average F1 score across all classes, using macro or micro averaging * * @param averaging Averaging method to use */ public double f1(EvaluationAveraging averaging) { return fBeta(1.0, averaging); } /** * Calculate the average F_beta score across all classes, using macro or micro averaging * * @param beta Beta value to use * @param averaging Averaging method to use */ public double fBeta(double beta, EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (nClasses == 2) { return EvaluationUtils.fBeta(beta, (long) truePositives.getCount(1), (long) falsePositives.getCount(1), (long) falseNegatives.getCount(1)); } if (averaging == EvaluationAveraging.Macro) { double macroFBeta = 0.0; int count = 0; for (int i = 0; i < nClasses; i++) { double thisFBeta = fBeta(beta, i, -1); if (thisFBeta != -1) { macroFBeta += thisFBeta; count++; } } macroFBeta /= count; return macroFBeta; } else if (averaging == EvaluationAveraging.Micro) { long tpCount = 0; long fpCount = 0; long fnCount = 0; for (int i = 0; i < nClasses; i++) { tpCount += truePositives.getCount(i); fpCount += falsePositives.getCount(i); fnCount += falseNegatives.getCount(i); } return EvaluationUtils.fBeta(beta, tpCount, fpCount, fnCount); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * Calculate the G-measure for the given output * * @param output The specified output * @return The G-measure for the specified output */ public double gMeasure(int output) { Preconditions.checkState(numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed"); double precision = precision(output); double recall = recall(output); return EvaluationUtils.gMeasure(precision, recall); } /** * Calculates the average G measure for all outputs using micro or macro averaging * * @param averaging Averaging method to use * @return Average G measure */ public double gMeasure(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroGMeasure = 0.0; for (int i = 0; i < nClasses; i++) { macroGMeasure += gMeasure(i); } macroGMeasure /= nClasses; return macroGMeasure; } else if (averaging == EvaluationAveraging.Micro) { long tpCount = 0; long fpCount = 0; long fnCount = 0; for (int i = 0; i < nClasses; i++) { tpCount += truePositives.getCount(i); fpCount += falsePositives.getCount(i); fnCount += falseNegatives.getCount(i); } double precision = EvaluationUtils.precision(tpCount, fpCount, DEFAULT_EDGE_VALUE); double recall = EvaluationUtils.recall(tpCount, fnCount, DEFAULT_EDGE_VALUE); return EvaluationUtils.gMeasure(precision, recall); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * Accuracy: * (TP + TN) / (P + N) * * @return the accuracy of the guesses so far */ public double accuracy() { Preconditions.checkState(numRowCounter > 0, "Cannot get accuracy: no evaluation has been performed"); //Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total int nClasses = confusion().getClasses().size(); int countCorrect = 0; for (int i = 0; i < nClasses; i++) { countCorrect += confusion().getCount(i, i); } return countCorrect / (double) getNumRowCounter(); } /** * Top N accuracy of the predictions so far. For top N = 1 (default), equivalent to {@link #accuracy()} * @return Top N accuracy */ public double topNAccuracy() { if (topN <= 1) return accuracy(); if (topNTotalCount == 0) return 0.0; return topNCorrectCount / (double) topNTotalCount; } /** * Calculate the binary Mathews correlation coefficient, for the specified class.
* MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
* * @param classIdx Class index to calculate Matthews correlation coefficient for */ public double matthewsCorrelation(int classIdx) { Preconditions.checkState(numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed"); return EvaluationUtils.matthewsCorrelation((long) truePositives.getCount(classIdx), (long) falsePositives.getCount(classIdx), (long) falseNegatives.getCount(classIdx), (long) trueNegatives.getCount(classIdx)); } /** * Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
* MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
* Note: This is NOT the same as the multi-class Matthews correlation coefficient * * @param averaging Averaging approach * @return Average */ public double matthewsCorrelation(EvaluationAveraging averaging) { Preconditions.checkState(numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroMatthewsCorrelation = 0.0; for (int i = 0; i < nClasses; i++) { macroMatthewsCorrelation += matthewsCorrelation(i); } macroMatthewsCorrelation /= nClasses; return macroMatthewsCorrelation; } else if (averaging == EvaluationAveraging.Micro) { long tpCount = 0; long fpCount = 0; long fnCount = 0; long tnCount = 0; for (int i = 0; i < nClasses; i++) { tpCount += truePositives.getCount(i); fpCount += falsePositives.getCount(i); fnCount += falseNegatives.getCount(i); tnCount += trueNegatives.getCount(i); } return EvaluationUtils.matthewsCorrelation(tpCount, fpCount, fnCount, tnCount); } else { throw new UnsupportedOperationException("Unknown averaging approach: " + averaging); } } /** * True positives: correctly rejected * * @return the total true positives so far */ public Map truePositives() { return convertToMap(truePositives, confusion().getClasses().size()); } /** * True negatives: correctly rejected * * @return the total true negatives so far */ public Map trueNegatives() { return convertToMap(trueNegatives, confusion().getClasses().size()); } /** * False positive: wrong guess * * @return the count of the false positives */ public Map falsePositives() { return convertToMap(falsePositives, confusion().getClasses().size()); } /** * False negatives: correctly rejected * * @return the total false negatives so far */ public Map falseNegatives() { return convertToMap(falseNegatives, confusion().getClasses().size()); } /** * Total negatives true negatives + false negatives * * @return the overall negative count */ public Map negative() { return addMapsByKey(trueNegatives(), falsePositives()); } /** * Returns all of the positive guesses: * true positive + false negative */ public Map positive() { return addMapsByKey(truePositives(), falseNegatives()); } private Map convertToMap(Counter counter, int maxCount) { Map map = new HashMap<>(); for (int i = 0; i < maxCount; i++) { map.put(i, (int) counter.getCount(i)); } return map; } private Map addMapsByKey(Map first, Map second) { Map out = new HashMap<>(); Set keys = new HashSet<>(first.keySet()); keys.addAll(second.keySet()); for (Integer i : keys) { Integer f = first.get(i); Integer s = second.get(i); if (f == null) f = 0; if (s == null) s = 0; out.put(i, f + s); } return out; } // Incrementing counters public void incrementTruePositives(Integer classLabel) { truePositives.incrementCount(classLabel, 1.0f); } public void incrementTrueNegatives(Integer classLabel) { trueNegatives.incrementCount(classLabel, 1.0f); } public void incrementFalseNegatives(Integer classLabel) { falseNegatives.incrementCount(classLabel, 1.0f); } public void incrementFalsePositives(Integer classLabel) { falsePositives.incrementCount(classLabel, 1.0f); } // Other misc methods /** * Adds to the confusion matrix * * @param real the actual guess * @param guess the system guess */ public void addToConfusion(Integer real, Integer guess) { confusion().add(real, guess); } /** * Returns the number of times the given label * has actually occurred * * @param clazz the label * @return the number of times the label * actually occurred */ public int classCount(Integer clazz) { return confusion().getActualTotal(clazz); } public int getNumRowCounter() { return numRowCounter; } /** * Return the number of correct predictions according to top N value. For top N = 1 (default) this is equivalent to * the number of correct predictions * @return Number of correct top N predictions */ public int getTopNCorrectCount() { if (confusion == null) return 0; if (topN <= 1) { int nClasses = confusion().getClasses().size(); int countCorrect = 0; for (int i = 0; i < nClasses; i++) { countCorrect += confusion().getCount(i, i); } return countCorrect; } return topNCorrectCount; } /** * Return the total number of top N evaluations. Most of the time, this is exactly equal to {@link #getNumRowCounter()}, * but may differ in the case of using {@link #eval(int, int)} as top N accuracy cannot be calculated in that case * (i.e., requires the full probability distribution, not just predicted/actual indices) * @return Total number of top N predictions */ public int getTopNTotalCount() { if (topN <= 1) { return getNumRowCounter(); } return topNTotalCount; } public String getClassLabel(Integer clazz) { return resolveLabelForClass(clazz); } /** * Returns the confusion matrix variable * * @return confusion matrix variable for this evaluation */ public ConfusionMatrix getConfusionMatrix() { return confusion; } /** * Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts * etc from both * * @param other Evaluation object to merge into this one. */ @Override public void merge(Evaluation other) { if (other == null) return; truePositives.incrementAll(other.truePositives); falsePositives.incrementAll(other.falsePositives); trueNegatives.incrementAll(other.trueNegatives); falseNegatives.incrementAll(other.falseNegatives); if (confusion == null) { if (other.confusion != null) confusion = new ConfusionMatrix<>(other.confusion); } else { if (other.confusion != null) confusion().add(other.confusion); } numRowCounter += other.numRowCounter; if (labelsList.isEmpty()) labelsList.addAll(other.labelsList); if (topN != other.topN) { log.warn("Different topN values ({} vs {}) detected during Evaluation merging. Top N accuracy may not be accurate.", topN, other.topN); } this.topNCorrectCount += other.topNCorrectCount; this.topNTotalCount += other.topNTotalCount; } /** * Get a String representation of the confusion matrix */ public String confusionToString() { int nClasses = confusion().getClasses().size(); //First: work out the longest label size int maxLabelSize = 0; for (String s : labelsList) { maxLabelSize = Math.max(maxLabelSize, s.length()); } //Build the formatting for the rows: int labelSize = Math.max(maxLabelSize + 5, 10); StringBuilder sb = new StringBuilder(); sb.append("%-3d"); sb.append("%-"); sb.append(labelSize); sb.append("s | "); StringBuilder headerFormat = new StringBuilder(); headerFormat.append(" %-").append(labelSize).append("s "); for (int i = 0; i < nClasses; i++) { sb.append("%7d"); headerFormat.append("%7d"); } String rowFormat = sb.toString(); StringBuilder out = new StringBuilder(); //First: header row Object[] headerArgs = new Object[nClasses + 1]; headerArgs[0] = "Predicted:"; for (int i = 0; i < nClasses; i++) headerArgs[i + 1] = i; out.append(String.format(headerFormat.toString(), headerArgs)).append("\n"); //Second: divider rows out.append(" Actual:\n"); //Finally: data rows for (int i = 0; i < nClasses; i++) { Object[] args = new Object[nClasses + 2]; args[0] = i; args[1] = labelsList.get(i); for (int j = 0; j < nClasses; j++) { args[j + 2] = confusion().getCount(i, j); } out.append(String.format(rowFormat, args)); out.append("\n"); } return out.toString(); } private void addToMetaConfusionMatrix(int actual, int predicted, Object metaData) { if (confusionMatrixMetaData == null) { confusionMatrixMetaData = new HashMap<>(); } Pair p = new Pair<>(actual, predicted); List list = confusionMatrixMetaData.get(p); if (list == null) { list = new ArrayList<>(); confusionMatrixMetaData.put(p, list); } list.add(metaData); } /** * Get a list of prediction errors, on a per-record basis
*

* Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)} * Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in * splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, * via {@link #getConfusionMatrix()} * * @return A list of prediction errors, or null if no metadata has been recorded */ public List getPredictionErrors() { if (this.confusionMatrixMetaData == null) return null; List list = new ArrayList<>(); List, List>> sorted = new ArrayList<>(confusionMatrixMetaData.entrySet()); Collections.sort(sorted, new Comparator, List>>() { @Override public int compare(Map.Entry, List> o1, Map.Entry, List> o2) { Pair p1 = o1.getKey(); Pair p2 = o2.getKey(); int order = Integer.compare(p1.getFirst(), p2.getFirst()); if (order != 0) return order; order = Integer.compare(p1.getSecond(), p2.getSecond()); return order; } }); for (Map.Entry, List> entry : sorted) { Pair p = entry.getKey(); if (p.getFirst().equals(p.getSecond())) { //predicted = actual -> not an error -> skip continue; } for (Object m : entry.getValue()) { list.add(new Prediction(p.getFirst(), p.getSecond(), m)); } } return list; } /** * Get a list of predictions, for all data with the specified actual class, regardless of the predicted * class. *

* Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)} * Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in * splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, * via {@link #getConfusionMatrix()} * * @param actualClass Actual class to get predictions for * @return List of predictions, or null if the "evaluate with metadata" method was not used */ public List getPredictionsByActualClass(int actualClass) { if (confusionMatrixMetaData == null) return null; List out = new ArrayList<>(); for (Map.Entry, List> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted) if (entry.getKey().getFirst() == actualClass) { int actual = entry.getKey().getFirst(); int predicted = entry.getKey().getSecond(); for (Object m : entry.getValue()) { out.add(new Prediction(actual, predicted, m)); } } } return out; } /** * Get a list of predictions, for all data with the specified predicted class, regardless of the actual data * class. *

* Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)} * Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in * splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, * via {@link #getConfusionMatrix()} * * @param predictedClass Actual class to get predictions for * @return List of predictions, or null if the "evaluate with metadata" method was not used */ public List getPredictionByPredictedClass(int predictedClass) { if (confusionMatrixMetaData == null) return null; List out = new ArrayList<>(); for (Map.Entry, List> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted) if (entry.getKey().getSecond() == predictedClass) { int actual = entry.getKey().getFirst(); int predicted = entry.getKey().getSecond(); for (Object m : entry.getValue()) { out.add(new Prediction(actual, predicted, m)); } } } return out; } /** * Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair) * * @param actualClass Actual class * @param predictedClass Predicted class * @return List of predictions that match the specified actual/predicted classes, or null if the "evaluate with metadata" method was not used */ public List getPredictions(int actualClass, int predictedClass) { if (confusionMatrixMetaData == null) return null; List out = new ArrayList<>(); List list = confusionMatrixMetaData.get(new Pair<>(actualClass, predictedClass)); if (list == null) return out; for (Object meta : list) { out.add(new Prediction(actualClass, predictedClass, meta)); } return out; } public double scoreForMetric(Metric metric){ switch (metric){ case ACCURACY: return accuracy(); case F1: return f1(); case PRECISION: return precision(); case RECALL: return recall(); case GMEASURE: return gMeasure(EvaluationAveraging.Macro); case MCC: return matthewsCorrelation(EvaluationAveraging.Macro); default: throw new IllegalStateException("Unknown metric: " + metric); } } public static Evaluation fromJson(String json) { return fromJson(json, Evaluation.class); } public static Evaluation fromYaml(String yaml) { return fromYaml(yaml, Evaluation.class); } @Override public double getValue(IMetric metric){ if(metric instanceof Metric){ return scoreForMetric((Metric) metric); } else throw new IllegalStateException("Can't get value for non-evaluation Metric " + metric); } @Override public Evaluation newInstance() { return new Evaluation(axis, binaryPositiveClass, topN, labelsList, binaryDecisionThreshold, costArray, maxWarningClassesToPrint); } }