stream.learner.evaluation.ConfusionMatrix Maven / Gradle / Ivy
/*
* streams library
*
* Copyright (C) 2011-2012 by Christian Bockermann, Hendrik Blom
*
* streams is a library, API and runtime environment for processing high
* volume data streams. It is composed of three submodules "stream-api",
* "stream-core" and "stream-runtime".
*
* The streams library (and its submodules) is free software: you can
* redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any
* later version.
*
* The stream.ai library (and its submodules) is distributed in the hope
* that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
* warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package stream.learner.evaluation;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
* This is a data structure for dealing with label-pair-relations. For
* statistical information on a per-label-base there are methods which return
* instance(s) of {@link TableOfConfusion}. That class has methods to calculate
* values like Precision, Recall, [...].
*
*
*
* Look at http://en.wikipedia.org/wiki/Confusion_matrix for detailed
* descriptions.
*
*
*
* Although there is a method for adding a label or a list of labels to the
* existing labels, it is highly recommended to use it as rare as possible, as
* it requires that (internally) the array storing the label-combination-counts
* has to be copied to a new array which then replaces the old array. Instead,
* use the constructor with a list of labels as argument, where that list should
* contain as much labels as possible so that the need for adding labels is at a
* minimum frequency. However, if the labels are not known in advance, using
* {@link #addLabels(java.util.List)} is preferable to
* {@link #addLabel(java.lang.Object)} as there are fewer array copy operations
* needed with the first mentioned method.
*
*
* @param T
* the type of the labels
* @author Benedikt Kulmann, Lukas Kalabis, Christian Bockermann
* <[email protected]>
* @see TableOfConfusion
*/
public final class ConfusionMatrix {
/**
* The list of all labels. Indices refer to the {@link #confusionMatrix}.
*/
private List labels;
/**
*
* Array which stores the counts of classification instances (pairs of true
* and predicted labels)
*
*
*
* The first dimension represents the true labels while the second dimension
* represents the predicted labels.
*
*/
private long[][] confusionMatrix;
/**
* Creates a new ConfusionMatrix instance with an empty list of labels.
* Don't use this constructor if it is possible to construct a list of
* possible labels in advance.
*/
public ConfusionMatrix() {
this(new ArrayList());
}
/**
* Creates a new ConfusionMatrix instance.
*
* @param labels
* Labels to maintain a label-pair-combination-counter for.
*/
public ConfusionMatrix(List labels) {
this.labels = labels;
this.confusionMatrix = new long[labels.size()][labels.size()];
}
/**
*
* Adds the provided label to the list of labels. Afterwards it is necessary
* to create a new array for the internal counters so the usage of this
* method is expensive. Try to add as much labels as possible at a time by
* using {@link #addLabels(java.util.List)} or, which is even better, at
* object creation time.
*
*
*
* Duplicates are ignored.
*
*
* @param additionalLabel
* The label to add to the internal list of labels
*/
public void addLabel(T additionalLabel) {
final List additionalLabelAsList = new ArrayList();
additionalLabelAsList.add(additionalLabel);
addLabels(additionalLabelAsList);
}
/**
*
* Adds the provided list of labels to the internal list of labels.
* Afterwards it is necessary to create a new array for the internal
* counters so the usage of this method is expensive. Try to add as much
* labels as possible at object creation time.
*
*
*
* Duplicates are ignored.
*
*
* @param additionalLabels
* The labels to add to the internal list of labels
*/
public void addLabels(List additionalLabels) {
// construct new list of labels
final List modAdditionalLabels = new ArrayList(additionalLabels);
modAdditionalLabels.removeAll(labels);
labels.addAll(modAdditionalLabels);
// construct new confusion matrix
final long[][] newConfusionMatrix = new long[labels.size()][labels
.size()];
for (int i = 0; i < confusionMatrix.length; i++) {
System.arraycopy(confusionMatrix[i], 0, newConfusionMatrix[i], 0,
confusionMatrix.length);
}
confusionMatrix = newConfusionMatrix;
}
/**
* Returns the list of labels this {@link ConfusionMatrix} maintains
* counters for.
*
* @return The list of labels this {@link ConfusionMatrix} maintains
* counters for.
*/
public List getLabels() {
return labels;
}
/**
*
* Adds a classification instance (true and predicted label) to this
* {@link ConfusionMatrix}.
*
*
*
* Each label which didn't exist previously will be added within this method
* automatically (which is expensive and to be avoided).
*
*
* @param truth
* The true label
* @param prediction
* The predicted label
*/
public void add(T truth, T prediction) {
int indexOfTruth = labels.indexOf(truth);
if (indexOfTruth == -1) {
indexOfTruth = labels.size();
addLabel(truth);
}
int indexOfPrediction = labels.indexOf(prediction);
if (indexOfPrediction == -1) {
indexOfPrediction = labels.size();
addLabel(prediction);
}
confusionMatrix[indexOfTruth][indexOfPrediction]++;
}
/**
* Returns a map which contains a {@link TableOfConfusion} per label.
*
* @return a map of {@link TableOfConfusion} instances
* @see #getTableOfConfusion(java.lang.Object)
*/
public Map getTablesOfConfusion() {
final Map tablesOfConfusion = new HashMap();
for (T label : labels) {
tablesOfConfusion.put(label, getTableOfConfusion(label));
}
return tablesOfConfusion;
}
/**
* Constructs and returns the {@link TableOfConfusion} for the provided
* label.
*
* @param label
* The label to construct a {@link TableOfConfusion} for
* @return A {@link TableOfConfusion} instance for the provided label
*
* @see TableOfConfusion
*/
public TableOfConfusion getTableOfConfusion(T label) {
final TableOfConfusion tableOfConfusion = new TableOfConfusion();
tableOfConfusion.addTruePositive(getTruePositiveCount(label));
tableOfConfusion.addTrueNegative(getTrueNegativeCount(label));
tableOfConfusion.addFalsePositive(getFalsePositiveCount(label));
tableOfConfusion.addFalseNegative(getFalseNegativeCount(label));
return tableOfConfusion;
}
/**
* Calculates and returns the overall accuracy for this confusion matrix in
* range [0,1].
*
* @return the overall accuracy for this confusion matrix in range [0,1]
*/
public double calculateAccuracy() {
double correct = 0.0;
for (int i = 0; i < labels.size(); i++) {
correct += confusionMatrix[i][i];
}
double divisor = 0.0;
for (int i = 0; i < labels.size(); i++) {
for (int j = 0; j < labels.size(); j++) {
divisor += confusionMatrix[i][j];
}
}
if (divisor == 0) {
return Double.NaN;
} else {
return correct / divisor;
}
}
/**
* Returns the class size as a weight for per-class-calculations.
*
* @param label
* The label to get the weight for
* @return the class size as a weight for per-class-calculations.
*/
public long getWeightForLabel(T label) {
final int indexOfLabel = labels.indexOf(label);
long weight = 0L;
for (int i = 0; i < labels.size(); i++) {
if (i == indexOfLabel) {
continue;
}
weight += confusionMatrix[indexOfLabel][i];
}
return weight;
}
/**
* Returns the number of "true positive" instances for the provided label.
*
* @param label
* The label to return the counted number of "true positive"
* instances for
* @return The number of "true positive" instances for the provided label
*/
private long getTruePositiveCount(T label) {
final int indexOfLabel = labels.indexOf(label);
return confusionMatrix[indexOfLabel][indexOfLabel];
}
/**
* Returns the number of "true negative" instances for the provided label.
*
* @param label
* The label to return the counted number of "true negative"
* instances for
* @return The number of "true negative" instances for the provided label
*/
private long getTrueNegativeCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long trueNegativeCount = 0L;
for (int i = 0; i < labels.size(); i++) {
if (i == indexOfLabel) {
continue;
}
for (int j = 0; j < labels.size(); j++) {
if (j == indexOfLabel) {
continue;
}
trueNegativeCount += confusionMatrix[i][j];
}
}
return trueNegativeCount;
}
/**
* Returns the number of "false positive" instances for the provided label.
*
* @param label
* The label to return the counted number of "false positive"
* instances for
* @return The number of "false positive" instances for the provided label
*/
private long getFalsePositiveCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long falsePositiveCount = 0L;
for (int i = 0; i < labels.size(); i++) {
if (i == indexOfLabel) {
continue;
} else {
falsePositiveCount += confusionMatrix[indexOfLabel][i];
}
}
return falsePositiveCount;
}
/**
* Returns the number of "false negative" instances for the provided label.
*
* @param label
* The label to return the counted number of "false negative"
* instances for
* @return The number of "false negative" instances for the provided label
*/
private long getFalseNegativeCount(T label) {
final int indexOfLabel = labels.indexOf(label);
long falseNegativeCount = 0L;
for (int i = 0; i < labels.size(); i++) {
if (i == indexOfLabel) {
continue;
} else {
falseNegativeCount += confusionMatrix[i][indexOfLabel];
}
}
return falseNegativeCount;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
final String lineSeparator = System.getProperty("line.separator");
StringBuilder sb = new StringBuilder(
"ConfusionMatrix (rows=truth,columns=prediction)")
.append(lineSeparator).append("values:").append(lineSeparator);
sb.append(lineSeparator);
sb.append(" ");
for (int i = 0; i < labels.size(); i++) {
sb.append(" | pred: " + labels.get(i));
}
sb.append(lineSeparator);
for (int i = 0; i < labels.size(); i++) {
sb.append("true:" + labels.get(i) + " | ");
for (int j = 0; j < labels.size(); j++) {
sb.append(" ").append(confusionMatrix[i][j]);
}
sb.append(lineSeparator);
}
sb.append(lineSeparator).append("results:").append(lineSeparator);
for (T label : labels) {
sb.append(label).append(lineSeparator)
.append(getTableOfConfusion(label));
}
return sb.toString();
}
public String toHtml() {
StringBuilder b = new StringBuilder("");
b.append("");
b.append("prediction ");
b.append(" ");
b.append("");
for (T l : labels) {
b.append("" + l.toString() + " ");
}
b.append("Precision ");
b.append(" ");
DecimalFormat fmt = new DecimalFormat("0.00 %");
for (int i = 0; i < labels.size(); i++) {
T cur = labels.get(i);
b.append("");
if (i == 0)
b.append("true ");
b.append("" + labels.get(i) + " ");
Double tp = 0.0d;
Double fp = 0.0d;
for (int j = 0; j < labels.size(); j++) {
T against = labels.get(j);
if (cur != against)
fp += confusionMatrix[i][j];
else
tp += confusionMatrix[i][j];
b.append(" ").append(confusionMatrix[i][j]).append(" ");
}
b.append("" + fmt.format(tp / (tp + fp)) + " ");
b.append(" \n");
}
b.append("
");
return b.toString();
}
}