All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.eval.ConfusionMatrix Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.eval;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;

/**
 * This data structure provides an easy way to build and output a confusion matrix. A confusion
 * matrix is a two dimensional table with a row and table for each class. Each element in the matrix
 * shows the number of test examples for which the actual class is the row and the predicted class
 * is the column. Display of this matrix is useful for identifying when a system is confusing two
 * classes
 * 
 * 
* For more info @see The wikipedia page on * Confusion Matrices * *
* Copyright (c) 2011, Regents of the University of Colorado
* All rights reserved. * * @author Lee Becker * * @param * The data type used to represent the class labels */ public class ConfusionMatrix> { private Map> matrix; private SortedSet classes; /** * Creates an empty confusion Matrix */ public ConfusionMatrix() { this.matrix = new HashMap>(); this.classes = new TreeSet(Ordering.natural().nullsFirst()); } /** * Creates a new ConfusionMatrix initialized with the contents of another ConfusionMatrix. */ public ConfusionMatrix(ConfusionMatrix other) { this(); this.add(other); } /** * Increments the entry specified by actual and predicted by one. */ public void add(T actual, T predicted) { add(actual, predicted, 1); } /** * Increments the entry specified by actual and predicted by count. */ public void add(T actual, T predicted, int count) { if (matrix.containsKey(actual)) { matrix.get(actual).add(predicted, count); } else { Multiset counts = HashMultiset.create(); counts.add(predicted, count); matrix.put(actual, counts); } classes.add(actual); classes.add(predicted); } /** * Adds the entries from another confusion matrix to this one. */ public void add(ConfusionMatrix other) { for (T actual : other.matrix.keySet()) { Multiset counts = other.matrix.get(actual); for (T predicted : counts.elementSet()) { int count = counts.count(predicted); this.add(actual, predicted, count); } } } /** * Gives the set of all classes in the confusion matrix. */ public SortedSet getClasses() { return classes; } /** * Gives the count of the number of times the "predicted" class was predicted for the "actual" * class. */ public int getCount(T actual, T predicted) { if (!matrix.containsKey(actual)) { return 0; } else { return matrix.get(actual).count(predicted); } } /** * Computes the total number of times the class was predicted by the classifier. */ public int getPredictedTotal(T predicted) { int total = 0; for (T actual : classes) { total += getCount(actual, predicted); } return total; } /** * Computes the total number of times the class actually appeared in the data. */ public int getActualTotal(T actual) { if (!matrix.containsKey(actual)) { return 0; } else { int total = 0; for (T elem : matrix.get(actual).elementSet()) { total += matrix.get(actual).count(elem); } return total; } } @Override public String toString() { return matrix.toString(); } /** * Outputs the ConfusionMatrix as comma-separated values for easy import into spreadsheets */ public String toCSV() { StringBuilder builder = new StringBuilder(); // Header Row builder.append(",,Predicted Class,\n"); // Predicted Classes Header Row builder.append(",,"); for (T predicted : classes) { builder.append(String.format("%s,", predicted)); } builder.append("Total\n"); // Data Rows String firstColumnLabel = "Actual Class,"; for (T actual : classes) { builder.append(firstColumnLabel); firstColumnLabel = ","; builder.append(String.format("%s,", actual)); for (T predicted : classes) { builder.append(getCount(actual, predicted)); builder.append(","); } // Actual Class Totals Column builder.append(getActualTotal(actual)); builder.append("\n"); } // Predicted Class Totals Row builder.append(",Total,"); for (T predicted : classes) { builder.append(getPredictedTotal(predicted)); builder.append(","); } builder.append("\n"); return builder.toString(); } /** * Outputs Confusion Matrix in an HTML table. Cascading Style Sheets (CSS) can control the table's * appearance by defining the empty-space, actual-count-header, predicted-class-header, and * count-element classes. For example * * @return html string */ public String toHTML() { StringBuilder builder = new StringBuilder(); int numClasses = classes.size(); // Header Row builder.append("\n"); builder.append("\n", numClasses + 1)); // Predicted Classes Header Row builder.append(""); // builder.append(""); for (T predicted : classes) { builder.append(""); } builder.append(""); builder.append("\n"); // Data Rows String firstColumnLabel = String.format( "", numClasses + 1); for (T actual : classes) { builder.append(firstColumnLabel); firstColumnLabel = ""; builder.append(String.format("", actual)); for (T predicted : classes) { builder.append(""); } // Actual Class Totals Column builder.append(""); builder.append("\n"); } // Predicted Class Totals Row builder.append(""); for (T predicted : classes) { builder.append(""); } builder.append("\n"); builder.append("\n"); builder.append("
"); builder.append(String.format( "Predicted Class
"); builder.append(predicted); builder.append("Total
Actual Class
%s"); builder.append(getCount(actual, predicted)); builder.append(""); builder.append(getActualTotal(actual)); builder.append("
Total"); builder.append(getPredictedTotal(predicted)); builder.append("
\n"); return builder.toString(); } public static void main(String[] args) { ConfusionMatrix confusionMatrix = new ConfusionMatrix(); confusionMatrix.add("a", "a", 88); confusionMatrix.add("a", "b", 10); // confusionMatrix.add("a", "c", 2); confusionMatrix.add("b", "a", 14); confusionMatrix.add("b", "b", 40); confusionMatrix.add("b", "c", 6); confusionMatrix.add("c", "a", 18); confusionMatrix.add("c", "b", 10); confusionMatrix.add("c", "c", 12); ConfusionMatrix confusionMatrix2 = new ConfusionMatrix(confusionMatrix); confusionMatrix2.add(confusionMatrix); System.out.println(confusionMatrix2.toHTML()); System.out.println(confusionMatrix2.toCSV()); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy