All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.bbn.bue.common.evaluation.ProvenancedConfusionMatrix Maven / Gradle / Ivy

The newest version!
package com.bbn.bue.common.evaluation;

import com.bbn.bue.common.collections.CollectionUtils;
import com.bbn.bue.common.collections.MapUtils;
import com.bbn.bue.common.symbols.Symbol;
import com.bbn.bue.common.symbols.SymbolUtils;

import com.google.common.annotations.Beta;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableTable;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import com.google.common.collect.Table.Cell;

import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.google.common.base.Preconditions.checkNotNull;

/**
 * A confusion matrix which tracks the actual identities of all its elements.
 *
 * Row and column labels may not be null.  Cell fillers may not be null. Entries in each cell are
 * stored in the order they were added to the builder.
 *
 * @param  What sort of entry to keep in each cell.
 * @author rgabbard
 */
@Beta
public final class ProvenancedConfusionMatrix {

  private final Table> table;

  /**
   * The left-hand labels of the confusion matrix.
   */
  public Set leftLabels() {
    return table.rowKeySet();
  }

  /**
   * The right hand labels of the confusion matrix.
   */
  public Set rightLabels() {
    return table.columnKeySet();
  }

  /**
   * A list of all the entries occupying the cell {@code (left, right)} of this confusion matrix.
   */
  public List cell(final Symbol left, final Symbol right) {
    final List cell = table.get(left, right);
    if (cell != null) {
      return cell;
    } else {
      return ImmutableList.of();
    }
  }

  /**
   * Returns all provenance entries in this matrix, regardless of cell.
   */
  public Set entries() {
    return FluentIterable.from(table.cellSet())
        .transformAndConcat(CollectionUtils.>TableCellValue())
        .toSet();
  }

  /**
   * Return a new {@code ProvenancedConfusionMatrix} containing only those provenance entries
   * matching the provided predicate.
   */
  public ProvenancedConfusionMatrix filteredCopy(Predicate predicate) {
    final ImmutableTable.Builder> newTable =
        ImmutableTable.builder();

    for (final Cell> curCell : table.cellSet()) {
      final List newFiller = FluentIterable.from(curCell.getValue())
          .filter(predicate).toList();
      if (!newFiller.isEmpty()) {
        newTable.put(curCell.getRowKey(), curCell.getColumnKey(), newFiller);
      }

    }

    return new ProvenancedConfusionMatrix(newTable.build());
  }

  /**
   * Allows generating "breakdowns" of a provenanced confusion matrix according to some criteria.
   * For example, a confusion matrix for an event detection task could be further broken down into
   * separate confusion matrices for each event type.
   *
   * To do this, you specify a signature function mapping from each provenance to some signature
   * (e.g. to event types, to genres, etc.).  The output will be an {@link
   * com.google.common.collect.ImmutableMap} from all observed signatures to {@link
   * ProvenancedConfusionMatrix}es consisting of only those provenances with
   * the corresponding signature under the provided function.
   *
   * The signature function may never return a signature of {@code null}.
   *
   * {@code keyOrder} is the order the keys should be in the iteration order of the resulting map.
   */
  public  BrokenDownProvenancedConfusionMatrix
  breakdown(Function signatureFunction,
      Ordering keyOrdering) {
    final Map> ret = Maps.newHashMap();

    // a more efficient implementation should be used if the confusion matrix is
    // large and sparse, but this is unlikely. ~ rgabbard
    for (final Symbol leftLabel : leftLabels()) {
      for (final Symbol rightLabel : rightLabels()) {
        for (final CellFiller provenance : cell(leftLabel, rightLabel)) {
          final SignatureType signature = signatureFunction.apply(provenance);
          checkNotNull(signature, "Provenance function may never return null");
          if (!ret.containsKey(signature)) {
            ret.put(signature, ProvenancedConfusionMatrix.builder());
          }
          ret.get(signature).record(leftLabel, rightLabel, provenance);
        }
      }
    }

    final ImmutableMap.Builder> trueRet =
        ImmutableMap.builder();
    // to get consistent output, we make sure to sort by the keys
    for (final Map.Entry> entry :
        MapUtils.>byKeyOrdering(keyOrdering).sortedCopy(
            ret.entrySet())) {
      trueRet.put(entry.getKey(), entry.getValue().build());
    }
    return BrokenDownProvenancedConfusionMatrix.fromMap(trueRet.build());
  }

  public SummaryConfusionMatrix buildSummaryMatrix() {
    final SummaryConfusionMatrices.Builder builder = SummaryConfusionMatrices.builder();

    for (final Cell> cell : table.cellSet()) {
      builder.accumulate(cell.getRowKey(), cell.getColumnKey(), cell.getValue().size());
    }

    return builder.build();
  }

  private String prettyPrint(Ordering labelOrdering, Optional> fillerOrdering) {
    final StringBuilder sb = new StringBuilder();

    final List sortedColumns = labelOrdering.sortedCopy(table.columnKeySet());
    for (final Symbol rowLabel : labelOrdering.sortedCopy(table.rowKeySet())) {
      for (final Symbol colLabel : sortedColumns) {
        if (table.contains(rowLabel, colLabel)) {
          sb.append(String.format(" =============== %s / %s ==============\n", rowLabel, colLabel));
          final Iterable orderedFillers;
          if (fillerOrdering.isPresent()) {
            orderedFillers = fillerOrdering.get().sortedCopy(table.get(rowLabel, colLabel));
          } else {
            orderedFillers = table.get(rowLabel, colLabel);
          }

          for (final CellFiller filler : orderedFillers) {
            sb.append("\n\t").append(filler.toString());
          }
          sb.append("\n");
        }
      }
    }

    return sb.toString();
  }

  public String prettyPrint() {
    return prettyPrint(SymbolUtils.byStringOrdering(), Optional.>absent());
  }

  public String prettyPrintWithFillerOrdering(Ordering cellFillerOrdering) {
    return prettyPrint(SymbolUtils.byStringOrdering(), Optional.of(cellFillerOrdering));
  }


  /**
   * Generate an object which will let you create a confusion matrix.
   */
  public static  Builder builder() {
    return new Builder();
  }

  private ProvenancedConfusionMatrix(final Table> table) {
    final ImmutableTable.Builder> builder =
        ImmutableTable.builder();

    for (final Cell> cell : table.cellSet()) {
      builder.put(cell.getRowKey(), cell.getColumnKey(), ImmutableList.copyOf(cell.getValue()));
    }
    this.table = builder.build();
  }

  public static  Function,
      SummaryConfusionMatrix> ToSummaryMatrix() {
    return new Function, SummaryConfusionMatrix>() {
      @Override
      public SummaryConfusionMatrix apply(ProvenancedConfusionMatrix input) {
        return input.buildSummaryMatrix();
      }
    };
  }

  public static class Builder {

    private Builder() {
    }

    /**
     * Add the specified {@code filler} to cell {@code (left, right)} of this confusion matrix being
     * built.
     */
    public void record(final Symbol left, final Symbol right, final CellFiller filler) {
      if (!tableBuilder.contains(left, right)) {
        tableBuilder.put(left, right, Lists.newArrayList());
      }
      tableBuilder.get(left, right).add(filler);
    }

    /**
     * This is an alias for {@link #record(com.bbn.bue.common.symbols.Symbol,
     * com.bbn.bue.common.symbols.Symbol, Object)} you can use to make your code clearer, since the
     * predicted value is assumed to be on the rows for F-Measure calculations, etc.
     */
    public void recordPredictedGold(final Symbol left, final Symbol right,
        final CellFiller filler) {
      record(left, right, filler);
    }

    public ProvenancedConfusionMatrix build() {
      return new ProvenancedConfusionMatrix(tableBuilder);
    }

    private final Table> tableBuilder = HashBasedTable.create();

    public void accumulate(ProvenancedConfusionMatrix matrix) {
      for (final Table.Cell> cell : matrix.table.cellSet()) {
        if (!tableBuilder.contains(cell.getRowKey(), cell.getColumnKey())) {
          tableBuilder.put(cell.getRowKey(), cell.getColumnKey(), Lists.newArrayList());
        }
        tableBuilder.get(cell.getRowKey(), cell.getColumnKey()).addAll(cell.getValue());
      }
    }
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy