weka.classifiers.evaluation.ConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* NominalPrediction.java
* Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.evaluation;
import java.util.ArrayList;
import weka.classifiers.CostMatrix;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.matrix.Matrix;
/**
* Cells of this matrix correspond to counts of the number (or weight) of
* predictions for each actual value / predicted value combination.
*
* @author Len Trigg ([email protected])
* @version $Revision: 10169 $
*/
public class ConfusionMatrix extends Matrix {
/** for serialization */
private static final long serialVersionUID = -181789981401504090L;
/** Stores the names of the classes */
protected String[] m_ClassNames;
/**
* Creates the confusion matrix with the given class names.
*
* @param classNames an array containing the names the classes.
*/
public ConfusionMatrix(String[] classNames) {
super(classNames.length, classNames.length);
m_ClassNames = classNames.clone();
}
/**
* Makes a copy of this ConfusionMatrix after applying the supplied CostMatrix
* to the cells. The resulting ConfusionMatrix can be used to get
* cost-weighted statistics.
*
* @param costs the CostMatrix.
* @return a ConfusionMatrix that has had costs applied.
* @exception Exception if the CostMatrix is not of the same size as this
* ConfusionMatrix.
*/
public ConfusionMatrix makeWeighted(CostMatrix costs) throws Exception {
if (costs.size() != size()) {
throw new Exception("Cost and confusion matrices must be the same size");
}
ConfusionMatrix weighted = new ConfusionMatrix(m_ClassNames);
for (int row = 0; row < size(); row++) {
for (int col = 0; col < size(); col++) {
weighted.set(row, col, get(row, col) * costs.getElement(row, col));
}
}
return weighted;
}
/**
* Creates and returns a clone of this object.
*
* @return a clone of this instance.
*/
@Override
public Object clone() {
ConfusionMatrix m = (ConfusionMatrix) super.clone();
m.m_ClassNames = m_ClassNames.clone();
return m;
}
/**
* Gets the number of classes.
*
* @return the number of classes
*/
public int size() {
return m_ClassNames.length;
}
/**
* Gets the name of one of the classes.
*
* @param index the index of the class.
* @return the class name.
*/
public String className(int index) {
return m_ClassNames[index];
}
/**
* Includes a prediction in the confusion matrix.
*
* @param pred the NominalPrediction to include
* @exception Exception if no valid prediction was made (i.e. unclassified).
*/
public void addPrediction(NominalPrediction pred) throws Exception {
if (pred.predicted() == NominalPrediction.MISSING_VALUE) {
throw new Exception("No predicted value given.");
}
if (pred.actual() == NominalPrediction.MISSING_VALUE) {
throw new Exception("No actual value given.");
}
set((int) pred.actual(), (int) pred.predicted(),
get((int) pred.actual(), (int) pred.predicted()) + pred.weight());
}
/**
* Includes a whole bunch of predictions in the confusion matrix.
*
* @param predictions a FastVector containing the NominalPredictions to
* include
* @exception Exception if no valid prediction was made (i.e. unclassified).
*/
public void addPredictions(ArrayList predictions)
throws Exception {
for (int i = 0; i < predictions.size(); i++) {
addPrediction((NominalPrediction) predictions.get(i));
}
}
/**
* Gets the performance with respect to one of the classes as a TwoClassStats
* object.
*
* @param classIndex the index of the class of interest.
* @return the generated TwoClassStats object.
*/
public TwoClassStats getTwoClassStats(int classIndex) {
double fp = 0, tp = 0, fn = 0, tn = 0;
for (int row = 0; row < size(); row++) {
for (int col = 0; col < size(); col++) {
if (row == classIndex) {
if (col == classIndex) {
tp += get(row, col);
} else {
fn += get(row, col);
}
} else {
if (col == classIndex) {
fp += get(row, col);
} else {
tn += get(row, col);
}
}
}
}
return new TwoClassStats(tp, fp, tn, fn);
}
/**
* Gets the number of correct classifications (that is, for which a correct
* prediction was made). (Actually the sum of the weights of these
* classifications)
*
* @return the number of correct classifications
*/
public double correct() {
double correct = 0;
for (int i = 0; i < size(); i++) {
correct += get(i, i);
}
return correct;
}
/**
* Gets the number of incorrect classifications (that is, for which an
* incorrect prediction was made). (Actually the sum of the weights of these
* classifications)
*
* @return the number of incorrect classifications
*/
public double incorrect() {
double incorrect = 0;
for (int row = 0; row < size(); row++) {
for (int col = 0; col < size(); col++) {
if (row != col) {
incorrect += get(row, col);
}
}
}
return incorrect;
}
/**
* Gets the number of predictions that were made (actually the sum of the
* weights of predictions where the class value was known).
*
* @return the number of predictions with known class
*/
public double total() {
double total = 0;
for (int row = 0; row < size(); row++) {
for (int col = 0; col < size(); col++) {
total += get(row, col);
}
}
return total;
}
/**
* Returns the estimated error rate.
*
* @return the estimated error rate (between 0 and 1).
*/
public double errorRate() {
return incorrect() / total();
}
/**
* Calls toString() with a default title.
*
* @return the confusion matrix as a string
*/
@Override
public String toString() {
return toString("=== Confusion Matrix ===\n");
}
/**
* Outputs the performance statistics as a classification confusion matrix.
* For each class value, shows the distribution of predicted class values.
*
* @param title the title for the confusion matrix
* @return the confusion matrix as a String
*/
public String toString(String title) {
StringBuffer text = new StringBuffer();
char[] IDChars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' };
int IDWidth;
boolean fractional = false;
// Find the maximum value in the matrix
// and check for fractional display requirement
double maxval = 0;
for (int i = 0; i < size(); i++) {
for (int j = 0; j < size(); j++) {
double current = get(i, j);
if (current < 0) {
current *= -10;
}
if (current > maxval) {
maxval = current;
}
double fract = current - Math.rint(current);
if (!fractional && ((Math.log(fract) / Math.log(10)) >= -2)) {
fractional = true;
}
}
}
IDWidth = 1 + Math.max(
(int) (Math.log(maxval) / Math.log(10) + (fractional ? 3 : 0)),
(int) (Math.log(size()) / Math.log(IDChars.length)));
text.append(title).append("\n");
for (int i = 0; i < size(); i++) {
if (fractional) {
text.append(" ").append(num2ShortID(i, IDChars, IDWidth - 3))
.append(" ");
} else {
text.append(" ").append(num2ShortID(i, IDChars, IDWidth));
}
}
text.append(" actual class\n");
for (int i = 0; i < size(); i++) {
for (int j = 0; j < size(); j++) {
text.append(" ").append(
Utils.doubleToString(get(i, j), IDWidth, (fractional ? 2 : 0)));
}
text.append(" | ").append(num2ShortID(i, IDChars, IDWidth)).append(" = ")
.append(m_ClassNames[i]).append("\n");
}
return text.toString();
}
/**
* Method for generating indices for the confusion matrix.
*
* @param num integer to format
* @return the formatted integer as a string
*/
private static String num2ShortID(int num, char[] IDChars, int IDWidth) {
char ID[] = new char[IDWidth];
int i;
for (i = IDWidth - 1; i >= 0; i--) {
ID[i] = IDChars[num % IDChars.length];
num = num / IDChars.length - 1;
if (num < 0) {
break;
}
}
for (i--; i >= 0; i--) {
ID[i] = ' ';
}
return new String(ID);
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10169 $");
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy