com.bbn.bue.common.evaluation.ProvenancedConfusionMatrix Maven / Gradle / Ivy
The newest version!
package com.bbn.bue.common.evaluation;
import com.bbn.bue.common.collections.CollectionUtils;
import com.bbn.bue.common.collections.MapUtils;
import com.bbn.bue.common.symbols.Symbol;
import com.bbn.bue.common.symbols.SymbolUtils;
import com.google.common.annotations.Beta;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import com.google.common.collect.Table.Cell;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.google.common.base.Preconditions.checkNotNull;
/**
* A confusion matrix which tracks the actual identities of all its elements.
*
* Row and column labels may not be null. Cell fillers may not be null. Entries in each cell are
* stored in the order they were added to the builder.
*
* @param What sort of entry to keep in each cell.
* @author rgabbard
*/
@Beta
public final class ProvenancedConfusionMatrix {
private final Table> table;
/**
* The left-hand labels of the confusion matrix.
*/
public Set leftLabels() {
return table.rowKeySet();
}
/**
* The right hand labels of the confusion matrix.
*/
public Set rightLabels() {
return table.columnKeySet();
}
/**
* A list of all the entries occupying the cell {@code (left, right)} of this confusion matrix.
*/
public List cell(final Symbol left, final Symbol right) {
final List cell = table.get(left, right);
if (cell != null) {
return cell;
} else {
return ImmutableList.of();
}
}
/**
* Returns all provenance entries in this matrix, regardless of cell.
*/
public Set entries() {
return FluentIterable.from(table.cellSet())
.transformAndConcat(CollectionUtils.>TableCellValue())
.toSet();
}
/**
* Return a new {@code ProvenancedConfusionMatrix} containing only those provenance entries
* matching the provided predicate.
*/
public ProvenancedConfusionMatrix filteredCopy(Predicate predicate) {
final ImmutableTable.Builder> newTable =
ImmutableTable.builder();
for (final Cell> curCell : table.cellSet()) {
final List newFiller = FluentIterable.from(curCell.getValue())
.filter(predicate).toList();
if (!newFiller.isEmpty()) {
newTable.put(curCell.getRowKey(), curCell.getColumnKey(), newFiller);
}
}
return new ProvenancedConfusionMatrix(newTable.build());
}
/**
* Allows generating "breakdowns" of a provenanced confusion matrix according to some criteria.
* For example, a confusion matrix for an event detection task could be further broken down into
* separate confusion matrices for each event type.
*
* To do this, you specify a signature function mapping from each provenance to some signature
* (e.g. to event types, to genres, etc.). The output will be an {@link
* com.google.common.collect.ImmutableMap} from all observed signatures to {@link
* ProvenancedConfusionMatrix}es consisting of only those provenances with
* the corresponding signature under the provided function.
*
* The signature function may never return a signature of {@code null}.
*
* {@code keyOrder} is the order the keys should be in the iteration order of the resulting map.
*/
public BrokenDownProvenancedConfusionMatrix
breakdown(Function super CellFiller, SignatureType> signatureFunction,
Ordering keyOrdering) {
final Map> ret = Maps.newHashMap();
// a more efficient implementation should be used if the confusion matrix is
// large and sparse, but this is unlikely. ~ rgabbard
for (final Symbol leftLabel : leftLabels()) {
for (final Symbol rightLabel : rightLabels()) {
for (final CellFiller provenance : cell(leftLabel, rightLabel)) {
final SignatureType signature = signatureFunction.apply(provenance);
checkNotNull(signature, "Provenance function may never return null");
if (!ret.containsKey(signature)) {
ret.put(signature, ProvenancedConfusionMatrix.builder());
}
ret.get(signature).record(leftLabel, rightLabel, provenance);
}
}
}
final ImmutableMap.Builder> trueRet =
ImmutableMap.builder();
// to get consistent output, we make sure to sort by the keys
for (final Map.Entry> entry :
MapUtils.>byKeyOrdering(keyOrdering).sortedCopy(
ret.entrySet())) {
trueRet.put(entry.getKey(), entry.getValue().build());
}
return BrokenDownProvenancedConfusionMatrix.fromMap(trueRet.build());
}
public SummaryConfusionMatrix buildSummaryMatrix() {
final SummaryConfusionMatrices.Builder builder = SummaryConfusionMatrices.builder();
for (final Cell> cell : table.cellSet()) {
builder.accumulate(cell.getRowKey(), cell.getColumnKey(), cell.getValue().size());
}
return builder.build();
}
private String prettyPrint(Ordering labelOrdering, Optional extends Ordering super CellFiller>> fillerOrdering) {
final StringBuilder sb = new StringBuilder();
final List sortedColumns = labelOrdering.sortedCopy(table.columnKeySet());
for (final Symbol rowLabel : labelOrdering.sortedCopy(table.rowKeySet())) {
for (final Symbol colLabel : sortedColumns) {
if (table.contains(rowLabel, colLabel)) {
sb.append(String.format(" =============== %s / %s ==============\n", rowLabel, colLabel));
final Iterable orderedFillers;
if (fillerOrdering.isPresent()) {
orderedFillers = fillerOrdering.get().sortedCopy(table.get(rowLabel, colLabel));
} else {
orderedFillers = table.get(rowLabel, colLabel);
}
for (final CellFiller filler : orderedFillers) {
sb.append("\n\t").append(filler.toString());
}
sb.append("\n");
}
}
}
return sb.toString();
}
public String prettyPrint() {
return prettyPrint(SymbolUtils.byStringOrdering(), Optional.>absent());
}
public String prettyPrintWithFillerOrdering(Ordering super CellFiller> cellFillerOrdering) {
return prettyPrint(SymbolUtils.byStringOrdering(), Optional.of(cellFillerOrdering));
}
/**
* Generate an object which will let you create a confusion matrix.
*/
public static Builder builder() {
return new Builder();
}
private ProvenancedConfusionMatrix(final Table> table) {
final ImmutableTable.Builder> builder =
ImmutableTable.builder();
for (final Cell> cell : table.cellSet()) {
builder.put(cell.getRowKey(), cell.getColumnKey(), ImmutableList.copyOf(cell.getValue()));
}
this.table = builder.build();
}
public static Function,
SummaryConfusionMatrix> ToSummaryMatrix() {
return new Function, SummaryConfusionMatrix>() {
@Override
public SummaryConfusionMatrix apply(ProvenancedConfusionMatrix input) {
return input.buildSummaryMatrix();
}
};
}
public static class Builder {
private Builder() {
}
/**
* Add the specified {@code filler} to cell {@code (left, right)} of this confusion matrix being
* built.
*/
public void record(final Symbol left, final Symbol right, final CellFiller filler) {
if (!tableBuilder.contains(left, right)) {
tableBuilder.put(left, right, Lists.newArrayList());
}
tableBuilder.get(left, right).add(filler);
}
/**
* This is an alias for {@link #record(com.bbn.bue.common.symbols.Symbol,
* com.bbn.bue.common.symbols.Symbol, Object)} you can use to make your code clearer, since the
* predicted value is assumed to be on the rows for F-Measure calculations, etc.
*/
public void recordPredictedGold(final Symbol left, final Symbol right,
final CellFiller filler) {
record(left, right, filler);
}
public ProvenancedConfusionMatrix build() {
return new ProvenancedConfusionMatrix(tableBuilder);
}
private final Table> tableBuilder = HashBasedTable.create();
public void accumulate(ProvenancedConfusionMatrix matrix) {
for (final Table.Cell> cell : matrix.table.cellSet()) {
if (!tableBuilder.contains(cell.getRowKey(), cell.getColumnKey())) {
tableBuilder.put(cell.getRowKey(), cell.getColumnKey(), Lists.newArrayList());
}
tableBuilder.get(cell.getRowKey(), cell.getColumnKey()).addAll(cell.getValue());
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy