edu.stanford.nlp.util.ConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
package edu.stanford.nlp.util;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.HashSet;
import java.util.concurrent.ConcurrentHashMap;
/**
* This implements a confusion table over arbitrary types of class labels. Main
* routines of interest:
* add(guess, gold), increments the guess/gold entry in this cell by 1
* get(guess, gold), returns the number of entries in this cell
* toString(), returns printed form of the table, with marginals and
* contingencies for each class label
*
* Example usage:
* Confusion myConf = new Confusion();
* myConf.add("l1", "l1");
* myConf.add("l1", "l2");
* myConf.add("l2", "l2");
* System.out.println(myConf.toString());
*
* NOTES: - This sorts by the toString() of the guess and gold labels. Thus the
* label.toString() values should be distinct!
*
* @author [email protected]
*
* @param the class label type
*/
public class ConfusionMatrix {
// classification placeholder prefix when drawing in table
private static final String CLASS_PREFIX = "C";
private static final String FORMAT = "#.#####";
protected DecimalFormat format;
private int leftPadSize = 16;
private int delimPadSize = 8;
private boolean useRealLabels = false;
public ConfusionMatrix() {
format = new DecimalFormat(FORMAT);
}
public ConfusionMatrix(Locale locale) {
format = new DecimalFormat(FORMAT, new DecimalFormatSymbols(locale));
}
@Override
public String toString() {
return printTable();
}
/**
* This sets the lefthand side pad width for displaying the text table.
* @param newPadSize
*/
public void setLeftPadSize(int newPadSize) {
this.leftPadSize = newPadSize;
}
/**
* Sets the width used to separate cells in the table.
*/
public void setDelimPadSize(int newPadSize) {
this.delimPadSize = newPadSize;
}
public void setUseRealLabels(boolean useRealLabels) {
this.useRealLabels = useRealLabels;
}
/**
* Contingency table, listing precision ,recall, specificity, and f1 given
* the number of true and false positives, true and false negatives.
*
* @author [email protected]
*
*/
public class Contingency {
private double tp = 0;
private double fp = 0;
private double tn = 0;
private double fn = 0;
private double prec = 0.0;
private double recall = 0.0;
private double spec = 0.0;
private double f1 = 0.0;
public Contingency(int tp_, int fp_, int tn_, int fn_) {
tp = tp_;
fp = fp_;
tn = tn_;
fn = fn_;
prec = tp / (tp + fp);
recall = tp / (tp + fn);
spec = tn / (fp + tn);
f1 = (2 * prec * recall) / (prec + recall);
}
public String toString() {
return StringUtils.join(Arrays.asList("prec=" + (((tp + fp) > 0) ? format.format(prec) : "n/a"),
"recall=" + (((tp + fn) > 0) ? format.format(recall) : "n/a"),
"spec=" + (((fp + tn) > 0) ? format.format(spec) : "n/a"), "f1="
+ (((prec + recall) > 0) ? format.format(f1) : "n/a")),
", ");
}
public double f1(){
return f1;
}
public double precision(){
return prec;
}
public double recall(){
return recall;
}
public double spec(){
return spec;
}
}
private ConcurrentHashMap, Integer> confTable = new ConcurrentHashMap, Integer>();
/**
* Increments the entry for this guess and gold by 1.
*/
public void add(U guess, U gold) {
add(guess, gold, 1);
}
/**
* Increments the entry for this guess and gold by the given increment amount.
*/
public synchronized void add(U guess, U gold, int increment) {
Pair pair = new Pair(guess, gold);
if (confTable.containsKey(pair)) {
confTable.put(pair, confTable.get(pair) + increment);
} else {
confTable.put(pair, increment);
}
}
/**
* Retrieves the number of entries with this guess and gold.
*/
public Integer get(U guess, U gold) {
Pair pair = new Pair(guess, gold);
if (confTable.containsKey(pair)) {
return confTable.get(pair);
} else {
return 0;
}
}
/**
* Returns the set of distinct class labels
* entered into this confusion table.
*/
public Set uniqueLabels() {
HashSet ret = new HashSet();
for (Pair pair : confTable.keySet()) {
ret.add(pair.first());
ret.add(pair.second());
}
return ret;
}
/**
* Returns the contingency table for the given class label, where all other
* class labels are treated as negative.
*/
public Contingency getContingency(U positiveLabel) {
int tp = 0;
int fp = 0;
int tn = 0;
int fn = 0;
for (Pair pair : confTable.keySet()) {
int count = confTable.get(pair);
U guess = pair.first();
U gold = pair.second();
boolean guessP = guess.equals(positiveLabel);
boolean goldP = gold.equals(positiveLabel);
if (guessP && goldP) {
tp += count;
} else if (!guessP && goldP) {
fn += count;
} else if (guessP && !goldP) {
fp += count;
} else {
tn += count;
}
}
return new Contingency(tp, fp, tn, fn);
}
/**
* Returns the current set of unique labels, sorted by their string order.
*/
private List sortKeys() {
Set labels = uniqueLabels();
if (labels.size() == 0) {
return Collections.emptyList();
}
boolean comparable = true;
for (U label : labels) {
if (!(label instanceof Comparable)) {
comparable = false;
break;
}
}
if (comparable) {
List> sorted = Generics.newArrayList();
for (U label : labels) {
sorted.add(ErasureUtils.>uncheckedCast(label));
}
Collections.sort(sorted);
List ret = Generics.newArrayList();
for (Object o : sorted) {
ret.add(ErasureUtils.uncheckedCast(o));
}
return ret;
} else {
ArrayList names = new ArrayList();
HashMap lookup = new HashMap();
for (U label : labels) {
names.add(label.toString());
lookup.put(label.toString(), label);
}
Collections.sort(names);
ArrayList ret = new ArrayList();
for (String name : names) {
ret.add(lookup.get(name));
}
return ret;
}
}
/**
* Marginal over the given gold, or column sum
*/
private Integer goldMarginal(U gold) {
Integer sum = 0;
Set labels = uniqueLabels();
for (U guess : labels) {
sum += get(guess, gold);
}
return sum;
}
/**
* Marginal over given guess, or row sum
*/
private Integer guessMarginal(U guess) {
Integer sum = 0;
Set labels = uniqueLabels();
for (U gold : labels) {
sum += get(guess, gold);
}
return sum;
}
public String getPlaceHolder(int index, U label) {
if (useRealLabels) {
return label.toString();
} else {
return CLASS_PREFIX + (index + 1); // class name
}
}
/**
* Prints the current confusion in table form to a string, with contingency
*/
public String printTable() {
List sortedLabels = sortKeys();
if (confTable.size() == 0) {
return "Empty table!";
}
StringWriter ret = new StringWriter();
// header row (top)
ret.write(StringUtils.padLeft("Guess/Gold", leftPadSize));
for (int i = 0; i < sortedLabels.size(); i++) {
String placeHolder = getPlaceHolder(i, sortedLabels.get(i));
// placeholder
ret.write(StringUtils.padLeft(placeHolder, delimPadSize));
}
ret.write(" Marg. (Guess)");
ret.write("\n");
// Write out contents
for (int guessI = 0; guessI < sortedLabels.size(); guessI++) {
String placeHolder = getPlaceHolder(guessI, sortedLabels.get(guessI));
ret.write(StringUtils.padLeft(placeHolder, leftPadSize));
U guess = sortedLabels.get(guessI);
for (int goldI = 0; goldI < sortedLabels.size(); goldI++) {
U gold = sortedLabels.get(goldI);
Integer value = get(guess, gold);
ret.write(StringUtils.padLeft(value.toString(), delimPadSize));
}
ret.write(StringUtils.padLeft(guessMarginal(guess).toString(), delimPadSize));
ret.write("\n");
}
// Bottom row, write out marginals over golds
ret.write(StringUtils.padLeft("Marg. (Gold)", leftPadSize));
for (int goldI = 0; goldI < sortedLabels.size(); goldI++) {
U gold = sortedLabels.get(goldI);
ret.write(StringUtils.padLeft(goldMarginal(gold).toString(), delimPadSize));
}
// Print out key, along with contingencies
ret.write("\n\n");
for (int labelI = 0; labelI < sortedLabels.size(); labelI++) {
U classLabel = sortedLabels.get(labelI);
String placeHolder = getPlaceHolder(labelI, classLabel);
ret.write(StringUtils.padLeft(placeHolder, leftPadSize));
if (!useRealLabels) {
ret.write(" = ");
ret.write(classLabel.toString());
}
ret.write(StringUtils.padLeft("", delimPadSize));
Contingency contingency = getContingency(classLabel);
ret.write(contingency.toString());
ret.write("\n");
}
return ret.toString();
}
}