![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.util.Counter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.util;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
/**
* A map from objects to doubles. Includes convenience methods for getting,
* setting, and incrementing element counts. Objects not in the counter will
* return a count of zero. The counter is backed by a HashMap (unless specified
* otherwise with the MapFactory constructor).
*
* @author Dan Klein
*/
public class Counter implements Serializable {
private static final long serialVersionUID = 1L;
Map entries;
boolean dirty = true;
double cacheTotal = 0.0;
MapFactory mf;
double deflt = 0.0;
public double getDeflt() {
return deflt;
}
public void setDeflt(double deflt) {
this.deflt = deflt;
}
/**
* The elements in the counter.
*
* @return set of keys
*/
public Set keySet() {
return entries.keySet();
}
public Set> entrySet() {
return entries.entrySet();
}
/**
* The number of entries in the counter (not the total count -- use
* totalCount() instead).
*/
public int size() {
return entries.size();
}
/**
* True if there are no entries in the counter (false does not mean
* totalCount > 0)
*/
public boolean isEmpty() {
return size() == 0;
}
/**
* Returns whether the counter contains the given key. Note that this is the
* way to distinguish keys which are in the counter with count zero, and
* those which are not in the counter (and will therefore return count zero
* from getCount().
*
* @param key
* @return whether the counter contains the key
*/
public boolean containsKey(E key) {
return entries.containsKey(key);
}
/**
* Get the count of the element, or zero if the element is not in the
* counter.
*
* @param key
* @return
*/
public double getCount(E key) {
Double value = entries.get(key);
if (value == null) return deflt;
return value;
}
/**
* I know, I know, this should be wrapped in a Distribution class, but it's
* such a common use...why not. Returns the MLE prob. Assumes all the counts
* are >= 0.0 and totalCount > 0.0. If the latter is false, return 0.0 (i.e.
* 0/0 == 0)
*
* @author Aria
* @param key
* @return MLE prob of the key
*/
public double getProbability(E key) {
double count = getCount(key);
double total = totalCount();
if (total < 0.0) {
throw new RuntimeException("Can't call getProbability() with totalCount < 0.0");
}
return total > 0.0 ? count / total : 0.0;
}
/**
* Destructively normalize this Counter in place.
*/
public void normalize() {
double totalCount = totalCount();
for (E key : keySet()) {
setCount(key, getCount(key) / totalCount);
}
dirty = true;
}
/**
* Set the count for the given key, clobbering any previous count.
*
* @param key
* @param count
*/
public void setCount(E key, double count) {
entries.put(key, count);
dirty = true;
}
/**
* Set the count for the given key if it is larger than the previous one;
*
* @param key
* @param count
*/
public void put(E key, double count, boolean keepHigher) {
if (keepHigher && entries.containsKey(key)) {
double oldCount = entries.get(key);
if (count > oldCount) {
entries.put(key, count);
}
} else {
entries.put(key, count);
}
dirty = true;
}
/**
* Will return a sample from the counter, will throw exception if any of the
* counts are < 0.0 or if the totalCount() <= 0.0
*
* @return
*
* @author aria42
*/
public E sample(Random rand) {
double total = totalCount();
if (total <= 0.0) {
throw new RuntimeException(String.format(
"Attempting to sample() with totalCount() %.3f\n", total));
}
double sum = 0.0;
double r = rand.nextDouble();
for (Map.Entry entry : entries.entrySet()) {
double count = entry.getValue();
double frac = count / total;
sum += frac;
if (r < sum) {
return entry.getKey();
}
}
throw new IllegalStateException("Shoudl've have returned a sample by now....");
}
/**
* Will return a sample from the counter, will throw exception if any of the
* counts are < 0.0 or if the totalCount() <= 0.0
*
* @return
*
* @author aria42
*/
public E sample() {
return sample(new Random());
}
public void removeKey(E key) {
setCount(key, 0.0);
dirty = true;
removeKeyFromEntries(key);
}
/**
* @param key
*/
protected void removeKeyFromEntries(E key) {
entries.remove(key);
}
/**
* Set's the key's count to the maximum of the current count and val. Always
* sets to val if key is not yet present.
*
* @param key
* @param increment
*/
public void setMaxCount(E key, double val) {
Double value = entries.get(key);
if (value == null || val > value.doubleValue()) {
setCount(key, val);
dirty = true;
}
}
/**
* Set's the key's count to the minimum of the current count and val. Always
* sets to val if key is not yet present.
*
* @param key
* @param increment
*/
public void setMinCount(E key, double val) {
Double value = entries.get(key);
if (value == null || val < value.doubleValue()) {
setCount(key, val);
dirty = true;
}
}
/**
* Increment a key's count by the given amount.
*
* @param key
* @param increment
*/
public double incrementCount(E key, double increment) {
double newVal = getCount(key) + increment;
setCount(key, newVal);
dirty = true;
return newVal;
}
/**
* Increment each element in a given collection by a given amount.
*/
public void incrementAll(Collection extends E> collection, double count) {
for (E key : collection) {
incrementCount(key, count);
}
dirty = true;
}
public void incrementAll(Counter counter) {
for (T key : counter.keySet()) {
double count = counter.getCount(key);
incrementCount(key, count);
}
dirty = true;
}
/**
* Finds the total of all counts in the counter. This implementation
* iterates through the entire counter every time this method is called.
*
* @return the counter's total
*/
public double totalCount() {
if (!dirty) {
return cacheTotal;
}
double total = 0.0;
for (Map.Entry entry : entries.entrySet()) {
total += entry.getValue();
}
cacheTotal = total;
dirty = false;
return total;
}
public List getSortedKeys() {
PriorityQueue pq = this.asPriorityQueue();
List keys = new ArrayList();
while (pq.hasNext()) {
keys.add(pq.next());
}
return keys;
}
/**
* Finds the key with maximum count. This is a linear operation, and ties
* are broken arbitrarily.
*
* @return a key with minumum count
*/
public E argMax() {
double maxCount = Double.NEGATIVE_INFINITY;
E maxKey = null;
for (Map.Entry entry : entries.entrySet()) {
if (entry.getValue() > maxCount || maxKey == null) {
maxKey = entry.getKey();
maxCount = entry.getValue();
}
}
return maxKey;
}
public double min() {
return maxMinHelp(false);
}
public double max() {
return maxMinHelp(true);
}
private double maxMinHelp(boolean max) {
double maxCount = max ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
for (Map.Entry entry : entries.entrySet()) {
if ((max && entry.getValue() > maxCount)
|| (!max && entry.getValue() < maxCount)) {
maxCount = entry.getValue();
}
}
return maxCount;
}
/**
* Returns a string representation with the keys ordered by decreasing
* counts.
*
* @return string representation
*/
@Override
public String toString() {
return toString(keySet().size());
}
public String toStringSortedByKeys() {
StringBuilder sb = new StringBuilder("[");
NumberFormat f = NumberFormat.getInstance();
f.setMaximumFractionDigits(5);
int numKeysPrinted = 0;
for (E element : new TreeSet(keySet())) {
sb.append(element.toString());
sb.append(" : ");
sb.append(f.format(getCount(element)));
if (numKeysPrinted < size() - 1) sb.append(", ");
numKeysPrinted++;
}
if (numKeysPrinted < size()) sb.append("...");
sb.append("]");
return sb.toString();
}
/**
* Returns a string representation which includes no more than the
* maxKeysToPrint elements with largest counts.
*
* @param maxKeysToPrint
* @return partial string representation
*/
public String toString(int maxKeysToPrint) {
return asPriorityQueue().toString(maxKeysToPrint, false);
}
/**
* Returns a string representation which includes no more than the
* maxKeysToPrint elements with largest counts and optionally prints
* one element per line.
*
* @param maxKeysToPrint
* @return partial string representation
*/
public String toString(int maxKeysToPrint, boolean multiline) {
return asPriorityQueue().toString(maxKeysToPrint, multiline);
}
/**
* Builds a priority queue whose elements are the counter's elements, and
* whose priorities are those elements' counts in the counter.
*/
public PriorityQueue asPriorityQueue() {
PriorityQueue pq = new PriorityQueue(entries.size());
for (Map.Entry entry : entries.entrySet()) {
pq.add(entry.getKey(), entry.getValue());
}
return pq;
}
/**
* Warning: all priorities are the negative of their counts in the counter
* here
*
* @return
*/
public PriorityQueue asMinPriorityQueue() {
PriorityQueue pq = new PriorityQueue(entries.size());
for (Map.Entry entry : entries.entrySet()) {
pq.add(entry.getKey(), -entry.getValue());
}
return pq;
}
public Counter() {
this(false);
}
public Counter(boolean identityHashMap) {
this(identityHashMap ? new MapFactory.IdentityHashMapFactory()
: new MapFactory.HashMapFactory());
}
public Counter(MapFactory mf) {
this.mf = mf;
entries = mf.buildMap();
}
public Counter(Map extends E, Double> mapCounts) {
this(false);
this.entries = new HashMap();
for (Entry extends E, Double> entry : mapCounts.entrySet()) {
incrementCount(entry.getKey(), entry.getValue());
}
}
public Counter(Counter extends E> counter) {
this();
incrementAll(counter);
}
public Counter(Collection extends E> collection) {
this();
incrementAll(collection, 1.0);
}
public void pruneKeysBelowThreshold(double cutoff) {
Iterator it = entries.keySet().iterator();
while (it.hasNext()) {
E key = it.next();
double val = entries.get(key);
if (val < cutoff) {
it.remove();
}
}
dirty = true;
}
public Set> getEntrySet() {
return entries.entrySet();
}
public boolean isEqualTo(Counter counter) {
boolean tmp = true;
Counter bigger = counter.size() > size() ? counter : this;
for (E e : bigger.keySet()) {
tmp &= counter.getCount(e) == getCount(e);
}
return tmp;
}
public static void main(String[] args) {
Counter counter = new Counter();
System.out.println(counter);
counter.incrementCount("planets", 7);
System.out.println(counter);
counter.incrementCount("planets", 1);
System.out.println(counter);
counter.setCount("suns", 1);
System.out.println(counter);
counter.setCount("aliens", 0);
System.out.println(counter);
System.out.println(counter.toString(2));
System.out.println("Total: " + counter.totalCount());
}
public void clear() {
entries = mf.buildMap();
dirty = true;
}
public void keepTopNKeys(int keepN) {
keepKeysHelper(keepN, true);
}
public void keepBottomNKeys(int keepN) {
keepKeysHelper(keepN, false);
}
private void keepKeysHelper(int keepN, boolean top) {
Counter tmp = new Counter();
int n = 0;
for (E e : Iterators.able(top ? asPriorityQueue() : asMinPriorityQueue())) {
if (n <= keepN) tmp.setCount(e, getCount(e));
n++;
}
clear();
incrementAll(tmp);
dirty = true;
}
/**
* Sets all counts to the given value, but does not remove any keys
*/
public void setAllCounts(double val) {
for (E e : keySet()) {
setCount(e, val);
}
}
public double dotProduct(Counter other) {
double sum = 0.0;
for (Map.Entry entry : getEntrySet()) {
final double otherCount = other.getCount(entry.getKey());
if (otherCount == 0.0) continue;
final double value = entry.getValue();
if (value == 0.0) continue;
sum += value * otherCount;
}
return sum;
}
public void scale(double c) {
for (Map.Entry entry : getEntrySet()) {
entry.setValue(entry.getValue() * c);
}
}
public Counter scaledClone(double c) {
Counter newCounter = new Counter();
for (Map.Entry entry : getEntrySet()) {
newCounter.setCount(entry.getKey(), entry.getValue() * c);
}
return newCounter;
}
public Counter difference(Counter counter) {
Counter clone = new Counter(this);
for (E key : counter.keySet()) {
double count = counter.getCount(key);
clone.incrementCount(key, -1 * count);
}
return clone;
}
public Counter toLogSpace() {
Counter newCounter = new Counter(this);
for (E key : newCounter.keySet()) {
newCounter.setCount(key, Math.log(getCount(key)));
}
return newCounter;
}
public boolean approxEquals(Counter other, double tol) {
for (E key : keySet()) {
if (Math.abs(getCount(key) - other.getCount(key)) > tol) return false;
}
for (E key : other.keySet()) {
if (Math.abs(getCount(key) - other.getCount(key)) > tol) return false;
}
return true;
}
public void setDirty(boolean dirty) {
this.dirty = dirty;
}
public String toStringTabSeparated() {
StringBuilder sb = new StringBuilder();
for (E key : getSortedKeys()) {
sb.append(key.toString() + "\t" + getCount(key) + "\n");
}
return sb.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy