org.apache.mahout.classifier.ConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* Licensed to the Apache Software Foundation (ASF) under one or more
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import com.google.common.base.Preconditions;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
*
* The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
*
* See http://en.wikipedia.org/wiki/Confusion_matrix for background
*/
public class ConfusionMatrix {
private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
private final Map labelMap = new LinkedHashMap<>();
private final int[][] confusionMatrix;
private int samples = 0;
private String defaultLabel = "unknown";
public ConfusionMatrix(Collection labels, String defaultLabel) {
confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
this.defaultLabel = defaultLabel;
int i = 0;
for (String label : labels) {
labelMap.put(label, i++);
}
labelMap.put(defaultLabel, i);
}
public ConfusionMatrix(Matrix m) {
confusionMatrix = new int[m.numRows()][m.numRows()];
setMatrix(m);
}
public int[][] getConfusionMatrix() {
return confusionMatrix;
}
public Collection getLabels() {
return Collections.unmodifiableCollection(labelMap.keySet());
}
private int numLabels() {
return labelMap.size();
}
public double getAccuracy(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
int correct = 0;
for (int i = 0; i < numLabels(); i++) {
labelTotal += confusionMatrix[labelId][i];
if (i == labelId) {
correct += confusionMatrix[labelId][i];
}
}
return 100.0 * correct / labelTotal;
}
// Producer accuracy
public double getAccuracy() {
int total = 0;
int correct = 0;
for (int i = 0; i < numLabels(); i++) {
for (int j = 0; j < numLabels(); j++) {
total += confusionMatrix[i][j];
if (i == j) {
correct += confusionMatrix[i][j];
}
}
}
return 100.0 * correct / total;
}
/** Sum of true positives and false negatives */
private int getActualNumberOfTestExamplesForClass(String label) {
int labelId = labelMap.get(label);
int sum = 0;
for (int i = 0; i < numLabels(); i++) {
sum += confusionMatrix[labelId][i];
}
return sum;
}
public double getPrecision(String label) {
int labelId = labelMap.get(label);
int truePositives = confusionMatrix[labelId][labelId];
int falsePositives = 0;
for (int i = 0; i < numLabels(); i++) {
if (i == labelId) {
continue;
}
falsePositives += confusionMatrix[i][labelId];
}
if (truePositives + falsePositives == 0) {
return 0;
}
return ((double) truePositives) / (truePositives + falsePositives);
}
public double getWeightedPrecision() {
double[] precisions = new double[numLabels()];
double[] weights = new double[numLabels()];
int index = 0;
for (String label : labelMap.keySet()) {
precisions[index] = getPrecision(label);
weights[index] = getActualNumberOfTestExamplesForClass(label);
index++;
}
return new Mean().evaluate(precisions, weights);
}
public double getRecall(String label) {
int labelId = labelMap.get(label);
int truePositives = confusionMatrix[labelId][labelId];
int falseNegatives = 0;
for (int i = 0; i < numLabels(); i++) {
if (i == labelId) {
continue;
}
falseNegatives += confusionMatrix[labelId][i];
}
if (truePositives + falseNegatives == 0) {
return 0;
}
return ((double) truePositives) / (truePositives + falseNegatives);
}
public double getWeightedRecall() {
double[] recalls = new double[numLabels()];
double[] weights = new double[numLabels()];
int index = 0;
for (String label : labelMap.keySet()) {
recalls[index] = getRecall(label);
weights[index] = getActualNumberOfTestExamplesForClass(label);
index++;
}
return new Mean().evaluate(recalls, weights);
}
public double getF1score(String label) {
double precision = getPrecision(label);
double recall = getRecall(label);
if (precision + recall == 0) {
return 0;
}
return 2 * precision * recall / (precision + recall);
}
public double getWeightedF1score() {
double[] f1Scores = new double[numLabels()];
double[] weights = new double[numLabels()];
int index = 0;
for (String label : labelMap.keySet()) {
f1Scores[index] = getF1score(label);
weights[index] = getActualNumberOfTestExamplesForClass(label);
index++;
}
return new Mean().evaluate(f1Scores, weights);
}
// User accuracy
public double getReliability() {
int count = 0;
double accuracy = 0;
for (String label: labelMap.keySet()) {
if (!label.equals(defaultLabel)) {
accuracy += getAccuracy(label);
}
count++;
}
return accuracy / count;
}
/**
* Accuracy v.s. randomly classifying all samples.
* kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
* Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
* Educational And Psychological Measurement 20:37-46.
*
* Formula and variable names from:
* http://www.yale.edu/ceo/OEFS/Accuracy.pdf
*
* @return double
*/
public double getKappa() {
double a = 0.0;
double b = 0.0;
for (int i = 0; i < confusionMatrix.length; i++) {
a += confusionMatrix[i][i];
double br = 0;
for (int j = 0; j < confusionMatrix.length; j++) {
br += confusionMatrix[i][j];
}
double bc = 0;
for (int[] vec : confusionMatrix) {
bc += vec[i];
}
b += br * bc;
}
return (samples * a - b) / (samples * samples - b);
}
/**
* Standard deviation of normalized producer accuracy
* Not a standard score
* @return double
*/
public RunningAverageAndStdDev getNormalizedStats() {
RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
for (int d = 0; d < confusionMatrix.length; d++) {
double total = 0;
for (int j = 0; j < confusionMatrix.length; j++) {
total += confusionMatrix[d][j];
}
summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
}
return summer;
}
public int getCorrect(String label) {
int labelId = labelMap.get(label);
return confusionMatrix[labelId][labelId];
}
public int getTotal(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
}
return labelTotal;
}
public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
samples++;
incrementCount(correctLabel, classifiedResult.getLabel());
}
public void addInstance(String correctLabel, String classifiedLabel) {
samples++;
incrementCount(correctLabel, classifiedLabel);
}
public int getCount(String correctLabel, String classifiedLabel) {
if(!labelMap.containsKey(correctLabel)) {
LOG.warn("Label {} did not appear in the training examples", correctLabel);
return 0;
}
Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
return confusionMatrix[correctId][classifiedId];
}
public void putCount(String correctLabel, String classifiedLabel, int count) {
if(!labelMap.containsKey(correctLabel)) {
LOG.warn("Label {} did not appear in the training examples", correctLabel);
return;
}
Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
samples++;
}
confusionMatrix[correctId][classifiedId] = count;
}
public String getDefaultLabel() {
return defaultLabel;
}
public void incrementCount(String correctLabel, String classifiedLabel, int count) {
putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
}
public void incrementCount(String correctLabel, String classifiedLabel) {
incrementCount(correctLabel, classifiedLabel, 1);
}
public ConfusionMatrix merge(ConfusionMatrix b) {
Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
for (String correctLabel : this.labelMap.keySet()) {
for (String classifiedLabel : this.labelMap.keySet()) {
incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
}
}
return this;
}
public Matrix getMatrix() {
int length = confusionMatrix.length;
Matrix m = new DenseMatrix(length, length);
for (int r = 0; r < length; r++) {
for (int c = 0; c < length; c++) {
m.set(r, c, confusionMatrix[r][c]);
}
}
Map labels = new HashMap<>();
for (Map.Entry entry : labelMap.entrySet()) {
labels.put(entry.getKey(), entry.getValue());
}
m.setRowLabelBindings(labels);
m.setColumnLabelBindings(labels);
return m;
}
public void setMatrix(Matrix m) {
int length = confusionMatrix.length;
if (m.numRows() != m.numCols()) {
throw new IllegalArgumentException(
"ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
}
for (int r = 0; r < length; r++) {
for (int c = 0; c < length; c++) {
confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
}
}
Map labels = m.getRowLabelBindings();
if (labels == null) {
labels = m.getColumnLabelBindings();
}
if (labels != null) {
String[] sorted = sortLabels(labels);
verifyLabels(length, sorted);
labelMap.clear();
for (int i = 0; i < length; i++) {
labelMap.put(sorted[i], i);
}
}
}
private static String[] sortLabels(Map labels) {
String[] sorted = new String[labels.size()];
for (Map.Entry entry : labels.entrySet()) {
sorted[entry.getValue()] = entry.getKey();
}
return sorted;
}
private static void verifyLabels(int length, String[] sorted) {
Preconditions.checkArgument(sorted.length == length, "One label, one row");
for (int i = 0; i < length; i++) {
if (sorted[i] == null) {
Preconditions.checkArgument(false, "One label, one row");
}
}
}
/**
* This is overloaded. toString() is not a formatted report you print for a manager :)
* Assume that if there are no default assignments, the default feature was not used
*/
@Override
public String toString() {
StringBuilder returnString = new StringBuilder(200);
returnString.append("=======================================================").append('\n');
returnString.append("Confusion Matrix\n");
returnString.append("-------------------------------------------------------").append('\n');
int unclassified = getTotal(defaultLabel);
for (Map.Entry entry : this.labelMap.entrySet()) {
if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
continue;
}
returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
}
returnString.append("<--Classified as").append('\n');
for (Map.Entry entry : this.labelMap.entrySet()) {
if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
continue;
}
String correctLabel = entry.getKey();
int labelTotal = 0;
for (String classifiedLabel : this.labelMap.keySet()) {
if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
continue;
}
returnString.append(
StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
labelTotal += getCount(correctLabel, classifiedLabel);
}
returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
.append(" = ").append(correctLabel).append('\n');
}
if (unclassified > 0) {
returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
}
returnString.append('\n');
return returnString.toString();
}
static String getSmallLabel(int i) {
int val = i;
StringBuilder returnString = new StringBuilder();
do {
int n = val % 26;
returnString.insert(0, (char) ('a' + n));
val /= 26;
} while (val > 0);
return returnString.toString();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy