All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.eval.Evaluation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.eval;

import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.*;

/**
 * Evaluation metrics:
 * precision, recall, f1
 *
 * @author Adam Gibson
 */
public class Evaluation implements Serializable {

    protected Counter truePositives = new Counter<>();
    protected Counter falsePositives = new Counter<>();
    protected Counter trueNegatives = new Counter<>();
    protected Counter falseNegatives = new Counter<>();
    protected ConfusionMatrix confusion;
    protected int numRowCounter = 0;
    protected List labelsList = new ArrayList<>();
    protected static Logger log = LoggerFactory.getLogger(Evaluation.class);
    //What to output from the precision/recall function when we encounter an edge case
    protected static final double DEFAULT_EDGE_VALUE = 0.0;

    // Empty constructor
    public Evaluation() {
    }

    // Constructor that takes number of output classes

    /**
     * The number of classes to account
     * for in the evaluation
     * @param numClasses the number of classes to account for in the evaluation
     */
    public Evaluation(int numClasses) {
        this(createLabels(numClasses));
    }

    /**
     * The labels to include with the evaluation.
     * This constructor can be used for
     * generating labeled output rather than just
     * numbers for the labels
     * @param labels the labels to use
     *               for the output
     */
    public Evaluation(List labels) {
        this.labelsList = labels;
        if(labels != null){
            createConfusion(labels.size());
        }

    }

    /**
     * Use a map to generate labels
     * Pass in a label index with the actual label
     * you want to use for output
     * @param labels a map of label index to label value
     */
    public Evaluation(Map labels) {
        this(createLabelsFromMap(labels));
    }

    private static List createLabels(int numClasses){
        if(numClasses == 1) numClasses = 2; //Binary (single output variable) case...
        List list = new ArrayList<>(numClasses);
        for (int i = 0; i < numClasses; i++){
            list.add(String.valueOf(i));
        }
        return list;
    }

    private static List createLabelsFromMap(Map labels ){
        int size = labels.size();
        List labelsList = new ArrayList<>(size);
        for( int i = 0; i < size; i++) {
            String str = labels.get(i);
            if(str == null) throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size-1) + ")");
            labelsList.add(str);
        }
        return labelsList;
    }

    private void createConfusion(int nClasses) {
        List classes = new ArrayList<>();
        for (int i = 0; i < nClasses; i++) {
            classes.add(i);
        }

        confusion = new ConfusionMatrix<>(classes);
    }


    /**
     * Evaluate the output
     * using the given true labels,
     * the input to the multi layer network
     * and the multi layer network to
     * use for evaluation
     * @param trueLabels the labels to ise
     * @param input the input to the network to use
     *              for evaluation
     * @param network the network to use for output
     */
    public void eval(INDArray trueLabels,INDArray input,ComputationGraph network) {
        eval(trueLabels,network.output(false,input)[0]);
    }


    /**
     * Evaluate the output
     * using the given true labels,
     * the input to the multi layer network
     * and the multi layer network to
     * use for evaluation
     * @param trueLabels the labels to ise
     * @param input the input to the network to use
     *              for evaluation
     * @param network the network to use for output
     */
    public void eval(INDArray trueLabels,INDArray input,MultiLayerNetwork network) {
        eval(trueLabels,network.output(input, Layer.TrainingMode.TEST));
    }


    /**
     * Collects statistics on the real outcomes vs the
     * guesses. This is for logistic outcome matrices.
     * 

* Note that an IllegalArgumentException is thrown if the two passed in * matrices aren't the same length. * * @param realOutcomes the real outcomes (labels - usually binary) * @param guesses the guesses/prediction (usually a probability vector) */ public void eval(INDArray realOutcomes, INDArray guesses) { // Add the number of rows to numRowCounter numRowCounter += realOutcomes.shape()[0]; // If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from if (confusion == null) { int nClasses = realOutcomes.columns(); if(nClasses == 1) nClasses = 2; //Binary (single output variable) case labelsList = new ArrayList<>(nClasses); for( int i = 0; i classes = confusion.getClasses(); for (Integer clazz : classes) { actual = resolveLabelForClass(clazz); //Output confusion matrix for (Integer clazz2 : classes) { int count = confusion.getCount(clazz, clazz2); if (count != 0) { expected = resolveLabelForClass(clazz2); builder.append(String.format("Examples labeled as %s classified by model as %s: %d times%n", actual, expected, count)); } } //Output possible warnings regarding precision/recall calculation if (!suppressWarnings && truePositives.getCount(clazz) == 0) { if (falsePositives.getCount(clazz) == 0) { warnings.append(String.format("Warning: class %s was never predicted by the model. This class was excluded from the average precision%n", actual)); } if (falseNegatives.getCount(clazz) == 0) { warnings.append(String.format("Warning: class %s has never appeared as a true label. This class was excluded from the average recall%n", actual)); } } } builder.append("\n"); builder.append(warnings); DecimalFormat df = new DecimalFormat("#.####"); builder.append("\n==========================Scores========================================"); builder.append("\n Accuracy: ").append(df.format(accuracy())); builder.append("\n Precision: ").append(df.format(precision())); builder.append("\n Recall: ").append(df.format(recall())); builder.append("\n F1 Score: ").append(df.format(f1())); builder.append("\n========================================================================"); return builder.toString(); } private String resolveLabelForClass(Integer clazz) { if(labelsList != null && labelsList.size() > clazz ) return labelsList.get(clazz); return clazz.toString(); } /** * Returns the precision for a given label * * @param classLabel the label * @return the precision for the label */ public double precision(Integer classLabel) { return precision(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the precision for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return the precision for the label */ public double precision(Integer classLabel, double edgeCase) { double tpCount = truePositives.getCount(classLabel); double fpCount = falsePositives.getCount(classLabel); //Edge case if (tpCount == 0 && fpCount == 0) { return edgeCase; } return tpCount / (tpCount + fpCount); } /** * Precision based on guesses so far * Takes into account all known classes and outputs average precision across all of them * * @return the total precision based on guesses so far */ public double precision() { double precisionAcc = 0.0; int classCount = 0; for (Integer classLabel : confusion.getClasses()) { double precision = precision(classLabel, -1); if (precision != -1) { precisionAcc += precision(classLabel); classCount++; } } return precisionAcc / (double) classCount; } /** * Returns the recall for a given label * * @param classLabel the label * @return Recall rate as a double */ public double recall(Integer classLabel) { return recall(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the recall for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return Recall rate as a double */ public double recall(Integer classLabel, double edgeCase) { double tpCount = truePositives.getCount(classLabel); double fnCount = falseNegatives.getCount(classLabel); //Edge case if (tpCount == 0 && fnCount == 0) { return edgeCase; } return tpCount / (tpCount + fnCount); } /** * Recall based on guesses so far * Takes into account all known classes and outputs average recall across all of them * * @return the recall for the outcomes */ public double recall() { double recallAcc = 0.0; int classCount = 0; for (Integer classLabel : confusion.getClasses()) { double recall = recall(classLabel, -1.0); if (recall != -1.0) { recallAcc += recall(classLabel); classCount++; } } return recallAcc / (double) classCount; } /** * Returns the false positive rate for a given label * * @param classLabel the label * @return fpr as a double */ public double falsePositiveRate(Integer classLabel) { return recall(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the false positive rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fpr as a double */ public double falsePositiveRate(Integer classLabel, double edgeCase) { double fpCount = falsePositives.getCount(classLabel); double tnCount = trueNegatives.getCount(classLabel); //Edge case if (fpCount == 0 && tnCount == 0) { return edgeCase; } return fpCount / (fpCount + tnCount); } /** * False positive rate based on guesses so far * Takes into account all known classes and outputs average fpr across all of them * * @return the fpr for the outcomes */ public double falsePositiveRate() { double fprAlloc = 0.0; int classCount = 0; for (Integer classLabel : confusion.getClasses()) { double fpr = falsePositiveRate(classLabel, -1.0); if (fpr != -1.0) { fprAlloc += falsePositiveRate(classLabel); classCount++; } } return fprAlloc / (double) classCount; } /** * Returns the false negative rate for a given label * * @param classLabel the label * @return fnr as a double */ public double falseNegativeRate(Integer classLabel) { return recall(classLabel, DEFAULT_EDGE_VALUE); } /** * Returns the false negative rate for a given label * * @param classLabel the label * @param edgeCase What to output in case of 0/0 * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { double fnCount = falseNegatives.getCount(classLabel); double tpCount = truePositives.getCount(classLabel); //Edge case if (fnCount == 0 && tpCount == 0) { return edgeCase; } return fnCount / (fnCount + tpCount); } /** * False negative rate based on guesses so far * Takes into account all known classes and outputs average fnr across all of them * * @return the fnr for the outcomes */ public double falseNegativeRate() { double fnrAlloc = 0.0; int classCount = 0; for (Integer classLabel : confusion.getClasses()) { double fnr = falseNegativeRate(classLabel, -1.0); if (fnr != -1.0) { fnrAlloc += falseNegativeRate(classLabel); classCount++; } } return fnrAlloc / (double) classCount; } /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw * * @return the fpr for the outcomes */ public double falseAlarmRate() { return (falsePositiveRate() + falseNegativeRate()) / 2.0; } /** * Calculate f1 score for a given class * * @param classLabel the label to calculate f1 for * @return the f1 score for the given label */ public double f1(Integer classLabel) { double precision = precision(classLabel); double recall = recall(classLabel); if (precision == 0 || recall == 0) return 0; return 2.0 * ((precision * recall / (precision + recall))); } /** * TP: true positive * FP: False Positive * FN: False Negative * F1 score: 2 * TP / (2TP + FP + FN) * * @return the f1 score or harmonic mean based on current guesses */ public double f1() { double precision = precision(); double recall = recall(); if (precision == 0 || recall == 0) return 0; return 2.0 * ((precision * recall / (precision + recall))); } /** * Accuracy: * (TP + TN) / (P + N) * * @return the accuracy of the guesses so far */ public double accuracy() { //Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total int nClasses = confusion.getClasses().size(); int countCorrect = 0; for (int i = 0; i < nClasses; i++) { countCorrect += confusion.getCount(i, i); } return countCorrect / (double)getNumRowCounter(); } // Access counter methods /** * True positives: correctly rejected * * @return the total true positives so far */ public Map truePositives() { return convertToMap(truePositives, confusion.getClasses().size()); } /** * True negatives: correctly rejected * * @return the total true negatives so far */ public Map trueNegatives() { return convertToMap(trueNegatives, confusion.getClasses().size()); } /** * False positive: wrong guess * * @return the count of the false positives */ public Map falsePositives() { return convertToMap(falsePositives, confusion.getClasses().size()); } /** * False negatives: correctly rejected * * @return the total false negatives so far */ public Map falseNegatives() { return convertToMap(falseNegatives, confusion.getClasses().size()); } /** * Total negatives true negatives + false negatives * * @return the overall negative count */ public Map negative() { return addMapsByKey(trueNegatives(), falsePositives()); } /** * Returns all of the positive guesses: * true positive + false negative */ public Map positive() { return addMapsByKey(truePositives(), falseNegatives()); } private Map convertToMap(Counter counter, int maxCount) { Map map = new HashMap<>(); for (int i = 0; i < maxCount; i++) { map.put(i, (int) counter.getCount(i)); } return map; } private Map addMapsByKey(Map first, Map second) { Map out = new HashMap<>(); Set keys = new HashSet<>(first.keySet()); keys.addAll(second.keySet()); for (Integer i : keys) { Integer f = first.get(i); Integer s = second.get(i); if (f == null) f = 0; if (s == null) s = 0; out.put(i, f + s); } return out; } // Incrementing counters public void incrementTruePositives(Integer classLabel) { truePositives.incrementCount(classLabel, 1.0); } public void incrementTrueNegatives(Integer classLabel) { trueNegatives.incrementCount(classLabel, 1.0); } public void incrementFalseNegatives(Integer classLabel) { falseNegatives.incrementCount(classLabel, 1.0); } public void incrementFalsePositives(Integer classLabel) { falsePositives.incrementCount(classLabel, 1.0); } // Other misc methods /** * Adds to the confusion matrix * * @param real the actual guess * @param guess the system guess */ public void addToConfusion(Integer real, Integer guess) { confusion.add(real, guess); } /** * Returns the number of times the given label * has actually occurred * * @param clazz the label * @return the number of times the label * actually occurred */ public int classCount(Integer clazz) { return confusion.getActualTotal(clazz); } public int getNumRowCounter() { return numRowCounter; } public String getClassLabel(Integer clazz) { return resolveLabelForClass(clazz); } /** * Returns the confusion matrix variable * * @return confusion matrix variable for this evaluation */ public ConfusionMatrix getConfusionMatrix() { return confusion; } /** * Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts * etc from both * * @param other Evaluation object to merge into this one. */ public void merge(Evaluation other) { if (other == null) return; truePositives.incrementAll(other.truePositives); falsePositives.incrementAll(other.falsePositives); trueNegatives.incrementAll(other.trueNegatives); falseNegatives.incrementAll(other.falseNegatives); if (confusion == null) { if (other.confusion != null) confusion = new ConfusionMatrix<>(other.confusion); } else { if (other.confusion != null) confusion.add(other.confusion); } numRowCounter += other.numRowCounter; if (labelsList.isEmpty()) labelsList.addAll(other.labelsList); } /** * Get a String representation of the confusion matrix */ public String confusionToString() { int nClasses = confusion.getClasses().size(); //First: work out the longest label size int maxLabelSize = 0; for (String s : labelsList) { maxLabelSize = Math.max(maxLabelSize, s.length()); } //Build the formatting for the rows: int labelSize = Math.max(maxLabelSize + 5, 10); StringBuilder sb = new StringBuilder(); sb.append("%-3d"); sb.append("%-"); sb.append(labelSize); sb.append("s | "); StringBuilder headerFormat = new StringBuilder(); headerFormat.append(" %-").append(labelSize).append("s "); for (int i = 0; i < nClasses; i++) { sb.append("%7d"); headerFormat.append("%7d"); } String rowFormat = sb.toString(); StringBuilder out = new StringBuilder(); //First: header row Object[] headerArgs = new Object[nClasses + 1]; headerArgs[0] = "Predicted:"; for (int i = 0; i < nClasses; i++) headerArgs[i + 1] = i; out.append(String.format(headerFormat.toString(), headerArgs)).append("\n"); //Second: divider rows out.append(" Actual:\n"); //Finally: data rows for (int i = 0; i < nClasses; i++) { Object[] args = new Object[nClasses + 2]; args[0] = i; args[1] = labelsList.get(i); for (int j = 0; j < nClasses; j++) { args[j + 2] = confusion.getCount(i, j); } out.append(String.format(rowFormat, args)); out.append("\n"); } return out.toString(); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy