org.deeplearning4j.berkeley.Counter Maven / Gradle / Ivy
package org.deeplearning4j.berkeley;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
/**
* A map from objects to doubles. Includes convenience methods for getting,
* setting, and incrementing element counts. Objects not in the counter will
* return a count of zero. The counter is backed by a HashMap (unless specified
* otherwise with the MapFactory constructor).
*
* @author Dan Klein
*/
public class Counter implements Serializable {
private static final long serialVersionUID = 1L;
Map entries;
boolean dirty = true;
double cacheTotal = 0.0;
MapFactory mf;
double deflt = 0.0;
public double getDeflt() {
return deflt;
}
public void setDeflt(double deflt) {
this.deflt = deflt;
}
/**
* The elements in the counter.
*
* @return set of keys
*/
public Set keySet() {
return entries.keySet();
}
public Set> entrySet() {
return entries.entrySet();
}
/**
* The number of entries in the counter (not the total count -- use
* totalCount() instead).
*/
public int size() {
return entries.size();
}
/**
* True if there are no entries in the counter (false does not mean
* totalCount > 0)
*/
public boolean isEmpty() {
return size() == 0;
}
/**
* Returns whether the counter contains the given key. Note that this is the
* way to distinguish keys which are in the counter with count zero, and
* those which are not in the counter (and will therefore return count zero
* from getCount().
*
* @param key
* @return whether the counter contains the key
*/
public boolean containsKey(E key) {
return entries.containsKey(key);
}
/**
* Get the count of the element, or zero if the element is not in the
* counter.
*
* @param key
* @return
*/
public double getCount(E key) {
Double value = entries.get(key);
if (value == null) return deflt;
return value;
}
/**
* I know, I know, this should be wrapped in a Distribution class, but it's
* such a common use...why not. Returns the MLE prob. Assumes all the counts
* are >= 0.0 and totalCount > 0.0. If the latter is false, return 0.0 (i.e.
* 0/0 == 0)
*
* @author Aria
* @param key
* @return MLE prob of the key
*/
public double getProbability(E key) {
double count = getCount(key);
double total = totalCount();
if (total < 0.0) {
throw new RuntimeException("Can't call getProbability() with totalCount < 0.0");
}
return total > 0.0 ? count / total : 0.0;
}
/**
* Destructively normalize this Counter in place.
*/
public void normalize() {
double totalCount = totalCount();
for (E key : keySet()) {
setCount(key, getCount(key) / totalCount);
}
dirty = true;
}
/**
* Set the count for the given key, clobbering any previous count.
*
* @param key
* @param count
*/
public void setCount(E key, double count) {
entries.put(key, count);
dirty = true;
}
/**
* Set the count for the given key if it is larger than the previous one;
*
* @param key
* @param count
*/
public void put(E key, double count, boolean keepHigher) {
if (keepHigher && entries.containsKey(key)) {
double oldCount = entries.get(key);
if (count > oldCount) {
entries.put(key, count);
}
} else {
entries.put(key, count);
}
dirty = true;
}
/**
* Will return a sample from the counter, will throw exception if any of the
* counts are < 0.0 or if the totalCount() <= 0.0
*
* @return
*
* @author aria42
*/
public E sample(Random rand) {
double total = totalCount();
if (total <= 0.0) {
throw new RuntimeException(String.format(
"Attempting to sample() with totalCount() %.3f\n", total));
}
double sum = 0.0;
double r = rand.nextDouble();
for (Map.Entry entry : entries.entrySet()) {
double count = entry.getValue();
double frac = count / total;
sum += frac;
if (r < sum) {
return entry.getKey();
}
}
throw new IllegalStateException("Shoudl've have returned a sample by now....");
}
/**
* Will return a sample from the counter, will throw exception if any of the
* counts are < 0.0 or if the totalCount() <= 0.0
*
* @return
*
* @author aria42
*/
public E sample() {
return sample(new Random());
}
public void removeKey(E key) {
setCount(key, 0.0);
dirty = true;
removeKeyFromEntries(key);
}
/**
* @param key
*/
protected void removeKeyFromEntries(E key) {
entries.remove(key);
}
/**
* Set's the key's count to the maximum of the current count and val. Always
* sets to val if key is not yet present.
*
* @param key
* @param increment
*/
public void setMaxCount(E key, double val) {
Double value = entries.get(key);
if (value == null || val > value.doubleValue()) {
setCount(key, val);
dirty = true;
}
}
/**
* Set's the key's count to the minimum of the current count and val. Always
* sets to val if key is not yet present.
*
* @param key
* @param increment
*/
public void setMinCount(E key, double val) {
Double value = entries.get(key);
if (value == null || val < value.doubleValue()) {
setCount(key, val);
dirty = true;
}
}
/**
* Increment a key's count by the given amount.
*
* @param key
* @param increment
*/
public double incrementCount(E key, double increment) {
double newVal = getCount(key) + increment;
setCount(key, newVal);
dirty = true;
return newVal;
}
/**
* Increment each element in a given collection by a given amount.
*/
public void incrementAll(Collection extends E> collection, double count) {
for (E key : collection) {
incrementCount(key, count);
}
dirty = true;
}
public void incrementAll(Counter counter) {
for (T key : counter.keySet()) {
double count = counter.getCount(key);
incrementCount(key, count);
}
dirty = true;
}
/**
* Finds the total of all counts in the counter. This implementation
* iterates through the entire counter every time this method is called.
*
* @return the counter's total
*/
public double totalCount() {
if (!dirty) {
return cacheTotal;
}
double total = 0.0;
for (Map.Entry entry : entries.entrySet()) {
total += entry.getValue();
}
cacheTotal = total;
dirty = false;
return total;
}
public List getSortedKeys() {
PriorityQueue pq = this.asPriorityQueue();
List keys = new ArrayList();
while (pq.hasNext()) {
keys.add(pq.next());
}
return keys;
}
/**
* Finds the key with maximum count. This is a linear operation, and ties
* are broken arbitrarily.
*
* @return a key with minumum count
*/
public E argMax() {
double maxCount = Double.NEGATIVE_INFINITY;
E maxKey = null;
for (Map.Entry entry : entries.entrySet()) {
if (entry.getValue() > maxCount || maxKey == null) {
maxKey = entry.getKey();
maxCount = entry.getValue();
}
}
return maxKey;
}
public double min() {
return maxMinHelp(false);
}
public double max() {
return maxMinHelp(true);
}
private double maxMinHelp(boolean max) {
double maxCount = max ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
for (Map.Entry entry : entries.entrySet()) {
if ((max && entry.getValue() > maxCount)
|| (!max && entry.getValue() < maxCount)) {
maxCount = entry.getValue();
}
}
return maxCount;
}
/**
* Returns a string representation with the keys ordered by decreasing
* counts.
*
* @return string representation
*/
@Override
public String toString() {
return toString(keySet().size());
}
public String toStringSortedByKeys() {
StringBuilder sb = new StringBuilder("[");
NumberFormat f = NumberFormat.getInstance();
f.setMaximumFractionDigits(5);
int numKeysPrinted = 0;
for (E element : new TreeSet(keySet())) {
sb.append(element.toString());
sb.append(" : ");
sb.append(f.format(getCount(element)));
if (numKeysPrinted < size() - 1) sb.append(", ");
numKeysPrinted++;
}
if (numKeysPrinted < size()) sb.append("...");
sb.append("]");
return sb.toString();
}
/**
* Returns a string representation which includes no more than the
* maxKeysToPrint elements with largest counts.
*
* @param maxKeysToPrint
* @return partial string representation
*/
public String toString(int maxKeysToPrint) {
return asPriorityQueue().toString(maxKeysToPrint, false);
}
/**
* Returns a string representation which includes no more than the
* maxKeysToPrint elements with largest counts and optionally prints
* one element per line.
*
* @param maxKeysToPrint
* @return partial string representation
*/
public String toString(int maxKeysToPrint, boolean multiline) {
return asPriorityQueue().toString(maxKeysToPrint, multiline);
}
/**
* Builds a priority queue whose elements are the counter's elements, and
* whose priorities are those elements' counts in the counter.
*/
public PriorityQueue asPriorityQueue() {
PriorityQueue pq = new PriorityQueue(entries.size());
for (Map.Entry entry : entries.entrySet()) {
pq.add(entry.getKey(), entry.getValue());
}
return pq;
}
/**
* Warning: all priorities are the negative of their counts in the counter
* here
*
* @return
*/
public PriorityQueue asMinPriorityQueue() {
PriorityQueue pq = new PriorityQueue(entries.size());
for (Map.Entry entry : entries.entrySet()) {
pq.add(entry.getKey(), -entry.getValue());
}
return pq;
}
public Counter() {
this(false);
}
public Counter(boolean identityHashMap) {
this(identityHashMap ? new MapFactory.IdentityHashMapFactory()
: new MapFactory.HashMapFactory());
}
public Counter(MapFactory mf) {
this.mf = mf;
entries = mf.buildMap();
}
public Counter(Map extends E, Double> mapCounts) {
this(false);
this.entries = new HashMap();
for (Entry extends E, Double> entry : mapCounts.entrySet()) {
incrementCount(entry.getKey(), entry.getValue());
}
}
public Counter(Counter extends E> counter) {
this();
incrementAll(counter);
}
public Counter(Collection extends E> collection) {
this();
incrementAll(collection, 1.0);
}
public void pruneKeysBelowThreshold(double cutoff) {
Iterator it = entries.keySet().iterator();
while (it.hasNext()) {
E key = it.next();
double val = entries.get(key);
if (val < cutoff) {
it.remove();
}
}
dirty = true;
}
public Set> getEntrySet() {
return entries.entrySet();
}
public boolean isEqualTo(Counter counter) {
boolean tmp = true;
Counter bigger = counter.size() > size() ? counter : this;
for (E e : bigger.keySet()) {
tmp &= counter.getCount(e) == getCount(e);
}
return tmp;
}
public static void main(String[] args) {
Counter counter = new Counter();
System.out.println(counter);
counter.incrementCount("planets", 7);
System.out.println(counter);
counter.incrementCount("planets", 1);
System.out.println(counter);
counter.setCount("suns", 1);
System.out.println(counter);
counter.setCount("aliens", 0);
System.out.println(counter);
System.out.println(counter.toString(2));
System.out.println("Total: " + counter.totalCount());
}
public void clear() {
entries = mf.buildMap();
dirty = true;
}
public void keepTopNKeys(int keepN) {
keepKeysHelper(keepN, true);
}
public void keepBottomNKeys(int keepN) {
keepKeysHelper(keepN, false);
}
private void keepKeysHelper(int keepN, boolean top) {
Counter tmp = new Counter();
int n = 0;
for (E e : Iterators.able(top ? asPriorityQueue() : asMinPriorityQueue())) {
if (n <= keepN) tmp.setCount(e, getCount(e));
n++;
}
clear();
incrementAll(tmp);
dirty = true;
}
/**
* Sets all counts to the given value, but does not remove any keys
*/
public void setAllCounts(double val) {
for (E e : keySet()) {
setCount(e, val);
}
}
public double dotProduct(Counter other) {
double sum = 0.0;
for (Map.Entry entry : getEntrySet()) {
final double otherCount = other.getCount(entry.getKey());
if (otherCount == 0.0) continue;
final double value = entry.getValue();
if (value == 0.0) continue;
sum += value * otherCount;
}
return sum;
}
public void scale(double c) {
for (Map.Entry entry : getEntrySet()) {
entry.setValue(entry.getValue() * c);
}
}
public Counter scaledClone(double c) {
Counter newCounter = new Counter();
for (Map.Entry entry : getEntrySet()) {
newCounter.setCount(entry.getKey(), entry.getValue() * c);
}
return newCounter;
}
public Counter difference(Counter counter) {
Counter clone = new Counter(this);
for (E key : counter.keySet()) {
double count = counter.getCount(key);
clone.incrementCount(key, -1 * count);
}
return clone;
}
public Counter toLogSpace() {
Counter newCounter = new Counter(this);
for (E key : newCounter.keySet()) {
newCounter.setCount(key, Math.log(getCount(key)));
}
return newCounter;
}
public boolean approxEquals(Counter other, double tol) {
for (E key : keySet()) {
if (Math.abs(getCount(key) - other.getCount(key)) > tol) return false;
}
for (E key : other.keySet()) {
if (Math.abs(getCount(key) - other.getCount(key)) > tol) return false;
}
return true;
}
public void setDirty(boolean dirty) {
this.dirty = dirty;
}
public String toStringTabSeparated() {
StringBuilder sb = new StringBuilder();
for (E key : getSortedKeys()) {
sb.append(key.toString() + "\t" + getCount(key) + "\n");
}
return sb.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy