edu.stanford.nlp.stats.TwoDimensionalIntCounter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
The newest version!
package edu.stanford.nlp.stats;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableInteger;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
/**
* A class representing a mapping between pairs of typed objects and int values.
* (Copied from TwoDimensionalCounter)
*
* @author Teg Grenager
* @author Angel Chang
*/
public class TwoDimensionalIntCounter implements Serializable {
private static final long serialVersionUID = 1L;
// the outermost Map
private Map> map;
// the total of all counts
private int total;
// the MapFactory used to make new maps to counters
private MapFactory> outerMF;
// the MapFactory used to make new maps in the inner counter
private MapFactory innerMF;
private int defaultValue = 0;
public void defaultReturnValue(double rv) { defaultValue = (int) rv; }
public void defaultReturnValue(int rv) { defaultValue = rv; }
public int defaultReturnValue() { return defaultValue; }
@Override
public boolean equals(Object o) {
if (o == this) return true;
if (!(o instanceof TwoDimensionalIntCounter)) return false;
return ((TwoDimensionalIntCounter,?>) o).map.equals(map);
}
@Override
public int hashCode() {
return map.hashCode() + 17;
}
/**
* @return the inner Counter associated with key o
*/
public IntCounter getCounter(K1 o) {
IntCounter c = map.get(o);
if (c == null) {
c = new IntCounter<>(innerMF);
c.setDefaultReturnValue(defaultValue);
map.put(o, c);
}
return c;
}
public Set>> entrySet(){
return map.entrySet();
}
/**
* @return total number of entries (key pairs)
*/
public int size() {
int result = 0;
for (K1 o : firstKeySet()) {
IntCounter c = map.get(o);
result += c.size();
}
return result;
}
public boolean containsKey(K1 o1, K2 o2) {
if (!map.containsKey(o1)) return false;
IntCounter c = map.get(o1);
return c.containsKey(o2);
}
/**
*/
public void incrementCount(K1 o1, K2 o2) {
incrementCount(o1, o2, 1);
}
/**
*/
public void incrementCount(K1 o1, K2 o2, double count) {
incrementCount(o1, o2, (int) count);
}
/**
*/
public void incrementCount(K1 o1, K2 o2, int count) {
IntCounter c = getCounter(o1);
c.incrementCount(o2, count);
total += count;
}
/**
*/
public void decrementCount(K1 o1, K2 o2) {
incrementCount(o1, o2, -1);
}
/**
*/
public void decrementCount(K1 o1, K2 o2, double count) {
incrementCount(o1, o2, -count);
}
/**
*/
public void decrementCount(K1 o1, K2 o2, int count) {
incrementCount(o1, o2, -count);
}
/**
*/
public void setCount(K1 o1, K2 o2, double count) {
setCount(o1, o2, (int) count);
}
/**
*/
public void setCount(K1 o1, K2 o2, int count) {
IntCounter c = getCounter(o1);
int oldCount = getCount(o1, o2);
total -= oldCount;
c.setCount(o2, count);
total += count;
}
public int remove(K1 o1, K2 o2) {
IntCounter c = getCounter(o1);
int oldCount = getCount(o1, o2);
total -= oldCount;
c.remove(o2);
if (c.isEmpty()) {
map.remove(o1);
}
return oldCount;
}
/**
*/
public int getCount(K1 o1, K2 o2) {
IntCounter c = getCounter(o1);
if (c.totalCount() == 0 && !c.keySet().contains(o2)) { return defaultReturnValue(); }
return c.getIntCount(o2);
}
/**
* Takes linear time.
*
*/
public int totalCount() {
return total;
}
/**
*/
public int totalCount(K1 k1) {
IntCounter c = getCounter(k1);
return c.totalIntCount();
}
public IntCounter totalCounts() {
IntCounter tc = new IntCounter<>();
for (K1 k1:map.keySet()) {
tc.setCount(k1, map.get(k1).totalCount());
}
return tc;
}
public Set firstKeySet() {
return map.keySet();
}
/**
* replace the counter for K1-index o by new counter c
*/
public IntCounter setCounter(K1 o, IntCounter c) {
IntCounter old = getCounter(o);
total -= old.totalIntCount();
map.put(o, c);
total += c.totalIntCount();
return old;
}
/**
* Produces a new ConditionalCounter.
*
* @return a new ConditionalCounter, where order of indices is reversed
*/
@SuppressWarnings({"unchecked"})
public static TwoDimensionalIntCounter reverseIndexOrder(TwoDimensionalIntCounter cc) {
// the typing on the outerMF is violated a bit, but it'll work....
TwoDimensionalIntCounter result = new TwoDimensionalIntCounter<>(
(MapFactory) cc.outerMF, (MapFactory) cc.innerMF);
for (K1 key1 : cc.firstKeySet()) {
IntCounter c = cc.getCounter(key1);
for (K2 key2 : c.keySet()) {
int count = c.getIntCount(key2);
result.setCount(key2, key1, count);
}
}
return result;
}
/**
* A simple String representation of this TwoDimensionalCounter, which has
* the String representation of each key pair
* on a separate line, followed by the count for that pair.
* The items are tab separated, so the result is a tab-separated value (TSV)
* file. Iff none of the keys contain spaces, it will also be possible to
* treat this as whitespace separated fields.
*/
@Override
public String toString() {
StringBuilder buff = new StringBuilder();
for (K1 key1 : map.keySet()) {
IntCounter c = getCounter(key1);
for (K2 key2 : c.keySet()) {
double score = c.getCount(key2);
buff.append(key1).append("\t").append(key2).append("\t").append(score).append("\n");
}
}
return buff.toString();
}
@SuppressWarnings({"unchecked"})
public String toMatrixString(int cellSize) {
List firstKeys = new ArrayList<>(firstKeySet());
List secondKeys = new ArrayList<>(secondKeySet());
Collections.sort((List extends Comparable>)firstKeys);
Collections.sort((List extends Comparable>)secondKeys);
int[][] counts = toMatrix(firstKeys, secondKeys);
return ArrayMath.toString(counts, firstKeys.toArray(), secondKeys.toArray(), cellSize, cellSize, new DecimalFormat(), true);
}
/**
* Given an ordering of the first (row) and second (column) keys, will produce a double matrix.
*
*/
public int[][] toMatrix(List firstKeys, List secondKeys) {
int[][] counts = new int[firstKeys.size()][secondKeys.size()];
for (int i = 0; i < firstKeys.size(); i++) {
for (int j = 0; j < secondKeys.size(); j++) {
counts[i][j] = getCount(firstKeys.get(i), secondKeys.get(j));
}
}
return counts;
}
@SuppressWarnings({"unchecked"})
public String toCSVString(NumberFormat nf) {
List firstKeys = new ArrayList<>(firstKeySet());
List secondKeys = new ArrayList<>(secondKeySet());
Collections.sort((List extends Comparable>)firstKeys);
Collections.sort((List extends Comparable>)secondKeys);
StringBuilder b = new StringBuilder();
String[] headerRow = new String[secondKeys.size() + 1];
headerRow[0] = "";
for (int j = 0; j < secondKeys.size(); j++) {
headerRow[j + 1] = secondKeys.get(j).toString();
}
b.append(StringUtils.toCSVString(headerRow)).append("\n");
for (K1 rowLabel : firstKeys) {
String[] row = new String[secondKeys.size() + 1];
row[0] = rowLabel.toString();
for (int j = 0; j < secondKeys.size(); j++) {
K2 colLabel = secondKeys.get(j);
row[j + 1] = nf.format(getCount(rowLabel, colLabel));
}
b.append(StringUtils.toCSVString(row)).append("\n");
}
return b.toString();
}
public static , CK2 extends Comparable> String toCSVString(
TwoDimensionalIntCounter counter,
NumberFormat nf, Comparator key1Comparator, Comparator key2Comparator) {
List firstKeys = new ArrayList<>(counter.firstKeySet());
List secondKeys = new ArrayList<>(counter.secondKeySet());
Collections.sort(firstKeys, key1Comparator);
Collections.sort(secondKeys, key2Comparator);
StringBuilder b = new StringBuilder();
int secondKeysSize = secondKeys.size();
String[] headerRow = new String[secondKeysSize + 1];
headerRow[0] = "";
for (int j = 0; j < secondKeysSize; j++) {
headerRow[j + 1] = secondKeys.get(j).toString();
}
b.append(StringUtils.toCSVString(headerRow)).append('\n');
for (CK1 rowLabel : firstKeys) {
String[] row = new String[secondKeysSize + 1];
row[0] = rowLabel.toString();
for (int j = 0; j < secondKeysSize; j++) {
CK2 colLabel = secondKeys.get(j);
row[j + 1] = nf.format(counter.getCount(rowLabel, colLabel));
}
b.append(StringUtils.toCSVString(row)).append('\n');
}
return b.toString();
}
public Set secondKeySet() {
Set result = Generics.newHashSet();
for (K1 k1 : firstKeySet()) {
for (K2 k2 : getCounter(k1).keySet()) {
result.add(k2);
}
}
return result;
}
public boolean isEmpty() {
return map.isEmpty();
}
public IntCounter> flatten() {
IntCounter> result = new IntCounter<>();
result.setDefaultReturnValue(defaultValue);
for (K1 key1 : firstKeySet()) {
IntCounter inner = getCounter(key1);
for (K2 key2 : inner.keySet()) {
result.setCount(new Pair<>(key1, key2), inner.getIntCount(key2));
}
}
return result;
}
public void addAll(TwoDimensionalIntCounter c) {
for (K1 key : c.firstKeySet()) {
IntCounter inner = c.getCounter(key);
IntCounter myInner = getCounter(key);
Counters.addInPlace(myInner, inner);
total += inner.totalIntCount();
}
}
public void addAll(K1 key, IntCounter c) {
IntCounter myInner = getCounter(key);
Counters.addInPlace(myInner, c);
total += c.totalIntCount();
}
public void subtractAll(K1 key, IntCounter c) {
IntCounter myInner = getCounter(key);
Counters.subtractInPlace(myInner, c);
total -= c.totalIntCount();
}
public void subtractAll(TwoDimensionalIntCounter c, boolean removeKeys) {
for (K1 key : c.firstKeySet()) {
IntCounter inner = c.getCounter(key);
IntCounter myInner = getCounter(key);
Counters.subtractInPlace(myInner, inner);
if (removeKeys) {
Counters.retainNonZeros(myInner);
}
total -= inner.totalIntCount();
}
}
public void removeZeroCounts() {
Set firstKeySet = Generics.newHashSet(firstKeySet());
for (K1 k1 : firstKeySet) {
IntCounter c = getCounter(k1);
Counters.retainNonZeros(c);
if (c.isEmpty()) {
map.remove(k1); // it's empty, get rid of it!
}
}
}
public void remove(K1 key) {
IntCounter counter = map.get(key);
if (counter != null) { total -= counter.totalIntCount(); }
map.remove(key);
}
public void clean() {
for (K1 key1 : Generics.newHashSet(map.keySet())) {
IntCounter c = map.get(key1);
for (K2 key2 : Generics.newHashSet(c.keySet())) {
if (c.getIntCount(key2) == 0) {
c.remove(key2);
}
}
if (c.keySet().isEmpty()) {
map.remove(key1);
}
}
}
public MapFactory> getOuterMapFactory() {
return outerMF;
}
public MapFactory getInnerMapFactory() {
return innerMF;
}
public TwoDimensionalIntCounter() {
this(MapFactory.>hashMapFactory(), MapFactory.hashMapFactory());
}
public TwoDimensionalIntCounter(int initialCapacity) {
this(MapFactory.>hashMapFactory(), MapFactory.hashMapFactory(), initialCapacity);
}
public TwoDimensionalIntCounter(MapFactory> outerFactory, MapFactory innerFactory) {
this(outerFactory, innerFactory, 100);
}
public TwoDimensionalIntCounter(MapFactory> outerFactory, MapFactory innerFactory, int initialCapacity) {
innerMF = innerFactory;
outerMF = outerFactory;
map = outerFactory.newMap(initialCapacity);
total = 0;
}
}