All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.isi.nlp.evaluation.SummaryConfusionMatrices Maven / Gradle / Ivy

The newest version!
package edu.isi.nlp.evaluation;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Iterables.all;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Optional;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import edu.isi.nlp.StringUtils;
import edu.isi.nlp.primitives.DoubleUtils;
import edu.isi.nlp.symbols.Symbol;
import edu.isi.nlp.symbols.SymbolUtils;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
 * Utilities for working with {@link SummaryConfusionMatrix}es. In particular, to build a {@link
 * SummaryConfusionMatrix}, use {@link #builder()}.
 *
 * 

Other useful things: computing F-measures ({@link #FMeasureVsAllOthers(SummaryConfusionMatrix, * Symbol)}) and pretty-printing ({@link #prettyPrint(SummaryConfusionMatrix)}. * * @author rgabbard */ public final class SummaryConfusionMatrices { private SummaryConfusionMatrices() { throw new UnsupportedOperationException(); } public static String prettyPrint(SummaryConfusionMatrix m, Ordering labelOrdering) { final StringBuilder sb = new StringBuilder(); for (final Symbol key1 : labelOrdering.sortedCopy(m.leftLabels())) { for (final Symbol key2 : labelOrdering.sortedCopy(m.rightLabels())) { sb.append(String.format("%s / %s: %6.2f\n", key1, key2, m.cell(key1, key2))); } } return sb.toString(); } public static String prettyDelimPrint(final SummaryConfusionMatrix m, final String delimiter) { return prettyDelimPrint(m, delimiter, SymbolUtils.byStringOrdering()); } public static String prettyDelimPrint( final SummaryConfusionMatrix m, final String delimiter, final Ordering labelOrdering) { final Joiner delimJoiner = Joiner.on(delimiter); final ImmutableList.Builder lines = ImmutableList.builder(); final List rowLabels = labelOrdering.sortedCopy(m.leftLabels()); final List columnLabels = labelOrdering.sortedCopy(m.rightLabels()); // Create header final ImmutableList.Builder header = ImmutableList.builder(); header.add("Predicted"); header.addAll(Iterables.transform(columnLabels, SymbolUtils.desymbolizeFunction())); lines.add(delimJoiner.join(header.build())); // Output each line for (final Symbol rowLabel : rowLabels) { final ImmutableList.Builder row = ImmutableList.builder(); row.add(rowLabel.asString()); for (final Symbol columnLabel : columnLabels) { row.add(String.format("%.2f", m.cell(rowLabel, columnLabel))); } lines.add(delimJoiner.join(row.build())); } // Return all lines return StringUtils.unixNewlineJoiner().join(lines.build()); } public static String prettyPrint(SummaryConfusionMatrix m) { return prettyPrint(m, SymbolUtils.byStringOrdering()); } public static final FMeasureCounts FMeasureVsAllOthers( SummaryConfusionMatrix m, final Symbol positiveSymbol) { return FMeasureVsAllOthers(m, ImmutableSet.of(positiveSymbol)); } public static final FMeasureCounts FMeasureVsAllOthers( SummaryConfusionMatrix m, final Set positiveSymbols) { double truePositives = 0; for (final Symbol goodSymbol : positiveSymbols) { for (final Symbol goodSymbol2 : positiveSymbols) { truePositives += m.cell(goodSymbol, goodSymbol2); } } double falsePositives = -truePositives; double falseNegatives = -truePositives; for (final Symbol goodSymbol : positiveSymbols) { falsePositives += m.rowSum(goodSymbol); falseNegatives += m.columnSum(goodSymbol); } return FMeasureCounts.fromTPFPFN(truePositives, falsePositives, falseNegatives); } /** * Returns accuracy, which is defined as the sum of the cells of the form (X,X) over the sum of * all cells. If the sum is 0, 0 is returned. To pretty-print this you probably want to multiply * by 100. */ public static final double accuracy(SummaryConfusionMatrix m) { final double total = m.sumOfallCells(); double matching = 0.0; for (final Symbol key : Sets.intersection(m.leftLabels(), m.rightLabels())) { matching += m.cell(key, key); } return DoubleUtils.XOverYOrZero(matching, total); } /** * Returns the maximum accuracy that would be achieved if a single classification were selected * for all instances. */ public static final double chooseMostCommonRightHandClassAccuracy(SummaryConfusionMatrix m) { final double total = m.sumOfallCells(); double max = 0.0; for (final Symbol right : m.rightLabels()) { max = Math.max(max, m.columnSum(right)); } return DoubleUtils.XOverYOrZero(max, total); } public static final double chooseMostCommonLeftHandClassAccuracy(SummaryConfusionMatrix m) { final double total = m.sumOfallCells(); double max = 0.0; for (final Symbol left : m.leftLabels()) { max = Math.max(max, m.rowSum(left)); } return DoubleUtils.XOverYOrZero(max, total); } public static Builder builder() { return new Builder(); } /** * To build a {@link SummaryConfusionMatrix}, call {@link SummaryConfusionMatrices#builder()}. On * the returned object, call {@link #accumulatePredictedGold(Symbol, Symbol, double)} to record * the number of times a system response corresponds to a gold standard responses for some item. * Typically the double value will be 1.0 unless you are using fractional counts for some reason. * *

When done, call {@link #build()} to get a {@link SummaryConfusionMatrix}. */ public static class Builder { private final Table table = HashBasedTable.create(); public Builder accumulate(final SummaryConfusionMatrix matrix) { matrix.accumulateTo(this); return this; } public Builder accumulate(final Symbol row, final Symbol col, final double val) { final Double cur = table.get(row, col); final double setVal; if (cur != null) { setVal = cur + val; } else { setVal = val; } table.put(row, col, setVal); return this; } /** * This is just an alias for accumulate. However, since the F-measure functions assume the * predictions are on the rows and the gold-standard on the columns, using this method in such * cases and make the code clearer and reduce errors. */ public Builder accumulatePredictedGold( final Symbol prediction, final Symbol gold, final double val) { accumulate(prediction, gold, val); return this; } public SummaryConfusionMatrix build() { // first attemtp the more efficient implementation for the common binary case final Optional binaryImp = BinarySummaryConfusionMatrix.attemptCreate(table); if (binaryImp.isPresent()) { return binaryImp.get(); } else { return new TableBasedSummaryConfusionMatrix(table); } } public static final Function Build = new Function() { @Override public SummaryConfusionMatrix apply(Builder input) { return input.build(); } }; private Builder() {} } } // here be implementation details users don't need to be concerned with class TableBasedSummaryConfusionMatrix implements SummaryConfusionMatrix { private final Table table; @Override public double cell(final Symbol row, final Symbol col) { final Double ret = table.get(row, col); if (ret != null) { return ret; } else { return 0.0; } } /** The left-hand labels of the confusion matrix. */ @Override public Set leftLabels() { return table.rowKeySet(); } /** The right hand labels of the confusion matrix. */ @Override public Set rightLabels() { return table.columnKeySet(); } TableBasedSummaryConfusionMatrix(final Table table) { this.table = ImmutableTable.copyOf(table); checkArgument(all(table.values(), x -> x >= 0)); } @Override public double sumOfallCells() { return DoubleUtils.sum(table.values()); } @Override public double rowSum(Symbol rowSymbol) { return DoubleUtils.sum(table.row(rowSymbol).values()); } @Override public double columnSum(Symbol columnSymbol) { return DoubleUtils.sum(table.column(columnSymbol).values()); } @Override public SummaryConfusionMatrix filteredCopy(CellFilter filter) { final SummaryConfusionMatrices.Builder ret = SummaryConfusionMatrices.builder(); for (final Table.Cell cell : table.cellSet()) { if (filter.keepCell(cell.getRowKey(), cell.getColumnKey())) { ret.accumulate(cell.getRowKey(), cell.getColumnKey(), cell.getValue()); } } return ret.build(); } @Override public SummaryConfusionMatrix copyWithTransformedLabels(Function f) { final SummaryConfusionMatrices.Builder ret = SummaryConfusionMatrices.builder(); for (final Table.Cell cell : table.cellSet()) { ret.accumulate(f.apply(cell.getRowKey()), f.apply(cell.getColumnKey()), cell.getValue()); } return ret.build(); } @Override public void accumulateTo(SummaryConfusionMatrices.Builder builder) { for (final Table.Cell cell : table.cellSet()) { builder.accumulate(cell.getRowKey(), cell.getColumnKey(), cell.getValue()); } } } /** * The special case where there are only two labels is very common, so we provide a much more * efficient implementation for it. This makes a noticeable difference when e.g. doing bootstrap * sampling with many different score breakdowns. */ class BinarySummaryConfusionMatrix implements SummaryConfusionMatrix { private final Symbol key0; private final Symbol key1; private final double[] data; private static final int NOT_PRESENT = -1; BinarySummaryConfusionMatrix(Symbol key0, Symbol key1, double[] data) { checkArgument(key0 != key1); checkArgument(data.length == 4); this.key0 = checkNotNull(key0); this.key1 = checkNotNull(key1); // no defensive copy because we control where this comes from this.data = checkNotNull(data); } public static boolean canUseFor(Table table) { return table.rowKeySet().size() == 2 && table.rowKeySet().equals(table.columnKeySet()); } public static Optional attemptCreate( Table table) { if (canUseFor(table)) { final Iterator keyIt = table.rowKeySet().iterator(); final Symbol key0 = keyIt.next(); final Symbol key1 = keyIt.next(); return Optional.of( new BinarySummaryConfusionMatrix( key0, key1, new double[] { cell(table, key0, key0), cell(table, key0, key1), cell(table, key1, key0), cell(table, key1, key1) })); } else { return Optional.absent(); } } private static double cell(Table table, Symbol row, Symbol col) { final Double val = table.get(row, col); if (val != null) { return val; } else { return 0.0; } } @Override public double cell(Symbol row, Symbol col) { int rowIdx = keyIndex(row); int colIdx = keyIndex(col); if (rowIdx == NOT_PRESENT || colIdx == NOT_PRESENT) { return 0.0; } return data[2 * rowIdx + colIdx]; } @Override public void accumulateTo(SummaryConfusionMatrices.Builder builder) { builder.accumulate(key0, key0, data[0]); builder.accumulate(key0, key1, data[1]); builder.accumulate(key1, key0, data[2]); builder.accumulate(key1, key1, data[3]); } private int keyIndex(Symbol sym) { if (sym == key0) { return 0; } else if (sym == key1) { return 1; } else { return NOT_PRESENT; } } @Override public Set leftLabels() { return ImmutableSet.of(key0, key1); } @Override public Set rightLabels() { return ImmutableSet.of(key0, key1); } @Override public double sumOfallCells() { return DoubleUtils.sum(data); } @Override public double rowSum(Symbol row) { int rowIdx = keyIndex(row); if (NOT_PRESENT == rowIdx) { return 0.0; } return data[2 * rowIdx] + data[2 * rowIdx + 1]; } @Override public double columnSum(Symbol column) { int colIdx = keyIndex(column); if (NOT_PRESENT == colIdx) { return 0.0; } return data[colIdx] + data[colIdx + 2]; } @Override public SummaryConfusionMatrix filteredCopy(CellFilter filter) { final SummaryConfusionMatrices.Builder builder = SummaryConfusionMatrices.builder(); for (final Symbol left : leftLabels()) { for (final Symbol right : rightLabels()) { if (filter.keepCell(left, right)) { builder.accumulate(left, right, cell(left, right)); } } } return builder.build(); } @Override public SummaryConfusionMatrix copyWithTransformedLabels(Function f) { final SummaryConfusionMatrices.Builder builder = SummaryConfusionMatrices.builder(); for (final Symbol left : leftLabels()) { for (final Symbol right : rightLabels()) { builder.accumulate(f.apply(left), f.apply(right), cell(left, right)); } } return builder.build(); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy