
org.deeplearning4j.eval.Evaluation Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.eval;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.*;
/**
* Evaluation metrics:
* precision, recall, f1
*
* @author Adam Gibson
*/
public class Evaluation implements Serializable {
protected Counter truePositives = new Counter<>();
protected Counter falsePositives = new Counter<>();
protected Counter trueNegatives = new Counter<>();
protected Counter falseNegatives = new Counter<>();
protected ConfusionMatrix confusion;
protected int numRowCounter = 0;
protected List labelsList = new ArrayList<>();
protected static Logger log = LoggerFactory.getLogger(Evaluation.class);
//What to output from the precision/recall function when we encounter an edge case
protected static final double DEFAULT_EDGE_VALUE = 0.0;
// Empty constructor
public Evaluation() {
}
// Constructor that takes number of output classes
/**
* The number of classes to account
* for in the evaluation
* @param numClasses the number of classes to account for in the evaluation
*/
public Evaluation(int numClasses) {
this(createLabels(numClasses));
}
/**
* The labels to include with the evaluation.
* This constructor can be used for
* generating labeled output rather than just
* numbers for the labels
* @param labels the labels to use
* for the output
*/
public Evaluation(List labels) {
this.labelsList = labels;
if(labels != null){
createConfusion(labels.size());
}
}
/**
* Use a map to generate labels
* Pass in a label index with the actual label
* you want to use for output
* @param labels a map of label index to label value
*/
public Evaluation(Map labels) {
this(createLabelsFromMap(labels));
}
private static List createLabels(int numClasses){
if(numClasses == 1) numClasses = 2; //Binary (single output variable) case...
List list = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++){
list.add(String.valueOf(i));
}
return list;
}
private static List createLabelsFromMap(Map labels ){
int size = labels.size();
List labelsList = new ArrayList<>(size);
for( int i = 0; i < size; i++) {
String str = labels.get(i);
if(str == null) throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size-1) + ")");
labelsList.add(str);
}
return labelsList;
}
private void createConfusion(int nClasses) {
List classes = new ArrayList<>();
for (int i = 0; i < nClasses; i++) {
classes.add(i);
}
confusion = new ConfusionMatrix<>(classes);
}
/**
* Evaluate the output
* using the given true labels,
* the input to the multi layer network
* and the multi layer network to
* use for evaluation
* @param trueLabels the labels to ise
* @param input the input to the network to use
* for evaluation
* @param network the network to use for output
*/
public void eval(INDArray trueLabels,INDArray input,ComputationGraph network) {
eval(trueLabels,network.output(false,input)[0]);
}
/**
* Evaluate the output
* using the given true labels,
* the input to the multi layer network
* and the multi layer network to
* use for evaluation
* @param trueLabels the labels to ise
* @param input the input to the network to use
* for evaluation
* @param network the network to use for output
*/
public void eval(INDArray trueLabels,INDArray input,MultiLayerNetwork network) {
eval(trueLabels,network.output(input, Layer.TrainingMode.TEST));
}
/**
* Collects statistics on the real outcomes vs the
* guesses. This is for logistic outcome matrices.
*
* Note that an IllegalArgumentException is thrown if the two passed in
* matrices aren't the same length.
*
* @param realOutcomes the real outcomes (labels - usually binary)
* @param guesses the guesses/prediction (usually a probability vector)
*/
public void eval(INDArray realOutcomes, INDArray guesses) {
// Add the number of rows to numRowCounter
numRowCounter += realOutcomes.shape()[0];
// If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from
if (confusion == null) {
int nClasses = realOutcomes.columns();
if(nClasses == 1) nClasses = 2; //Binary (single output variable) case
labelsList = new ArrayList<>(nClasses);
for( int i = 0; i classes = confusion.getClasses();
for (Integer clazz : classes) {
actual = resolveLabelForClass(clazz);
//Output confusion matrix
for (Integer clazz2 : classes) {
int count = confusion.getCount(clazz, clazz2);
if (count != 0) {
expected = resolveLabelForClass(clazz2);
builder.append(String.format("Examples labeled as %s classified by model as %s: %d times%n", actual, expected, count));
}
}
//Output possible warnings regarding precision/recall calculation
if (!suppressWarnings && truePositives.getCount(clazz) == 0) {
if (falsePositives.getCount(clazz) == 0) {
warnings.append(String.format("Warning: class %s was never predicted by the model. This class was excluded from the average precision%n", actual));
}
if (falseNegatives.getCount(clazz) == 0) {
warnings.append(String.format("Warning: class %s has never appeared as a true label. This class was excluded from the average recall%n", actual));
}
}
}
builder.append("\n");
builder.append(warnings);
DecimalFormat df = new DecimalFormat("#.####");
builder.append("\n==========================Scores========================================");
builder.append("\n Accuracy: ").append(df.format(accuracy()));
builder.append("\n Precision: ").append(df.format(precision()));
builder.append("\n Recall: ").append(df.format(recall()));
builder.append("\n F1 Score: ").append(df.format(f1()));
builder.append("\n========================================================================");
return builder.toString();
}
private String resolveLabelForClass(Integer clazz) {
if(labelsList != null && labelsList.size() > clazz ) return labelsList.get(clazz);
return clazz.toString();
}
/**
* Returns the precision for a given label
*
* @param classLabel the label
* @return the precision for the label
*/
public double precision(Integer classLabel) {
return precision(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the precision for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return the precision for the label
*/
public double precision(Integer classLabel, double edgeCase) {
double tpCount = truePositives.getCount(classLabel);
double fpCount = falsePositives.getCount(classLabel);
//Edge case
if (tpCount == 0 && fpCount == 0) {
return edgeCase;
}
return tpCount / (tpCount + fpCount);
}
/**
* Precision based on guesses so far
* Takes into account all known classes and outputs average precision across all of them
*
* @return the total precision based on guesses so far
*/
public double precision() {
double precisionAcc = 0.0;
int classCount = 0;
for (Integer classLabel : confusion.getClasses()) {
double precision = precision(classLabel, -1);
if (precision != -1) {
precisionAcc += precision(classLabel);
classCount++;
}
}
return precisionAcc / (double) classCount;
}
/**
* Returns the recall for a given label
*
* @param classLabel the label
* @return Recall rate as a double
*/
public double recall(Integer classLabel) {
return recall(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the recall for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return Recall rate as a double
*/
public double recall(Integer classLabel, double edgeCase) {
double tpCount = truePositives.getCount(classLabel);
double fnCount = falseNegatives.getCount(classLabel);
//Edge case
if (tpCount == 0 && fnCount == 0) {
return edgeCase;
}
return tpCount / (tpCount + fnCount);
}
/**
* Recall based on guesses so far
* Takes into account all known classes and outputs average recall across all of them
*
* @return the recall for the outcomes
*/
public double recall() {
double recallAcc = 0.0;
int classCount = 0;
for (Integer classLabel : confusion.getClasses()) {
double recall = recall(classLabel, -1.0);
if (recall != -1.0) {
recallAcc += recall(classLabel);
classCount++;
}
}
return recallAcc / (double) classCount;
}
/**
* Returns the false positive rate for a given label
*
* @param classLabel the label
* @return fpr as a double
*/
public double falsePositiveRate(Integer classLabel) {
return recall(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the false positive rate for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return fpr as a double
*/
public double falsePositiveRate(Integer classLabel, double edgeCase) {
double fpCount = falsePositives.getCount(classLabel);
double tnCount = trueNegatives.getCount(classLabel);
//Edge case
if (fpCount == 0 && tnCount == 0) {
return edgeCase;
}
return fpCount / (fpCount + tnCount);
}
/**
* False positive rate based on guesses so far
* Takes into account all known classes and outputs average fpr across all of them
*
* @return the fpr for the outcomes
*/
public double falsePositiveRate() {
double fprAlloc = 0.0;
int classCount = 0;
for (Integer classLabel : confusion.getClasses()) {
double fpr = falsePositiveRate(classLabel, -1.0);
if (fpr != -1.0) {
fprAlloc += falsePositiveRate(classLabel);
classCount++;
}
}
return fprAlloc / (double) classCount;
}
/**
* Returns the false negative rate for a given label
*
* @param classLabel the label
* @return fnr as a double
*/
public double falseNegativeRate(Integer classLabel) {
return recall(classLabel, DEFAULT_EDGE_VALUE);
}
/**
* Returns the false negative rate for a given label
*
* @param classLabel the label
* @param edgeCase What to output in case of 0/0
* @return fnr as a double
*/
public double falseNegativeRate(Integer classLabel, double edgeCase) {
double fnCount = falseNegatives.getCount(classLabel);
double tpCount = truePositives.getCount(classLabel);
//Edge case
if (fnCount == 0 && tpCount == 0) {
return edgeCase;
}
return fnCount / (fnCount + tpCount);
}
/**
* False negative rate based on guesses so far
* Takes into account all known classes and outputs average fnr across all of them
*
* @return the fnr for the outcomes
*/
public double falseNegativeRate() {
double fnrAlloc = 0.0;
int classCount = 0;
for (Integer classLabel : confusion.getClasses()) {
double fnr = falseNegativeRate(classLabel, -1.0);
if (fnr != -1.0) {
fnrAlloc += falseNegativeRate(classLabel);
classCount++;
}
}
return fnrAlloc / (double) classCount;
}
/**
* False Alarm Rate (FAR) reflects rate of misclassified to classified records
* http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
*
* @return the fpr for the outcomes
*/
public double falseAlarmRate() {
return (falsePositiveRate() + falseNegativeRate()) / 2.0;
}
/**
* Calculate f1 score for a given class
*
* @param classLabel the label to calculate f1 for
* @return the f1 score for the given label
*/
public double f1(Integer classLabel) {
double precision = precision(classLabel);
double recall = recall(classLabel);
if (precision == 0 || recall == 0)
return 0;
return 2.0 * ((precision * recall / (precision + recall)));
}
/**
* TP: true positive
* FP: False Positive
* FN: False Negative
* F1 score: 2 * TP / (2TP + FP + FN)
*
* @return the f1 score or harmonic mean based on current guesses
*/
public double f1() {
double precision = precision();
double recall = recall();
if (precision == 0 || recall == 0)
return 0;
return 2.0 * ((precision * recall / (precision + recall)));
}
/**
* Accuracy:
* (TP + TN) / (P + N)
*
* @return the accuracy of the guesses so far
*/
public double accuracy() {
//Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total
int nClasses = confusion.getClasses().size();
int countCorrect = 0;
for (int i = 0; i < nClasses; i++) {
countCorrect += confusion.getCount(i, i);
}
return countCorrect / (double)getNumRowCounter();
}
// Access counter methods
/**
* True positives: correctly rejected
*
* @return the total true positives so far
*/
public Map truePositives() {
return convertToMap(truePositives, confusion.getClasses().size());
}
/**
* True negatives: correctly rejected
*
* @return the total true negatives so far
*/
public Map trueNegatives() {
return convertToMap(trueNegatives, confusion.getClasses().size());
}
/**
* False positive: wrong guess
*
* @return the count of the false positives
*/
public Map falsePositives() {
return convertToMap(falsePositives, confusion.getClasses().size());
}
/**
* False negatives: correctly rejected
*
* @return the total false negatives so far
*/
public Map falseNegatives() {
return convertToMap(falseNegatives, confusion.getClasses().size());
}
/**
* Total negatives true negatives + false negatives
*
* @return the overall negative count
*/
public Map negative() {
return addMapsByKey(trueNegatives(), falsePositives());
}
/**
* Returns all of the positive guesses:
* true positive + false negative
*/
public Map positive() {
return addMapsByKey(truePositives(), falseNegatives());
}
private Map convertToMap(Counter counter, int maxCount) {
Map map = new HashMap<>();
for (int i = 0; i < maxCount; i++) {
map.put(i, (int) counter.getCount(i));
}
return map;
}
private Map addMapsByKey(Map first, Map second) {
Map out = new HashMap<>();
Set keys = new HashSet<>(first.keySet());
keys.addAll(second.keySet());
for (Integer i : keys) {
Integer f = first.get(i);
Integer s = second.get(i);
if (f == null) f = 0;
if (s == null) s = 0;
out.put(i, f + s);
}
return out;
}
// Incrementing counters
public void incrementTruePositives(Integer classLabel) {
truePositives.incrementCount(classLabel, 1.0);
}
public void incrementTrueNegatives(Integer classLabel) {
trueNegatives.incrementCount(classLabel, 1.0);
}
public void incrementFalseNegatives(Integer classLabel) {
falseNegatives.incrementCount(classLabel, 1.0);
}
public void incrementFalsePositives(Integer classLabel) {
falsePositives.incrementCount(classLabel, 1.0);
}
// Other misc methods
/**
* Adds to the confusion matrix
*
* @param real the actual guess
* @param guess the system guess
*/
public void addToConfusion(Integer real, Integer guess) {
confusion.add(real, guess);
}
/**
* Returns the number of times the given label
* has actually occurred
*
* @param clazz the label
* @return the number of times the label
* actually occurred
*/
public int classCount(Integer clazz) {
return confusion.getActualTotal(clazz);
}
public int getNumRowCounter() {
return numRowCounter;
}
public String getClassLabel(Integer clazz) {
return resolveLabelForClass(clazz);
}
/**
* Returns the confusion matrix variable
*
* @return confusion matrix variable for this evaluation
*/
public ConfusionMatrix getConfusionMatrix() {
return confusion;
}
/**
* Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts
* etc from both
*
* @param other Evaluation object to merge into this one.
*/
public void merge(Evaluation other) {
if (other == null) return;
truePositives.incrementAll(other.truePositives);
falsePositives.incrementAll(other.falsePositives);
trueNegatives.incrementAll(other.trueNegatives);
falseNegatives.incrementAll(other.falseNegatives);
if (confusion == null) {
if (other.confusion != null) confusion = new ConfusionMatrix<>(other.confusion);
} else {
if (other.confusion != null) confusion.add(other.confusion);
}
numRowCounter += other.numRowCounter;
if (labelsList.isEmpty()) labelsList.addAll(other.labelsList);
}
/**
* Get a String representation of the confusion matrix
*/
public String confusionToString() {
int nClasses = confusion.getClasses().size();
//First: work out the longest label size
int maxLabelSize = 0;
for (String s : labelsList) {
maxLabelSize = Math.max(maxLabelSize, s.length());
}
//Build the formatting for the rows:
int labelSize = Math.max(maxLabelSize + 5, 10);
StringBuilder sb = new StringBuilder();
sb.append("%-3d");
sb.append("%-");
sb.append(labelSize);
sb.append("s | ");
StringBuilder headerFormat = new StringBuilder();
headerFormat.append(" %-").append(labelSize).append("s ");
for (int i = 0; i < nClasses; i++) {
sb.append("%7d");
headerFormat.append("%7d");
}
String rowFormat = sb.toString();
StringBuilder out = new StringBuilder();
//First: header row
Object[] headerArgs = new Object[nClasses + 1];
headerArgs[0] = "Predicted:";
for (int i = 0; i < nClasses; i++) headerArgs[i + 1] = i;
out.append(String.format(headerFormat.toString(), headerArgs)).append("\n");
//Second: divider rows
out.append(" Actual:\n");
//Finally: data rows
for (int i = 0; i < nClasses; i++) {
Object[] args = new Object[nClasses + 2];
args[0] = i;
args[1] = labelsList.get(i);
for (int j = 0; j < nClasses; j++) {
args[j + 2] = confusion.getCount(i, j);
}
out.append(String.format(rowFormat, args));
out.append("\n");
}
return out.toString();
}
}