edu.stanford.nlp.util.ConfusionMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
The newest version!
package edu.stanford.nlp.util;
import javax.swing.*;
import java.awt.*;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.*;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* This implements a confusion table over arbitrary types of class labels. Main
* routines of interest:
* add(guess, gold), increments the guess/gold entry in this cell by 1
* get(guess, gold), returns the number of entries in this cell
* toString(), returns printed form of the table, with marginals and
* contingencies for each class label
*
* Example usage:
* Confusion myConf = new Confusion();
* myConf.add("l1", "l1");
* myConf.add("l1", "l2");
* myConf.add("l2", "l2");
* System.out.println(myConf.toString());
*
* NOTES: - This sorts by the toString() of the guess and gold labels. Thus the
* label.toString() values should be distinct!
*
* @author [email protected]
*
* @param the class label type
*/
public class ConfusionMatrix {
// classification placeholder prefix when drawing in table
private static final String CLASS_PREFIX = "C";
private static final String FORMAT = "#.#####";
protected DecimalFormat format;
private int leftPadSize = 16;
private int delimPadSize = 8;
private boolean useRealLabels = false;
public ConfusionMatrix() {
format = new DecimalFormat(FORMAT);
}
public ConfusionMatrix(Locale locale) {
format = new DecimalFormat(FORMAT, new DecimalFormatSymbols(locale));
}
@Override
public String toString() {
return printTable();
}
/**
* This sets the lefthand side pad width for displaying the text table.
* @param newPadSize
*/
public void setLeftPadSize(int newPadSize) {
this.leftPadSize = newPadSize;
}
/**
* Sets the width used to separate cells in the table.
*/
public void setDelimPadSize(int newPadSize) {
this.delimPadSize = newPadSize;
}
public void setUseRealLabels(boolean useRealLabels) {
this.useRealLabels = useRealLabels;
}
/**
* Contingency table, listing precision ,recall, specificity, and f1 given
* the number of true and false positives, true and false negatives.
*
* @author [email protected]
*
*/
public class Contingency {
private double tp = 0;
private double fp = 0;
private double tn = 0;
private double fn = 0;
private double prec = 0.0;
private double recall = 0.0;
private double spec = 0.0;
private double f1 = 0.0;
public Contingency(int tp_, int fp_, int tn_, int fn_) {
tp = tp_;
fp = fp_;
tn = tn_;
fn = fn_;
prec = tp / (tp + fp);
recall = tp / (tp + fn);
spec = tn / (fp + tn);
f1 = (2 * prec * recall) / (prec + recall);
}
public String toString() {
return StringUtils.join(Arrays.asList("prec=" + (((tp + fp) > 0) ? format.format(prec) : "n/a"),
"recall=" + (((tp + fn) > 0) ? format.format(recall) : "n/a"),
"spec=" + (((fp + tn) > 0) ? format.format(spec) : "n/a"), "f1="
+ (((prec + recall) > 0) ? format.format(f1) : "n/a")),
", ");
}
public double f1(){
return f1;
}
public double precision(){
return prec;
}
public double recall(){
return recall;
}
public double spec(){
return spec;
}
}
private ConcurrentHashMap, Integer> confTable = new ConcurrentHashMap<>();
/**
* Increments the entry for this guess and gold by 1.
*/
public void add(U guess, U gold) {
add(guess, gold, 1);
}
/**
* Increments the entry for this guess and gold by the given increment amount.
*/
public synchronized void add(U guess, U gold, int increment) {
Pair pair = new Pair<>(guess, gold);
if (confTable.containsKey(pair)) {
confTable.put(pair, confTable.get(pair) + increment);
} else {
confTable.put(pair, increment);
}
}
/**
* Retrieves the number of entries with this guess and gold.
*/
public Integer get(U guess, U gold) {
Pair pair = new Pair<>(guess, gold);
if (confTable.containsKey(pair)) {
return confTable.get(pair);
} else {
return 0;
}
}
/**
* Returns the set of distinct class labels
* entered into this confusion table.
*/
public Set uniqueLabels() {
HashSet ret = new HashSet<>();
for (Pair pair : confTable.keySet()) {
ret.add(pair.first());
ret.add(pair.second());
}
return ret;
}
/**
* Returns the contingency table for the given class label, where all other
* class labels are treated as negative.
*/
public Contingency getContingency(U positiveLabel) {
int tp = 0;
int fp = 0;
int tn = 0;
int fn = 0;
for (Pair pair : confTable.keySet()) {
int count = confTable.get(pair);
U guess = pair.first();
U gold = pair.second();
boolean guessP = guess.equals(positiveLabel);
boolean goldP = gold.equals(positiveLabel);
if (guessP && goldP) {
tp += count;
} else if (!guessP && goldP) {
fn += count;
} else if (guessP && !goldP) {
fp += count;
} else {
tn += count;
}
}
return new Contingency(tp, fp, tn, fn);
}
/**
* Returns the current set of unique labels, sorted by their string order.
*/
private List sortKeys() {
Set labels = uniqueLabels();
if (labels.size() == 0) {
return Collections.emptyList();
}
boolean comparable = true;
for (U label : labels) {
if (!(label instanceof Comparable)) {
comparable = false;
break;
}
}
if (comparable) {
List> sorted = Generics.newArrayList();
for (U label : labels) {
sorted.add(ErasureUtils.>uncheckedCast(label));
}
Collections.sort(sorted);
List ret = Generics.newArrayList();
for (Object o : sorted) {
ret.add(ErasureUtils.uncheckedCast(o));
}
return ret;
} else {
ArrayList names = new ArrayList<>();
HashMap lookup = new HashMap<>();
for (U label : labels) {
names.add(label.toString());
lookup.put(label.toString(), label);
}
Collections.sort(names);
ArrayList ret = new ArrayList<>();
for (String name : names) {
ret.add(lookup.get(name));
}
return ret;
}
}
/**
* Marginal over the given gold, or column sum
*/
private Integer goldMarginal(U gold) {
Integer sum = 0;
Set labels = uniqueLabels();
for (U guess : labels) {
sum += get(guess, gold);
}
return sum;
}
/**
* Marginal over given guess, or row sum
*/
private Integer guessMarginal(U guess) {
Integer sum = 0;
Set labels = uniqueLabels();
for (U gold : labels) {
sum += get(guess, gold);
}
return sum;
}
private String getPlaceHolder(int index, U label) {
if (useRealLabels) {
return label.toString();
} else {
return CLASS_PREFIX + (index + 1); // class name
}
}
/**
* Prints the current confusion in table form to a string, with contingency
*/
public String printTable() {
List sortedLabels = sortKeys();
if (confTable.size() == 0) {
return "Empty table!";
}
StringWriter ret = new StringWriter();
// header row (top)
ret.write(StringUtils.padLeft("Guess/Gold", leftPadSize));
for (int i = 0; i < sortedLabels.size(); i++) {
String placeHolder = getPlaceHolder(i, sortedLabels.get(i));
// placeholder
ret.write(StringUtils.padLeft(placeHolder, delimPadSize));
}
ret.write(" Marg. (Guess)");
ret.write("\n");
// Write out contents
for (int guessI = 0; guessI < sortedLabels.size(); guessI++) {
String placeHolder = getPlaceHolder(guessI, sortedLabels.get(guessI));
ret.write(StringUtils.padLeft(placeHolder, leftPadSize));
U guess = sortedLabels.get(guessI);
for (U gold : sortedLabels) {
Integer value = get(guess, gold);
ret.write(StringUtils.padLeft(value.toString(), delimPadSize));
}
ret.write(StringUtils.padLeft(guessMarginal(guess).toString(), delimPadSize));
ret.write("\n");
}
// Bottom row, write out marginals over golds
ret.write(StringUtils.padLeft("Marg. (Gold)", leftPadSize));
for (U gold : sortedLabels) {
ret.write(StringUtils.padLeft(goldMarginal(gold).toString(), delimPadSize));
}
// Print out key, along with contingencies
ret.write("\n\n");
for (int labelI = 0; labelI < sortedLabels.size(); labelI++) {
U classLabel = sortedLabels.get(labelI);
String placeHolder = getPlaceHolder(labelI, classLabel);
ret.write(StringUtils.padLeft(placeHolder, leftPadSize));
if (!useRealLabels) {
ret.write(" = ");
ret.write(classLabel.toString());
}
ret.write(StringUtils.padLeft("", delimPadSize));
Contingency contingency = getContingency(classLabel);
ret.write(contingency.toString());
ret.write("\n");
}
return ret.toString();
}
private class ConfusionGrid extends Canvas {
public class Grid extends JPanel {
private int columnCount = uniqueLabels().size() + 1;
private int rowCount = uniqueLabels().size() + 1;
private List cells;
private Point selectedCell;
public Grid() {
cells = new ArrayList<>(columnCount * rowCount);
MouseAdapter mouseHandler;
mouseHandler = new MouseAdapter() {
@Override
public void mouseMoved(MouseEvent e) {
int width = getWidth();
int height = getHeight();
int cellWidth = width / columnCount;
int cellHeight = height / rowCount;
int column = e.getX() / cellWidth;
int row = e.getY() / cellHeight;
selectedCell = new Point(column, row);
repaint();
}
};
addMouseMotionListener(mouseHandler);
}
public void onMouseOver(Graphics2D g2d, Rectangle cell, U guess, U gold) {
// Compute values
int x = (int) (cell.getLocation().x + cell.getWidth() / 5.0);
int y = (int) ( cell.getLocation().y + cell.getHeight() / 5.0);
// Compute the text
Integer value = confTable.get(Pair.makePair(guess, gold));
if (value == null) { value = 0; }
String text = "Guess: " + guess.toString() + "\n" +
"Gold: " + gold.toString() + "\n" +
"Value: " + value;
// Set the font
Font bak = g2d.getFont();
g2d.setFont(bak.deriveFont(bak.getSize() * 2.0f));
// Render
g2d.setColor(Color.WHITE);
g2d.fill(cell);
g2d.setColor(Color.BLACK);
for (String line : text.split("\n")) {
g2d.drawString(line, x, y += g2d.getFontMetrics().getHeight());
}
// Reset
g2d.setFont(bak);
}
@Override
public Dimension getPreferredSize() {
return new Dimension(800, 800);
}
@Override
public void invalidate() {
cells.clear();
super.invalidate();
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
// Dimensions
Graphics2D g2d = (Graphics2D) g.create();
g.setFont(new Font("Arial", Font.PLAIN, 10));
int width = getWidth();
int height = getHeight();
int cellWidth = width / columnCount;
int cellHeight = height / rowCount;
int xOffset = (width - (columnCount * cellWidth)) / 2;
int yOffset = (height - (rowCount * cellHeight)) / 2;
// Get label index
List labels = uniqueLabels().stream().collect(Collectors.toList());
// Get color gradient
int maxDiag = 0;
int maxOffdiag = 0;
for (Map.Entry, Integer> entry : confTable.entrySet()) {
if (entry.getKey().first == entry.getKey().second) {
maxDiag = Math.max(maxDiag, entry.getValue());
} else {
maxOffdiag = Math.max(maxOffdiag, entry.getValue());
}
}
// Render the grid
float[] hsb = new float[3];
for (int row = 0; row < rowCount; row++) {
for (int col = 0; col < columnCount; col++) {
// Position
int x = xOffset + (col * cellWidth);
int y = yOffset + (row * cellHeight);
float xCenter = xOffset + (col * cellWidth) + cellWidth / 3.0f;
float yCenter = yOffset + (row * cellHeight) + cellHeight / 2.0f;
// Get text + Color
String text;
Color bg = Color.WHITE;
if (row == 0 && col == 0) {
text = "V guess | gold >";
} else if (row == 0) {
text = labels.get(col - 1).toString();
} else if (col == 0) {
text = labels.get(row - 1).toString();
} else {
// Set value
Integer count = confTable.get(Pair.makePair(labels.get(row - 1), labels.get(col - 1)));
if (count == null) {
count = 0;
}
text = "" + count;
// Get color
if (row == col) {
double percentGood = ((double) count) / ((double) maxDiag);
hsb = Color.RGBtoHSB(
(int) (255 - (255.0 * percentGood)),
(int) (255 - (255.0 * percentGood / 2.0)),
(int) (255 - (255.0 * percentGood)),
hsb
);
bg = Color.getHSBColor(hsb[0], hsb[1], hsb[2]);
} else {
double percentBad = ((double) count) / ((double) maxOffdiag);
hsb = Color.RGBtoHSB(
(int) (255 - (255.0 * percentBad / 2.0)),
(int) (255 - (255.0 * percentBad)),
(int) (255 - (255.0 * percentBad)),
hsb
);
bg = Color.getHSBColor(hsb[0], hsb[1], hsb[2]);
}
}
// Draw
Rectangle cell = new Rectangle(x, y, cellWidth, cellHeight);
g2d.setColor(bg);
g2d.fill(cell);
g2d.setColor(Color.BLACK);
g2d.drawString(text, xCenter, yCenter);
cells.add(cell);
}
}
// Mouse over
if (selectedCell != null && selectedCell.x > 0 && selectedCell.y > 0) {
int index = selectedCell.x + (selectedCell.y * columnCount);
Rectangle cell = cells.get(index);
onMouseOver(g2d, cell, labels.get(selectedCell.y - 1), labels.get(selectedCell.x - 1));
}
// Clean up
g2d.dispose();
}
}
public ConfusionGrid() {
EventQueue.invokeLater(() -> {
try {
UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException | UnsupportedLookAndFeelException ignored) {
}
JFrame frame = new JFrame("Confusion Matrix");
frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
frame.setLayout(new BorderLayout());
frame.add(new Grid());
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
});
}
}
/**
* Show the confusion matrix in a GUI.
*/
public void gui() {
ConfusionGrid gui = new ConfusionGrid();
gui.setVisible(true);
}
public static void main(String[] args) {
ConfusionMatrix confusion = new ConfusionMatrix<>();
confusion.add("a", "a");
confusion.add("a", "b");
confusion.add("b", "a");
confusion.add("a", "a");
confusion.add("b", "b");
confusion.add("b", "b");
confusion.add("a", "b");
confusion.gui();
}
}