All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.stats.TwoDimensionalCounter Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

The newest version!
package edu.stanford.nlp.stats;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;

/**
 * A class representing a mapping between pairs of typed objects and double
 * values.
 *
 * @author Teg Grenager
 */
public class TwoDimensionalCounter implements TwoDimensionalCounterInterface, Serializable {

  private static final long serialVersionUID = 1L;

  // the outermost Map
  private Map> map;

  // the total of all counts
  private double total;

  // the MapFactory used to make new maps to counters
  private MapFactory> outerMF;

  // the MapFactory used to make new maps in the inner counter
  private MapFactory innerMF;

  private double defaultValue = 0.0;

  @Override
  public void defaultReturnValue(double rv) {
    defaultValue = rv;
  }

  @Override
  public double defaultReturnValue() {
    return defaultValue;
  }

  @Override
  public boolean equals(Object o) {
    if (o == this)
      return true;
    if (!(o instanceof TwoDimensionalCounter))
      return false;

    return ((TwoDimensionalCounter) o).map.equals(map);
  }

  @Override
  public int hashCode() {
    return map.hashCode() + 17;
  }

  /**
   * @return the inner Counter associated with key o
   */
  @Override
  public ClassicCounter getCounter(K1 o) {
    ClassicCounter c = map.get(o);
    if (c == null) {
      c = new ClassicCounter<>(innerMF);
      c.setDefaultReturnValue(defaultValue);
      map.put(o, c);
    }
    return c;
  }

  public Set>> entrySet() {
    return map.entrySet();
  }

  /**
   * @return total number of entries (key pairs)
   */
  @Override
  public int size() {
    int result = 0;
    for (K1 o : firstKeySet()) {
      ClassicCounter c = map.get(o);
      result += c.size();
    }
    return result;
  }

  /**
   * @return size of the outer map
   */
  public int sizeOuterMap(){
    return map.size();
  }

  @Override
  public boolean containsKey(K1 o1, K2 o2) {
    if (!map.containsKey(o1))
      return false;
    ClassicCounter c = map.get(o1);
    return c.containsKey(o2);
  }

  public boolean containsFirstKey(K1 o1) {
    return map.containsKey(o1);
  }

  /**
   */
  @Override
  public void incrementCount(K1 o1, K2 o2) {
    incrementCount(o1, o2, 1.0);
  }

  /**
   */
  @Override
  public void incrementCount(K1 o1, K2 o2, double count) {
    ClassicCounter c = getCounter(o1);
    c.incrementCount(o2, count);
    total += count;
  }

  /**
   */
  @Override
  public void decrementCount(K1 o1, K2 o2) {
    incrementCount(o1, o2, -1.0);
  }

  /**
   */
  @Override
  public void decrementCount(K1 o1, K2 o2, double count) {
    incrementCount(o1, o2, -count);
  }

  /**
   */
  @Override
  public void setCount(K1 o1, K2 o2, double count) {
    ClassicCounter c = getCounter(o1);
    double oldCount = getCount(o1, o2);
    total -= oldCount;
    c.setCount(o2, count);
    total += count;
  }

  @Override
  public double remove(K1 o1, K2 o2) {
    ClassicCounter c = getCounter(o1);
    double oldCount = getCount(o1, o2);
    total -= oldCount;
    c.remove(o2);
    if (c.size() == 0) {
      map.remove(o1);
    }
    return oldCount;
  }

  /**
   */
  @Override
  public double getCount(K1 o1, K2 o2) {
    ClassicCounter c = getCounter(o1);
    if (c.totalCount() == 0.0 && !c.keySet().contains(o2)) {
      return defaultReturnValue();
    }
    return c.getCount(o2);
  }

  /**
   * Takes linear time.
   *
   */
  @Override
  public double totalCount() {
    return total;
  }

  /**
   */
  @Override
  public double totalCount(K1 k1) {
    ClassicCounter c = getCounter(k1);
    return c.totalCount();
  }

  @Override
  public Set firstKeySet() {
    return map.keySet();
  }

  /**
   * replace the counter for K1-index o by new counter c
   */
  public ClassicCounter setCounter(K1 o, Counter c) {
    ClassicCounter old = getCounter(o);
    total -= old.totalCount();
    if (c instanceof ClassicCounter) {
      map.put(o, (ClassicCounter) c);
    } else {
      map.put(o, new ClassicCounter<>(c));
    }
    total += c.totalCount();
    return old;
  }

  /**
   * Produces a new ConditionalCounter.
   *
   * @return a new ConditionalCounter, where order of indices is reversed
   */
  @SuppressWarnings( { "unchecked" })
  public static  TwoDimensionalCounter reverseIndexOrder(TwoDimensionalCounter cc) {
    // they typing on the outerMF is violated a bit, but it'll work....
    TwoDimensionalCounter result = new TwoDimensionalCounter<>((MapFactory) cc.outerMF,
            (MapFactory) cc.innerMF);

    for (K1 key1 : cc.firstKeySet()) {
      ClassicCounter c = cc.getCounter(key1);
      for (K2 key2 : c.keySet()) {
        double count = c.getCount(key2);
        result.setCount(key2, key1, count);
      }
    }
    return result;
  }

  /**
   * A simple String representation of this TwoDimensionalCounter, which has the
   * String representation of each key pair on a separate line, followed by the
   * count for that pair. The items are tab separated, so the result is a
   * tab-separated value (TSV) file. Iff none of the keys contain spaces, it
   * will also be possible to treat this as whitespace separated fields.
   */
  @Override
  public String toString() {
    StringBuilder buff = new StringBuilder();
    for (K1 key1 : map.keySet()) {
      ClassicCounter c = getCounter(key1);
      for (K2 key2 : c.keySet()) {
        double score = c.getCount(key2);
        buff.append(key1).append('\t').append(key2).append('\t').append(score).append('\n');
      }
    }
    return buff.toString();
  }

  @Override
  @SuppressWarnings( { "unchecked" })
  public String toMatrixString(int cellSize) {
    return toMatrixString(cellSize, new DecimalFormat());
  }

  @SuppressWarnings( { "unchecked" })
  public String toMatrixString(int cellSize, NumberFormat nf) {
    List firstKeys = new ArrayList<>(firstKeySet());
    List secondKeys = new ArrayList<>(secondKeySet());
    Collections.sort((List) firstKeys);
    Collections.sort((List) secondKeys);
    double[][] counts = toMatrix(firstKeys, secondKeys);
    return ArrayMath.toString(counts, cellSize, firstKeys.toArray(), secondKeys.toArray(), nf, true);
  }

  /**
   * Given an ordering of the first (row) and second (column) keys, will produce
   * a double matrix.
   *
   */
  @Override
  public double[][] toMatrix(List firstKeys, List secondKeys) {
    double[][] counts = new double[firstKeys.size()][secondKeys.size()];
    for (int i = 0; i < firstKeys.size(); i++) {
      for (int j = 0; j < secondKeys.size(); j++) {
        counts[i][j] = getCount(firstKeys.get(i), secondKeys.get(j));
      }
    }
    return counts;
  }

  @Override
  @SuppressWarnings( { "unchecked" })
  public String toCSVString(NumberFormat nf) {
    List firstKeys = new ArrayList<>(firstKeySet());
    List secondKeys = new ArrayList<>(secondKeySet());
    Collections.sort((List) firstKeys);
    Collections.sort((List) secondKeys);
    StringBuilder b = new StringBuilder();
    String[] headerRow = new String[secondKeys.size() + 1];
    headerRow[0] = "";
    for (int j = 0; j < secondKeys.size(); j++) {
      headerRow[j + 1] = secondKeys.get(j).toString();
    }
    b.append(StringUtils.toCSVString(headerRow)).append('\n');
    for (K1 rowLabel : firstKeys) {
      String[] row = new String[secondKeys.size() + 1];
      row[0] = rowLabel.toString();
      for (int j = 0; j < secondKeys.size(); j++) {
        K2 colLabel = secondKeys.get(j);
        row[j + 1] = nf.format(getCount(rowLabel, colLabel));
      }
      b.append(StringUtils.toCSVString(row)).append('\n');
    }
    return b.toString();
  }

  @Override
  public Set secondKeySet() {
    Set result = Generics.newHashSet();
    for (K1 k1 : firstKeySet()) {
      for (K2 k2 : getCounter(k1).keySet()) {
        result.add(k2);
      }
    }
    return result;
  }

  @Override
  public boolean isEmpty() {
    return map.isEmpty();
  }

  public ClassicCounter> flatten() {
    ClassicCounter> result = new ClassicCounter<>();
    result.setDefaultReturnValue(defaultValue);
    for (K1 key1 : firstKeySet()) {
      ClassicCounter inner = getCounter(key1);
      for (K2 key2 : inner.keySet()) {
        result.setCount(new Pair<>(key1, key2), inner.getCount(key2));
      }
    }
    return result;
  }

  public void addAll(TwoDimensionalCounterInterface c) {
    for (K1 key : c.firstKeySet()) {
      Counter inner = c.getCounter(key);
      ClassicCounter myInner = getCounter(key);
      Counters.addInPlace(myInner, inner);
      total += inner.totalCount();
    }
  }

  public void addAll(K1 key, Counter c) {
    ClassicCounter myInner = getCounter(key);
    Counters.addInPlace(myInner, c);
    total += c.totalCount();
  }

  public void subtractAll(K1 key, Counter c) {
    ClassicCounter myInner = getCounter(key);
    Counters.subtractInPlace(myInner, c);
    total -= c.totalCount();
  }

  public void subtractAll(TwoDimensionalCounterInterface c, boolean removeKeys) {
    for (K1 key : c.firstKeySet()) {
      Counter inner = c.getCounter(key);
      ClassicCounter myInner = getCounter(key);
      Counters.subtractInPlace(myInner, inner);
      if (removeKeys)
        Counters.retainNonZeros(myInner);
      total -= inner.totalCount();
    }
  }

  /**
   * Returns the counters with keys as the first key and count as the
   * total count of the inner counter for that key
   *
   * @return counter of type K1
   */
  public Counter sumInnerCounter() {
    Counter summed = new ClassicCounter<>();
    for (K1 key : this.firstKeySet()) {
      summed.incrementCount(key, this.getCounter(key).totalCount());
    }
    return summed;
  }

  public void removeZeroCounts() {
    Set firstKeySet = Generics.newHashSet(firstKeySet());
    for (K1 k1 : firstKeySet) {
      ClassicCounter c = getCounter(k1);
      Counters.retainNonZeros(c);
      if (c.size() == 0)
        map.remove(k1); // it's empty, get rid of it!
    }
  }

  @Override
  public void remove(K1 key) {
    ClassicCounter counter = map.get(key);
    if (counter != null) {
      total -= counter.totalCount();
    }
    map.remove(key);
  }

  /**
   * clears the map, total and default value
   */
  public void clear(){
    map.clear();
    total = 0;
    defaultValue = 0;
  }


  public void clean() {
    for (K1 key1 : Generics.newHashSet(map.keySet())) {
      ClassicCounter c = map.get(key1);
      for (K2 key2 : Generics.newHashSet(c.keySet())) {
        if (SloppyMath.isCloseTo(0.0, c.getCount(key2))) {
          c.remove(key2);
        }
      }
      if (c.keySet().isEmpty()) {
        map.remove(key1);
      }
    }
  }

  public MapFactory> getOuterMapFactory() {
    return outerMF;
  }

  public MapFactory getInnerMapFactory() {
    return innerMF;
  }

  public TwoDimensionalCounter() {
    this(MapFactory.> hashMapFactory(), MapFactory. hashMapFactory());
  }

  public TwoDimensionalCounter(MapFactory> outerFactory,
      MapFactory innerFactory) {
    innerMF = innerFactory;
    outerMF = outerFactory;
    map = outerFactory.newMap();
    total = 0.0;
  }

  public static  TwoDimensionalCounter identityHashMapCounter() {
    return new TwoDimensionalCounter<>(MapFactory.>identityHashMapFactory(), MapFactory.identityHashMapFactory());
  }

  public void recomputeTotal(){
    total = 0;
    for(Entry> c: map.entrySet()){
      total += c.getValue().totalCount();
    }
  }

  public static void main(String[] args) {
    TwoDimensionalCounter cc = new TwoDimensionalCounter<>();
    cc.setCount("a", "c", 1.0);
    cc.setCount("b", "c", 1.0);
    cc.setCount("a", "d", 1.0);
    cc.setCount("a", "d", -1.0);
    cc.setCount("b", "d", 1.0);
    System.out.println(cc);
    cc.incrementCount("b", "d", 1.0);
    System.out.println(cc);
    TwoDimensionalCounter cc2 = TwoDimensionalCounter.reverseIndexOrder(cc);
    System.out.println(cc2);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy