edu.stanford.nlp.stats.TwoDimensionalCounter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
The newest version!
package edu.stanford.nlp.stats;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
/**
* A class representing a mapping between pairs of typed objects and double
* values.
*
* @author Teg Grenager
*/
public class TwoDimensionalCounter implements TwoDimensionalCounterInterface, Serializable {
private static final long serialVersionUID = 1L;
// the outermost Map
private Map> map;
// the total of all counts
private double total;
// the MapFactory used to make new maps to counters
private MapFactory> outerMF;
// the MapFactory used to make new maps in the inner counter
private MapFactory innerMF;
private double defaultValue = 0.0;
@Override
public void defaultReturnValue(double rv) {
defaultValue = rv;
}
@Override
public double defaultReturnValue() {
return defaultValue;
}
@Override
public boolean equals(Object o) {
if (o == this)
return true;
if (!(o instanceof TwoDimensionalCounter))
return false;
return ((TwoDimensionalCounter, ?>) o).map.equals(map);
}
@Override
public int hashCode() {
return map.hashCode() + 17;
}
/**
* @return the inner Counter associated with key o
*/
@Override
public ClassicCounter getCounter(K1 o) {
ClassicCounter c = map.get(o);
if (c == null) {
c = new ClassicCounter<>(innerMF);
c.setDefaultReturnValue(defaultValue);
map.put(o, c);
}
return c;
}
public Set>> entrySet() {
return map.entrySet();
}
/**
* @return total number of entries (key pairs)
*/
@Override
public int size() {
int result = 0;
for (K1 o : firstKeySet()) {
ClassicCounter c = map.get(o);
result += c.size();
}
return result;
}
/**
* @return size of the outer map
*/
public int sizeOuterMap(){
return map.size();
}
@Override
public boolean containsKey(K1 o1, K2 o2) {
if (!map.containsKey(o1))
return false;
ClassicCounter c = map.get(o1);
return c.containsKey(o2);
}
public boolean containsFirstKey(K1 o1) {
return map.containsKey(o1);
}
/**
*/
@Override
public void incrementCount(K1 o1, K2 o2) {
incrementCount(o1, o2, 1.0);
}
/**
*/
@Override
public void incrementCount(K1 o1, K2 o2, double count) {
ClassicCounter c = getCounter(o1);
c.incrementCount(o2, count);
total += count;
}
/**
*/
@Override
public void decrementCount(K1 o1, K2 o2) {
incrementCount(o1, o2, -1.0);
}
/**
*/
@Override
public void decrementCount(K1 o1, K2 o2, double count) {
incrementCount(o1, o2, -count);
}
/**
*/
@Override
public void setCount(K1 o1, K2 o2, double count) {
ClassicCounter c = getCounter(o1);
double oldCount = getCount(o1, o2);
total -= oldCount;
c.setCount(o2, count);
total += count;
}
@Override
public double remove(K1 o1, K2 o2) {
ClassicCounter c = getCounter(o1);
double oldCount = getCount(o1, o2);
total -= oldCount;
c.remove(o2);
if (c.size() == 0) {
map.remove(o1);
}
return oldCount;
}
/**
*/
@Override
public double getCount(K1 o1, K2 o2) {
ClassicCounter c = getCounter(o1);
if (c.totalCount() == 0.0 && !c.keySet().contains(o2)) {
return defaultReturnValue();
}
return c.getCount(o2);
}
/**
* Takes linear time.
*
*/
@Override
public double totalCount() {
return total;
}
/**
*/
@Override
public double totalCount(K1 k1) {
ClassicCounter c = getCounter(k1);
return c.totalCount();
}
@Override
public Set firstKeySet() {
return map.keySet();
}
/**
* replace the counter for K1-index o by new counter c
*/
public ClassicCounter setCounter(K1 o, Counter c) {
ClassicCounter old = getCounter(o);
total -= old.totalCount();
if (c instanceof ClassicCounter) {
map.put(o, (ClassicCounter) c);
} else {
map.put(o, new ClassicCounter<>(c));
}
total += c.totalCount();
return old;
}
/**
* Produces a new ConditionalCounter.
*
* @return a new ConditionalCounter, where order of indices is reversed
*/
@SuppressWarnings( { "unchecked" })
public static TwoDimensionalCounter reverseIndexOrder(TwoDimensionalCounter cc) {
// they typing on the outerMF is violated a bit, but it'll work....
TwoDimensionalCounter result = new TwoDimensionalCounter<>((MapFactory) cc.outerMF,
(MapFactory) cc.innerMF);
for (K1 key1 : cc.firstKeySet()) {
ClassicCounter c = cc.getCounter(key1);
for (K2 key2 : c.keySet()) {
double count = c.getCount(key2);
result.setCount(key2, key1, count);
}
}
return result;
}
/**
* A simple String representation of this TwoDimensionalCounter, which has the
* String representation of each key pair on a separate line, followed by the
* count for that pair. The items are tab separated, so the result is a
* tab-separated value (TSV) file. Iff none of the keys contain spaces, it
* will also be possible to treat this as whitespace separated fields.
*/
@Override
public String toString() {
StringBuilder buff = new StringBuilder();
for (K1 key1 : map.keySet()) {
ClassicCounter c = getCounter(key1);
for (K2 key2 : c.keySet()) {
double score = c.getCount(key2);
buff.append(key1).append('\t').append(key2).append('\t').append(score).append('\n');
}
}
return buff.toString();
}
@Override
@SuppressWarnings( { "unchecked" })
public String toMatrixString(int cellSize) {
return toMatrixString(cellSize, new DecimalFormat());
}
@SuppressWarnings( { "unchecked" })
public String toMatrixString(int cellSize, NumberFormat nf) {
List firstKeys = new ArrayList<>(firstKeySet());
List secondKeys = new ArrayList<>(secondKeySet());
Collections.sort((List extends Comparable>) firstKeys);
Collections.sort((List extends Comparable>) secondKeys);
double[][] counts = toMatrix(firstKeys, secondKeys);
return ArrayMath.toString(counts, cellSize, firstKeys.toArray(), secondKeys.toArray(), nf, true);
}
/**
* Given an ordering of the first (row) and second (column) keys, will produce
* a double matrix.
*
*/
@Override
public double[][] toMatrix(List firstKeys, List secondKeys) {
double[][] counts = new double[firstKeys.size()][secondKeys.size()];
for (int i = 0; i < firstKeys.size(); i++) {
for (int j = 0; j < secondKeys.size(); j++) {
counts[i][j] = getCount(firstKeys.get(i), secondKeys.get(j));
}
}
return counts;
}
@Override
@SuppressWarnings( { "unchecked" })
public String toCSVString(NumberFormat nf) {
List firstKeys = new ArrayList<>(firstKeySet());
List secondKeys = new ArrayList<>(secondKeySet());
Collections.sort((List extends Comparable>) firstKeys);
Collections.sort((List extends Comparable>) secondKeys);
StringBuilder b = new StringBuilder();
String[] headerRow = new String[secondKeys.size() + 1];
headerRow[0] = "";
for (int j = 0; j < secondKeys.size(); j++) {
headerRow[j + 1] = secondKeys.get(j).toString();
}
b.append(StringUtils.toCSVString(headerRow)).append('\n');
for (K1 rowLabel : firstKeys) {
String[] row = new String[secondKeys.size() + 1];
row[0] = rowLabel.toString();
for (int j = 0; j < secondKeys.size(); j++) {
K2 colLabel = secondKeys.get(j);
row[j + 1] = nf.format(getCount(rowLabel, colLabel));
}
b.append(StringUtils.toCSVString(row)).append('\n');
}
return b.toString();
}
@Override
public Set secondKeySet() {
Set result = Generics.newHashSet();
for (K1 k1 : firstKeySet()) {
for (K2 k2 : getCounter(k1).keySet()) {
result.add(k2);
}
}
return result;
}
@Override
public boolean isEmpty() {
return map.isEmpty();
}
public ClassicCounter> flatten() {
ClassicCounter> result = new ClassicCounter<>();
result.setDefaultReturnValue(defaultValue);
for (K1 key1 : firstKeySet()) {
ClassicCounter inner = getCounter(key1);
for (K2 key2 : inner.keySet()) {
result.setCount(new Pair<>(key1, key2), inner.getCount(key2));
}
}
return result;
}
public void addAll(TwoDimensionalCounterInterface c) {
for (K1 key : c.firstKeySet()) {
Counter inner = c.getCounter(key);
ClassicCounter myInner = getCounter(key);
Counters.addInPlace(myInner, inner);
total += inner.totalCount();
}
}
public void addAll(K1 key, Counter c) {
ClassicCounter myInner = getCounter(key);
Counters.addInPlace(myInner, c);
total += c.totalCount();
}
public void subtractAll(K1 key, Counter c) {
ClassicCounter myInner = getCounter(key);
Counters.subtractInPlace(myInner, c);
total -= c.totalCount();
}
public void subtractAll(TwoDimensionalCounterInterface c, boolean removeKeys) {
for (K1 key : c.firstKeySet()) {
Counter inner = c.getCounter(key);
ClassicCounter myInner = getCounter(key);
Counters.subtractInPlace(myInner, inner);
if (removeKeys)
Counters.retainNonZeros(myInner);
total -= inner.totalCount();
}
}
/**
* Returns the counters with keys as the first key and count as the
* total count of the inner counter for that key
*
* @return counter of type K1
*/
public Counter sumInnerCounter() {
Counter summed = new ClassicCounter<>();
for (K1 key : this.firstKeySet()) {
summed.incrementCount(key, this.getCounter(key).totalCount());
}
return summed;
}
public void removeZeroCounts() {
Set firstKeySet = Generics.newHashSet(firstKeySet());
for (K1 k1 : firstKeySet) {
ClassicCounter c = getCounter(k1);
Counters.retainNonZeros(c);
if (c.size() == 0)
map.remove(k1); // it's empty, get rid of it!
}
}
@Override
public void remove(K1 key) {
ClassicCounter counter = map.get(key);
if (counter != null) {
total -= counter.totalCount();
}
map.remove(key);
}
/**
* clears the map, total and default value
*/
public void clear(){
map.clear();
total = 0;
defaultValue = 0;
}
public void clean() {
for (K1 key1 : Generics.newHashSet(map.keySet())) {
ClassicCounter c = map.get(key1);
for (K2 key2 : Generics.newHashSet(c.keySet())) {
if (SloppyMath.isCloseTo(0.0, c.getCount(key2))) {
c.remove(key2);
}
}
if (c.keySet().isEmpty()) {
map.remove(key1);
}
}
}
public MapFactory> getOuterMapFactory() {
return outerMF;
}
public MapFactory getInnerMapFactory() {
return innerMF;
}
public TwoDimensionalCounter() {
this(MapFactory.> hashMapFactory(), MapFactory. hashMapFactory());
}
public TwoDimensionalCounter(MapFactory> outerFactory,
MapFactory innerFactory) {
innerMF = innerFactory;
outerMF = outerFactory;
map = outerFactory.newMap();
total = 0.0;
}
public static TwoDimensionalCounter identityHashMapCounter() {
return new TwoDimensionalCounter<>(MapFactory.>identityHashMapFactory(), MapFactory.identityHashMapFactory());
}
public void recomputeTotal(){
total = 0;
for(Entry> c: map.entrySet()){
total += c.getValue().totalCount();
}
}
public static void main(String[] args) {
TwoDimensionalCounter cc = new TwoDimensionalCounter<>();
cc.setCount("a", "c", 1.0);
cc.setCount("b", "c", 1.0);
cc.setCount("a", "d", 1.0);
cc.setCount("a", "d", -1.0);
cc.setCount("b", "d", 1.0);
System.out.println(cc);
cc.incrementCount("b", "d", 1.0);
System.out.println(cc);
TwoDimensionalCounter cc2 = TwoDimensionalCounter.reverseIndexOrder(cc);
System.out.println(cc2);
}
}