All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.stats.Counters Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
// Stanford JavaNLP support classes
// Copyright (c) 2004-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    [email protected]
//    http://nlp.stanford.edu/software/

package edu.stanford.nlp.stats;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.lang.reflect.Constructor;
import java.text.NumberFormat;
import java.util.AbstractCollection;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.Map.Entry;
import java.util.function.Function;
import java.util.regex.Pattern;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.PrettyLogger;
import edu.stanford.nlp.util.logging.Redwood.RedwoodChannels;

/**
 * Static methods for operating on a {@link Counter}.
 * 

* All methods that change their arguments change the first argument * (only), and have "InPlace" in their name. This class also provides access to * Comparators that can be used to sort the keys or entries of this Counter by * the counts, in either ascending or descending order. * * @author Galen Andrew ([email protected]) * @author Jeff Michels ([email protected]) * @author dramage * @author daniel cer (http://dmcer.net) * @author Christopher Manning * @author stefank (Optimized dot product) */ public class Counters { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(Counters.class); private static final double LOG_E_2 = Math.log(2.0); private Counters() {} // only static methods // // Log arithmetic operations // /** * Returns ArrayMath.logSum of the values in this counter. * * @param c Argument counter (which is not modified) * @return ArrayMath.logSum of the values in this counter. */ public static double logSum(Counter c) { return ArrayMath.logSum(ArrayMath.unbox(c.values())); } /** * Transform log space values into a probability distribution in place. On the * assumption that the values in the Counter are in log space, this method * calculates their sum, and then subtracts the log of their sum from each * element. That is, if a counter has keys c1, c2, c3 with values v1, v2, v3, * the value of c1 becomes v1 - log(e^v1 + e^v2 + e^v3). After this, e^v1 + * e^v2 + e^v3 = 1.0, so Counters.logSum(c) = 0.0 (approximately). * * @param c The Counter to log normalize in place */ @SuppressWarnings( { "UnnecessaryUnboxing" }) public static void logNormalizeInPlace(Counter c) { double logsum = logSum(c); // for (E key : c.keySet()) { // c.incrementCount(key, -logsum); // } // This should be faster for (Map.Entry e : c.entrySet()) { e.setValue(e.getValue().doubleValue() - logsum); } } // // Query operations // /** * Returns the value of the maximum entry in this counter. This is also the * L_infinity norm. An empty counter is given a max value of * Double.NEGATIVE_INFINITY. * * @param c The Counter to find the max of * @return The maximum value of the Counter */ public static double max(Counter c) { return max(c, Double.NEGATIVE_INFINITY); // note[gabor]: Should the default actually be 0 rather than negative_infinity? } /** * Returns the value of the maximum entry in this counter. This is also the * L_infinity norm. An empty counter is given a max value of * Double.NEGATIVE_INFINITY. * * @param c The Counter to find the max of * @param valueIfEmpty The value to return if this counter is empty (i.e., the maximum is not well defined. * @return The maximum value of the Counter */ public static double max(Counter c, double valueIfEmpty) { if (c.size() == 0) { return valueIfEmpty; } else { double max = Double.NEGATIVE_INFINITY; for (double v : c.values()) { max = Math.max(max, v); } return max; } } /** * Takes in a Collection of something and makes a counter, incrementing once * for each object in the collection. * * @param c The Collection to turn into a counter * @return The counter made out of the collection */ public static Counter asCounter(Collection c) { Counter count = new ClassicCounter<>(); for (E elem : c) { count.incrementCount(elem); } return count; } /** * Returns the value of the smallest entry in this counter. * * @param c The Counter (not modified) * @return The minimum value in the Counter */ public static double min(Counter c) { double min = Double.POSITIVE_INFINITY; for (double v : c.values()) { min = Math.min(min, v); } return min; } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @return The key in the Counter with the largest count. */ public static E argmax(Counter c) { return argmax(c, (x, y) -> 0, null); } /** * Finds and returns the key in this Counter with the smallest count. * * @param c The Counter * @return The key in the Counter with the smallest count. */ public static E argmin(Counter c) { double min = Double.POSITIVE_INFINITY; E argmin = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmin == null || count < min) { // || (count == min && tieBreaker.compare(key, argmin) < 0) min = count; argmin = key; } } return argmin; } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @param tieBreaker the tie breaker for when elements have the same value. * @return The key in the Counter with the largest count. */ public static E argmax(Counter c, Comparator tieBreaker) { return argmax(c, tieBreaker, (E) null); } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @param tieBreaker the tie breaker for when elements have the same value. * @param defaultIfEmpty The value to return if the counter is empty. * @return The key in the Counter with the largest count. */ public static E argmax(Counter c, Comparator tieBreaker, E defaultIfEmpty) { if (Thread.interrupted()) { // A good place to check for interrupts -- called from many annotators throw new RuntimeInterruptedException(); } if (c.size() == 0) { return defaultIfEmpty; } double max = Double.NEGATIVE_INFINITY; E argmax = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmax == null || count > max || (count == max && tieBreaker.compare(key, argmax) < 0)) { max = count; argmax = key; } } return argmax; } /** * Finds and returns the key in this Counter with the smallest count. * * @param c The Counter * @return The key in the Counter with the smallest count. */ public static E argmin(Counter c, Comparator tieBreaker) { double min = Double.POSITIVE_INFINITY; E argmin = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmin == null || count < min || (count == min && tieBreaker.compare(key, argmin) < 0)) { min = count; argmin = key; } } return argmin; } /** * Returns the mean of all the counts (totalCount/size). * * @param c The Counter to find the mean of. * @return The mean of all the counts (totalCount/size). */ public static double mean(Counter c) { return c.totalCount() / c.size(); } public static double standardDeviation(Counter c) { double std = 0; double mean = c.totalCount() / c.size(); for (Map.Entry en : c.entrySet()) { std += (en.getValue() - mean) * (en.getValue() - mean); } return Math.sqrt(std / c.size()); } // // In-place arithmetic // /** * Sets each value of target to be target[k]+scale*arg[k] for all keys k in * target. * * @param target A Counter that is modified * @param arg The Counter whose contents are added to target * @param scale How the arg Counter is scaled before being added */ // TODO: Rewrite to use arg.entrySet() public static void addInPlace(Counter target, Counter arg, double scale) { for (E key : arg.keySet()) { target.incrementCount(key, scale * arg.getCount(key)); } } /** * Sets each value of target to be target[k]+arg[k] for all keys k in arg. */ public static void addInPlace(Counter target, Counter arg) { for (Map.Entry entry : arg.entrySet()) { double count = entry.getValue(); if (count != 0) { target.incrementCount(entry.getKey(), count); } } } /** * Sets each value of double[] target to be * target[idx.indexOf(k)]+a.getCount(k) for all keys k in arg */ public static void addInPlace(double[] target, Counter arg, Index idx) { for (Map.Entry entry : arg.entrySet()) { target[idx.indexOf(entry.getKey())] += entry.getValue(); } } /** * For all keys (u,v) in arg1 and arg2, sets return[u,v] to be summation of both. * @param * @param */ public static TwoDimensionalCounter add(TwoDimensionalCounter arg1, TwoDimensionalCounter arg2) { TwoDimensionalCounter add = new TwoDimensionalCounter<>(); Counters.addInPlace(add , arg1); Counters.addInPlace(add , arg2); return add; } /** * For all keys (u,v) in arg, sets target[u,v] to be target[u,v] + scale * * arg[u,v]. * * @param * @param */ public static void addInPlace(TwoDimensionalCounter target, TwoDimensionalCounter arg, double scale) { for (T1 outer : arg.firstKeySet()) for (T2 inner : arg.secondKeySet()) { target.incrementCount(outer, inner, scale * arg.getCount(outer, inner)); } } /** * For all keys (u,v) in arg, sets target[u,v] to be target[u,v] + arg[u,v]. * * @param * @param */ public static void addInPlace(TwoDimensionalCounter target, TwoDimensionalCounter arg) { for (T1 outer : arg.firstKeySet()) for (T2 inner : arg.secondKeySet()) { target.incrementCount(outer, inner, arg.getCount(outer, inner)); } } /** * Sets each value of target to be target[k]+ * value*(num-of-times-it-occurs-in-collection) if the key is present in the arg * collection. */ public static void addInPlace(Counter target, Collection arg, double value) { for (E key : arg) { target.incrementCount(key, value); } } /** * For all keys (u,v) in target, sets target[u,v] to be target[u,v] + value * * @param * @param */ public static void addInPlace(TwoDimensionalCounter target, double value) { for (T1 outer : target.firstKeySet()){ addInPlace(target.getCounter(outer), value); } } /** * Sets each value of target to be target[k]+ * num-of-times-it-occurs-in-collection if the key is present in the arg * collection. */ public static void addInPlace(Counter target, Collection arg) { for (E key : arg) { target.incrementCount(key, 1); } } /** * Increments all keys in a Counter by a specific value. */ public static void addInPlace(Counter target, double value) { for (E key : target.keySet()) { target.incrementCount(key, value); } } /** * Sets each value of target to be target[k]-arg[k] for all keys k in target. */ public static void subtractInPlace(Counter target, Counter arg) { for (E key : arg.keySet()) { target.decrementCount(key, arg.getCount(key)); } } /** * Sets each value of double[] target to be * target[idx.indexOf(k)]-a.getCount(k) for all keys k in arg */ public static void subtractInPlace(double[] target, Counter arg, Index idx) { for (Map.Entry entry : arg.entrySet()) { target[idx.indexOf(entry.getKey())] -= entry.getValue(); } } /** * Divides every non-zero count in target by the corresponding value in the * denominator Counter. Beware that this can give NaN values for zero counts * in the denominator counter! */ public static void divideInPlace(Counter target, Counter denominator) { for (E key : target.keySet()) { target.setCount(key, target.getCount(key) / denominator.getCount(key)); } } /** * Multiplies every count in target by the corresponding value in the term * Counter. */ public static void dotProductInPlace(Counter target, Counter term) { for (E key : target.keySet()) { target.setCount(key, target.getCount(key) * term.getCount(key)); } } /** * Divides each value in target by the given divisor, in place. * * @param target The values in this Counter will be changed throughout by the * multiplier * @param divisor The number by which to change each number in the Counter * @return The target Counter is returned (for easier method chaining) */ public static Counter divideInPlace(Counter target, double divisor) { for (Entry entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() / divisor); } return target; } /** * Multiplies each value in target by the given multiplier, in place. * * @param target The values in this Counter will be multiplied by the * multiplier * @param multiplier The number by which to change each number in the Counter */ public static Counter multiplyInPlace(Counter target, double multiplier) { for (Entry entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() * multiplier); } return target; } /** * Multiplies each value in target by the count of the key in mult, in place. Returns non zero entries * * @param target The counter * @param mult The counter you want to multiply with target */ public static Counter multiplyInPlace(Counter target, Counter mult) { for (Entry entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() * mult.getCount(entry.getKey())); } Counters.retainNonZeros(target); return target; } /** * Normalizes the target counter in-place, so the sum of the resulting values * equals 1. * * @param Type of elements in Counter */ public static void normalize(Counter target) { divideInPlace(target, target.totalCount()); } /** * L1 normalize a counter. Return a counter that is a probability distribution, * so the sum of the resulting value equals 1. * * @param c The {@link Counter} to be L1 normalized. This counter is not * modified. * @return A new L1-normalized Counter based on c. */ public static > C asNormalizedCounter(C c) { return scale(c, 1.0 / c.totalCount()); } /** * Normalizes the target counter in-place, so the sum of the resulting values * equals 1. * * @param Type of elements in TwoDimensionalCounter * @param Type of elements in TwoDimensionalCounter */ public static void normalize(TwoDimensionalCounter target) { Counters.divideInPlace(target, target.totalCount()); } public static void logInPlace(Counter target) { for (E key : target.keySet()) { target.setCount(key, Math.log(target.getCount(key))); } } // // Selection Operators // /** * Delete 'top' and 'bottom' number of elements from the top and bottom * respectively */ public static List deleteOutofRange(Counter c, int top, int bottom) { List purgedItems = new ArrayList<>(); int numToPurge = top + bottom; if (numToPurge <= 0) { return purgedItems; } List l = Counters.toSortedList(c); for (int i = 0; i < top; i++) { E item = l.get(i); purgedItems.add(item); c.remove(item); } int size = c.size(); for (int i = c.size() - 1; i >= (size - bottom); i--) { E item = l.get(i); purgedItems.add(item); c.remove(item); } return purgedItems; } /** * Removes all entries from c except for the top {@code num}. */ public static void retainTop(Counter c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return; } List l = Counters.toSortedList(c, true); for (int i = 0; i < numToPurge; i++) { c.remove(l.get(i)); } } /** * Removes all entries from c except for the top {@code num}. */ public static > void retainTopKeyComparable(Counter c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return; } List l = Counters.toSortedListKeyComparable(c); Collections.reverse(l); for (int i = 0; i < numToPurge; i++) { c.remove(l.get(i)); } } /** * Removes all entries from c except for the bottom {@code num}. */ public static List retainBottom(Counter c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return Generics.newArrayList(); } List removed = new ArrayList<>(); List l = Counters.toSortedList(c); for (int i = 0; i < numToPurge; i++) { E rem = l.get(i); removed.add(rem); c.remove(rem); } return removed; } /** * Removes all entries with 0 count in the counter, returning the set of * removed entries. */ public static Set retainNonZeros(Counter counter) { Set removed = Generics.newHashSet(); for (E key : counter.keySet()) { if (counter.getCount(key) == 0.0) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with counts below the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countThreshold * The minimum count for an entry to be kept. Entries (strictly) less * than this threshold are discarded. * @return The set of discarded entries. */ public static Set retainAbove(Counter counter, double countThreshold) { Set removed = Generics.newHashSet(); for (E key : counter.keySet()) { if (counter.getCount(key) < countThreshold) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with counts below the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countThreshold * The minimum count for an entry to be kept. Entries (strictly) less * than this threshold are discarded. * @return The set of discarded entries. */ public static Set> retainAbove( TwoDimensionalCounter counter, double countThreshold) { Set> removed = new HashSet<>(); for (Entry> en : counter.entrySet()) { for (Entry en2 : en.getValue().entrySet()) { if (counter.getCount(en.getKey(), en2.getKey()) < countThreshold) { removed.add(new Pair<>(en.getKey(), en2.getKey())); } } } for (Pair key : removed) { counter.remove(key.first(), key.second()); } return removed; } /** * Removes all entries with counts above the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countMaxThreshold * The maximum count for an entry to be kept. Entries (strictly) more * than this threshold are discarded. * @return The set of discarded entries. */ public static Counter retainBelow(Counter counter, double countMaxThreshold) { Counter removed = new ClassicCounter<>(); for (E key : counter.keySet()) { double count = counter.getCount(key); if (counter.getCount(key) > countMaxThreshold) { removed.setCount(key, count); } } for (Entry key : removed.entrySet()) { counter.remove(key.getKey()); } return removed; } /** * Removes all entries with keys that does not match one of the given patterns. * * @param counter The counter. * @param matchPatterns pattern for key to match * @return The set of discarded entries. */ public static Set retainMatchingKeys(Counter counter, List matchPatterns) { Set removed = Generics.newHashSet(); for (String key : counter.keySet()) { boolean matched = false; for (Pattern pattern : matchPatterns) { if (pattern.matcher(key).matches()) { matched = true; break; } } if (!matched) { removed.add(key); } } for (String key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with keys that does not match the given set of keys. * * @param counter The counter * @param matchKeys Keys to match * @return The set of discarded entries. */ public static Set retainKeys(Counter counter, Collection matchKeys) { Set removed = Generics.newHashSet(); for (E key : counter.keySet()) { boolean matched = matchKeys.contains(key); if (!matched) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with keys in the given collection * * @param * @param counter * @param removeKeysCollection */ public static void removeKeys(Counter counter, Collection removeKeysCollection) { for (E key : removeKeysCollection) counter.remove(key); } /** * Removes all entries with keys (first key set) in the given collection * * @param * @param counter * @param removeKeysCollection */ public static void removeKeys(TwoDimensionalCounter counter, Collection removeKeysCollection) { for (E key : removeKeysCollection) counter.remove(key); } /** * Returns the set of keys whose counts are at or above the given threshold. * This set may have 0 elements but will not be null. * * @param c The Counter to examine * @param countThreshold * Items equal to or above this number are kept * @return A (non-null) Set of keys whose counts are at or above the given * threshold. */ public static Set keysAbove(Counter c, double countThreshold) { Set keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) >= countThreshold) { keys.add(key); } } return (keys); } /** * Returns the set of keys whose counts are at or below the given threshold. * This set may have 0 elements but will not be null. */ public static Set keysBelow(Counter c, double countThreshold) { Set keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) <= countThreshold) { keys.add(key); } } return (keys); } /** * Returns the set of keys that have exactly the given count. This set may * have 0 elements but will not be null. */ public static Set keysAt(Counter c, double count) { Set keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) == count) { keys.add(key); } } return (keys); } // // Transforms // /** * Returns the counter with keys modified according to function F. Eager * evaluation. If two keys are same after the transformation, one of the values is randomly chosen (depending on how the keyset is traversed) */ public static Counter transform(Counter c, Function f) { Counter c2 = new ClassicCounter<>(); for (T1 key : c.keySet()) { c2.setCount(f.apply(key), c.getCount(key)); } return c2; } /** * Returns the counter with keys modified according to function F. If two keys are same after the transformation, their values get added up. */ public static Counter transformWithValuesAdd(Counter c, Function f) { Counter c2 = new ClassicCounter<>(); for (T1 key : c.keySet()) { c2.incrementCount(f.apply(key), c.getCount(key)); } return c2; } // // Conversion to other types // /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by ascending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static Comparator toComparator(final Counter counter) { return (o1, o2) -> Double.compare(counter.getCount(o1), counter.getCount(o2)); } /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by ascending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static > Comparator toComparatorWithKeys(final Counter counter) { return (o1, o2) -> { int res = Double.compare(counter.getCount(o1), counter.getCount(o2)); if (res == 0) { return o1.compareTo(o2); } else { return res; } }; } /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by descending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static Comparator toComparatorDescending(final Counter counter) { return (o1, o2) -> Double.compare(counter.getCount(o2), counter.getCount(o1)); } /** * Returns a comparator suitable for sorting this Counter's keys or entries by * their respective value or magnitude (by absolute value). If * ascending is true, smaller magnitudes will be returned first, * otherwise higher magnitudes will be returned first. *

* Sample usage: * *

   * Counter c = new Counter();
   * // add to the counter...
   * List biggestAbsKeys = new ArrayList(c.keySet());
   * Collections.sort(biggestAbsKeys, Counters.comparator(c, false, true));
   * List smallestEntries = new ArrayList(c.entrySet());
   * Collections.sort(smallestEntries, Counters.comparator(c, true, false));
   * 
*/ public static Comparator toComparator(final Counter counter, final boolean ascending, final boolean useMagnitude) { return (o1, o2) -> { if (ascending) { if (useMagnitude) { return Double.compare(Math.abs(counter.getCount(o1)), Math.abs(counter.getCount(o2))); } else { return Double.compare(counter.getCount(o1), counter.getCount(o2)); } } else { // Descending if (useMagnitude) { return Double.compare(Math.abs(counter.getCount(o2)), Math.abs(counter.getCount(o1))); } else { return Double.compare(counter.getCount(o2), counter.getCount(o1)); } } }; } /** * A List of the keys in c, sorted from highest count to lowest. * So note that the default is descending! * * @return A List of the keys in c, sorted from highest count to lowest. */ public static List toSortedList(Counter c) { return toSortedList(c, false); } /** * A List of the keys in c, sorted from highest count to lowest. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static List toSortedList(Counter c, boolean ascending) { List l = new ArrayList<>(c.keySet()); Comparator comp = ascending ? toComparator(c) : toComparatorDescending(c); Collections.sort(l, comp); return l; } /** * A List of the keys in c, sorted from highest count to lowest. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static > List toSortedListKeyComparable(Counter c) { List l = new ArrayList<>(c.keySet()); Comparator comp = toComparatorWithKeys(c); Collections.sort(l, comp); Collections.reverse(l); return l; } /** * Converts a counter to ranks; ranks start from 0 * * @return A counter where the count is the rank in the original counter */ public static IntCounter toRankCounter(Counter c) { IntCounter rankCounter = new IntCounter<>(); List sortedList = toSortedList(c); for (int i = 0; i < sortedList.size(); i++) { rankCounter.setCount(sortedList.get(i), i); } return rankCounter; } /** * Converts a counter to tied ranks; ranks start from 1 * * @return A counter where the count is the rank in the original counter; when values are tied, the rank is the average of the ranks of the tied values */ public static Counter toTiedRankCounter(Counter c) { Counter rankCounter = new ClassicCounter<>(); List> sortedList = toSortedListWithCounts(c); int i = 0; Iterator> it = sortedList.iterator(); while(it.hasNext()) { Pair iEn = it.next(); double icount = iEn.second(); E iKey = iEn.first(); List l = new ArrayList<>(); List keys = new ArrayList<>(); l.add(i+1); keys.add(iKey); for(int j = i +1; j < sortedList.size(); j++){ Pair jEn = sortedList.get(j); if( icount == jEn.second()){ l.add(j+1); keys.add(jEn.first()); }else break; } if(l.size() > 1){ double sum = 0; for(Integer d: l) sum += d; double avgRank = sum/l.size(); for(int k = 0; k < l.size(); k++){ rankCounter.setCount(keys.get(k), avgRank); if(k != l.size()-1 && it.hasNext()) it.next(); i++; } }else{ rankCounter.setCount(iKey, i+1); i++; } } return rankCounter; } public static List> toDescendingMagnitudeSortedListWithCounts(Counter c) { List keys = new ArrayList<>(c.keySet()); Collections.sort(keys, toComparator(c, false, true)); List> l = new ArrayList<>(keys.size()); for (E key : keys) { l.add(new Pair<>(key, c.getCount(key))); } return l; } /** * A List of the keys in c, sorted from highest count to lowest, paired with * counts * * @return A List of the keys in c, sorted from highest count to lowest. */ public static List> toSortedListWithCounts(Counter c) { List> l = new ArrayList<>(c.size()); for (E e : c.keySet()) { l.add(new Pair<>(e, c.getCount(e))); } // descending order Collections.sort(l, (a, b) -> Double.compare(b.second, a.second)); return l; } /** * A List of the keys in c, sorted by the given comparator, paired with * counts. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static List> toSortedListWithCounts(Counter c, Comparator> comparator) { List> l = new ArrayList<>(c.size()); for (E e : c.keySet()) { l.add(new Pair<>(e, c.getCount(e))); } // descending order Collections.sort(l, comparator); return l; } /** * Returns a {@link edu.stanford.nlp.util.PriorityQueue} whose elements are * the keys of Counter c, and the score of each key in c becomes its priority. * * @param c Input Counter * @return A PriorityQueue where the count is a key's priority */ // TODO: rewrite to use entrySet() public static edu.stanford.nlp.util.PriorityQueue toPriorityQueue(Counter c) { edu.stanford.nlp.util.PriorityQueue queue = new BinaryHeapPriorityQueue<>(); for (E key : c.keySet()) { double count = c.getCount(key); queue.add(key, count); } return queue; } // // Other Utilities // /** * Returns a Counter that is the union of the two Counters passed in (counts * are added). * * @return A Counter that is the union of the two Counters passed in (counts * are added). */ @SuppressWarnings("unchecked") public static > C union(C c1, C c2) { C result = (C) c1.getFactory().create(); addInPlace(result, c1); addInPlace(result, c2); return result; } /** * Returns a counter that is the intersection of c1 and c2. If both c1 and c2 * contain a key, the min of the two counts is used. * * @return A counter that is the intersection of c1 and c2 */ public static Counter intersection(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { double count1 = c1.getCount(key); double count2 = c2.getCount(key); double minCount = (count1 < count2 ? count1 : count2); if (minCount > 0) { result.setCount(key, minCount); } } return result; } /** * Returns the Jaccard Coefficient of the two counters. Calculated as |c1 * intersect c2| / ( |c1| + |c2| - |c1 intersect c2| * * @return The Jaccard Coefficient of the two counters */ public static double jaccardCoefficient(Counter c1, Counter c2) { double minCount = 0.0, maxCount = 0.0; for (E key : Sets.union(c1.keySet(), c2.keySet())) { double count1 = c1.getCount(key); double count2 = c2.getCount(key); minCount += (count1 < count2 ? count1 : count2); maxCount += (count1 > count2 ? count1 : count2); } return minCount / maxCount; } /** * Returns the product of c1 and c2. * * @return The product of c1 and c2. */ public static Counter product(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.intersection(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) * c2.getCount(key)); } return result; } /** * Returns the product of c1 and c2. * * @return The product of c1 and c2. */ public static double dotProduct(Counter c1, Counter c2) { double dotProd = 0.0; if (c1.size() > c2.size()) { Counter tmpCnt = c1; c1 = c2; c2 = tmpCnt; } for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (Double.isNaN(count1) || Double.isInfinite(count1)) { throw new RuntimeException("Counters.dotProduct infinite or NaN value for key: " + key + '\t' + c1.getCount(key) + '\t' + c2.getCount(key)); } if (count1 != 0.0) { double count2 = c2.getCount(key); if (Double.isNaN(count2) || Double.isInfinite(count2)) { throw new RuntimeException("Counters.dotProduct infinite or NaN value for key: " + key + '\t' + c1.getCount(key) + '\t' + c2.getCount(key)); } if (count2 != 0.0) { // this is the inner product dotProd += (count1 * count2); } } } return dotProd; } /** * Returns the product of Counter c and double[] a, using Index idx to map * entries in C onto a. * * @return The product of c and a. */ public static double dotProduct(Counter c, double[] a, Index idx) { double dotProd = 0.0; for (Map.Entry entry : c.entrySet()) { int keyIdx = idx.indexOf(entry.getKey()); if (keyIdx >= 0) { dotProd += entry.getValue() * a[keyIdx]; } } return dotProd; } public static double sumEntries(Counter c1, Collection entries) { double dotProd = 0.0; for (E entry : entries) { dotProd += c1.getCount(entry); } return dotProd; } public static Counter add(Counter c1, Collection c2) { Counter result = c1.getFactory().create(); addInPlace(result, c1); for (E key : c2) { result.incrementCount(key, 1); } return result; } public static Counter add(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) + c2.getCount(key)); } retainNonZeros(result); return result; } /** * increments every key in the counter by value */ public static Counter add(Counter c1, double value) { Counter result = c1.getFactory().create(); for (E key : c1.keySet()) { result.setCount(key, c1.getCount(key) + value); } return result; } /** * This method does not check entries for NAN or INFINITY values in the * doubles returned. It also only iterates over the counter with the smallest * number of keys to help speed up computation. Pair this method with * normalizing your counters before hand and you have a reasonably quick * implementation of cosine. * * @param * @param c1 * @param c2 * @return The dot product of the two counter (as vectors) */ public static double optimizedDotProduct(Counter c1, Counter c2) { int size1 = c1.size(); int size2 = c2.size(); if (size1 < size2) { return getDotProd(c1, c2); } else { return getDotProd(c2, c1); } } private static double getDotProd(Counter c1, Counter c2) { double dotProd = 0.0; for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (count1 != 0.0) { double count2 = c2.getCount(key); if (count2 != 0.0) dotProd += (count1 * count2); } } return dotProd; } /** * Returns |c1 - c2|. * * @return The difference between sets c1 and c2. */ public static Counter absoluteDifference(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { double newCount = Math.abs(c1.getCount(key) - c2.getCount(key)); if (newCount > 0) { result.setCount(key, newCount); } } return result; } /** * Returns c1 divided by c2. Note that this can create NaN if c1 has non-zero * counts for keys that c2 has zero counts. * * @return c1 divided by c2. */ public static Counter division(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) / c2.getCount(key)); } return result; } /** * Returns c1 divided by c2. Safe - will not calculate scores for keys that are zero or that do not exist in c2 * * @return c1 divided by c2. */ public static Counter divisionNonNaN(Counter c1, Counter c2) { Counter result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { if(c2.getCount(key) != 0) result.setCount(key, c1.getCount(key) / c2.getCount(key)); } return result; } /** * Calculates the entropy of the given counter (in bits). This method * internally uses normalized counts (so they sum to one), but the value * returned is meaningless if some of the counts are negative. * * @return The entropy of the given counter (in bits) */ public static double entropy(Counter c) { double entropy = 0.0; double total = c.totalCount(); for (E key : c.keySet()) { double count = c.getCount(key); if (count == 0) { continue; // 0.0 doesn't add entropy but may cause -Inf } count /= total; // use normalized count entropy -= count * (Math.log(count) / LOG_E_2); } return entropy; } /** * Note that this implementation doesn't normalize the "from" Counter. It * does, however, normalize the "to" Counter. Result is meaningless if any of * the counts are negative. * * @return The cross entropy of H(from, to) */ public static double crossEntropy(Counter from, Counter to) { double tot2 = to.totalCount(); double result = 0.0; for (E key : from.keySet()) { double count1 = from.getCount(key); if (count1 == 0.0) { continue; } double count2 = to.getCount(key); double logFract = Math.log(count2 / tot2); if (logFract == Double.NEGATIVE_INFINITY) { return Double.NEGATIVE_INFINITY; // can't recover } result += count1 * (logFract / LOG_E_2); // express it in log base 2 } return result; } /** * Calculates the KL divergence between the two counters. That is, it * calculates KL(from || to). This method internally uses normalized counts * (so they sum to one), but the value returned is meaningless if any of the * counts are negative. In other words, how well can c1 be represented by c2. * if there is some value in c1 that gets zero prob in c2, then return * positive infinity. * * @return The KL divergence between the distributions */ public static double klDivergence(Counter from, Counter to) { double result = 0.0; double tot = (from.totalCount()); double tot2 = (to.totalCount()); // System.out.println("tot is " + tot + " tot2 is " + tot2); for (E key : from.keySet()) { double num = (from.getCount(key)); if (num == 0) { continue; } num /= tot; double num2 = (to.getCount(key)); num2 /= tot2; // System.out.println("num is " + num + " num2 is " + num2); double logFract = Math.log(num / num2); if (logFract == Double.NEGATIVE_INFINITY) { return Double.NEGATIVE_INFINITY; // can't recover } result += num * (logFract / LOG_E_2); // express it in log base 2 } return result; } /** * Calculates the Jensen-Shannon divergence between the two counters. That is, * it calculates 1/2 [KL(c1 || avg(c1,c2)) + KL(c2 || avg(c1,c2))] . * This code assumes that the Counters have only non-negative values in them. * * @return The Jensen-Shannon divergence between the distributions */ public static double jensenShannonDivergence(Counter c1, Counter c2) { // need to normalize the counters first before averaging them! Else buggy if not a probability distribution Counter d1 = asNormalizedCounter(c1); Counter d2 = asNormalizedCounter(c2); Counter average = average(d1, d2); double kl1 = klDivergence(d1, average); double kl2 = klDivergence(d2, average); return (kl1 + kl2) / 2.0; } /** * Calculates the skew divergence between the two counters. That is, it * calculates KL(c1 || (c2*skew + c1*(1-skew))) . In other words, how well can * c1 be represented by a "smoothed" c2. * * @return The skew divergence between the distributions */ public static double skewDivergence(Counter c1, Counter c2, double skew) { Counter d1 = asNormalizedCounter(c1); Counter d2 = asNormalizedCounter(c2); Counter average = linearCombination(d2, skew, d1, (1.0 - skew)); return klDivergence(d1, average); } /** * Return the l2 norm (Euclidean vector length) of a Counter. * Implementation note: The method name favors legibility of the L over * the convention of using lowercase names for methods. * * @param c The Counter * @return Its length */ public static > double L2Norm(C c) { return Math.sqrt(Counters.sumSquares(c)); } /** * Return the sum of squares (squared L2 norm). * * @param c The Counter * @return the L2 norm of the values in c */ public static > double sumSquares(C c) { double lenSq = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); lenSq += (count * count); } return lenSq; } /** * Return the L1 norm of a counter. Implementation note: The method * name favors legibility of the L over the convention of using lowercase * names for methods. * * @param c The Counter * @return Its length */ public static > double L1Norm(C c) { double sumAbs = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); if (count != 0.0) { sumAbs += Math.abs(count); } } return sumAbs; } /** * L2 normalize a counter. * * @param c The {@link Counter} to be L2 normalized. This counter is not * modified. * @return A new l2-normalized Counter based on c. */ public static > C L2Normalize(C c) { return scale(c, 1.0 / L2Norm(c)); } /** * L2 normalize a counter in place. * * @param c The {@link Counter} to be L2 normalized. This counter is modified * @return the passed in counter l2-normalized */ public static Counter L2NormalizeInPlace(Counter c) { return multiplyInPlace(c, 1.0 / L2Norm(c)); } /** * For counters with large # of entries, this scales down each entry in the * sum, to prevent an extremely large sum from building up and overwhelming * the max double. This may also help reduce error by preventing loss of SD's * with extremely large values. * * @param * @param */ public static > double saferL2Norm(C c) { double maxVal = 0.0; for (E key : c.keySet()) { double value = Math.abs(c.getCount(key)); if (value > maxVal) maxVal = value; } double sqrSum = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); sqrSum += Math.pow(count / maxVal, 2); } return maxVal * Math.sqrt(sqrSum); } /** * L2 normalize a counter, using the "safer" L2 normalizer. * * @param c The {@link Counter} to be L2 normalized. This counter is not * modified. * @return A new L2-normalized Counter based on c. */ public static > C saferL2Normalize(C c) { return scale(c, 1.0 / saferL2Norm(c)); } public static double cosine(Counter c1, Counter c2) { double dotProd = 0.0; double lsq1 = 0.0; double lsq2 = 0.0; for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (count1 != 0.0) { lsq1 += (count1 * count1); double count2 = c2.getCount(key); if (count2 != 0.0) { // this is the inner product dotProd += (count1 * count2); } } } for (E key : c2.keySet()) { double count2 = c2.getCount(key); if (count2 != 0.0) { lsq2 += (count2 * count2); } } if (lsq1 != 0.0 && lsq2 != 0.0) { double denom = (Math.sqrt(lsq1) * Math.sqrt(lsq2)); return dotProd / denom; } return 0.0; } /** * Returns a new Counter with counts averaged from the two given Counters. The * average Counter will contain the union of keys in both source Counters, and * each count will be the average of the two source counts for that key, where * as usual a missing count in one Counter is treated as count 0. * * @return A new counter with counts that are the mean of the resp. counts in * the given counters. */ public static Counter average(Counter c1, Counter c2) { Counter average = c1.getFactory().create(); Set allKeys = Generics.newHashSet(c1.keySet()); allKeys.addAll(c2.keySet()); for (E key : allKeys) { average.setCount(key, (c1.getCount(key) + c2.getCount(key)) * 0.5); } return average; } /** * Returns a Counter which is a weighted average of c1 and c2. Counts from c1 * are weighted with weight w1 and counts from c2 are weighted with w2. */ public static Counter linearCombination(Counter c1, double w1, Counter c2, double w2) { Counter result = c1.getFactory().create(); for (E o : c1.keySet()) { result.incrementCount(o, c1.getCount(o) * w1); } for (E o : c2.keySet()) { result.incrementCount(o, c2.getCount(o) * w2); } return result; } public static double pointwiseMutualInformation(Counter var1Distribution, Counter var2Distribution, Counter> jointDistribution, Pair values) { double var1Prob = var1Distribution.getCount(values.first); double var2Prob = var2Distribution.getCount(values.second); double jointProb = jointDistribution.getCount(values); double pmi = Math.log(jointProb) - Math.log(var1Prob) - Math.log(var2Prob); return pmi / LOG_E_2; } /** * Calculate h-Index (Hirsch, 2005) of an author. * * A scientist has index h if h of their Np papers have at least h citations * each, and the other (Np − h) papers have at most h citations each. * * @param citationCounts * Citation counts for each of the articles written by the author. * The keys can be anything, but the values should be integers. * @return The h-Index of the author. */ public static int hIndex(Counter citationCounts) { Counter countCounts = new ClassicCounter<>(); for (double value : citationCounts.values()) { for (int i = 0; i <= value; ++i) { countCounts.incrementCount(i); } } List citationCountValues = CollectionUtils.sorted(countCounts.keySet()); Collections.reverse(citationCountValues); for (int citationCount : citationCountValues) { double occurrences = countCounts.getCount(citationCount); if (occurrences >= citationCount) { return citationCount; } } return 0; } @SuppressWarnings("unchecked") public static > C perturbCounts(C c, Random random, double p) { C result = (C) c.getFactory().create(); for (E key : c.keySet()) { double count = c.getCount(key); double noise = -Math.log(1.0 - random.nextDouble()); // inverse of CDF for // exponential // distribution // log.info("noise=" + noise); double perturbedCount = count + noise * p; result.setCount(key, perturbedCount); } return result; } /** * Great for debugging. * */ public static void printCounterComparison(Counter a, Counter b) { printCounterComparison(a, b, System.err); } /** * Great for debugging. * */ public static void printCounterComparison(Counter a, Counter b, PrintStream out) { printCounterComparison(a, b, new PrintWriter(out, true)); } /** * Prints one or more lines (with a newline at the end) describing the * difference between the two Counters. Great for debugging. * */ public static void printCounterComparison(Counter a, Counter b, PrintWriter out) { if (a.equals(b)) { out.println("Counters are equal."); return; } for (E key : a.keySet()) { double aCount = a.getCount(key); double bCount = b.getCount(key); if (Math.abs(aCount - bCount) > 1e-5) { out.println("Counters differ on key " + key + '\t' + a.getCount(key) + " vs. " + b.getCount(key)); } } // left overs Set rest = Generics.newHashSet(b.keySet()); rest.removeAll(a.keySet()); for (E key : rest) { double aCount = a.getCount(key); double bCount = b.getCount(key); if (Math.abs(aCount - bCount) > 1e-5) { out.println("Counters differ on key " + key + '\t' + a.getCount(key) + " vs. " + b.getCount(key)); } } } public static Counter getCountCounts(Counter c) { Counter result = new ClassicCounter<>(); for (double v : c.values()) { result.incrementCount(v); } return result; } /** * Returns a new Counter which is scaled by the given scale factor. * * @param c The counter to scale. It is not changed * @param s The constant to scale the counter by * @return A new Counter which is the argument scaled by the given scale * factor. */ @SuppressWarnings("unchecked") public static > C scale(C c, double s) { C scaled = (C) c.getFactory().create(); for (E key : c.keySet()) { scaled.setCount(key, c.getCount(key) * s); } return scaled; } /** * Returns a new Counter which is the input counter with log tf scaling * * @param c The counter to scale. It is not changed * @param base The base of the logarithm used for tf scaling by 1 + log tf * @return A new Counter which is the argument scaled by the given scale * factor. */ @SuppressWarnings("unchecked") public static > C tfLogScale(C c, double base) { C scaled = (C) c.getFactory().create(); for (E key : c.keySet()) { double cnt = c.getCount(key); double scaledCnt = 0.0; if (cnt > 0) { scaledCnt = 1.0 + SloppyMath.log(cnt, base); } scaled.setCount(key, scaledCnt); } return scaled; } public static > void printCounterSortedByKeys(Counter c) { List keyList = new ArrayList<>(c.keySet()); Collections.sort(keyList); for (E o : keyList) { System.out.println(o + ":" + c.getCount(o)); } } /** * Loads a Counter from a text file. File must have the format of one * key/count pair per line, separated by whitespace. * * @param filename The path to the file to load the Counter from * @param c The Class to instantiate each member of the set. Must have a * String constructor. * @return The counter loaded from the file. */ public static ClassicCounter loadCounter(String filename, Class c) throws RuntimeException { ClassicCounter counter = new ClassicCounter<>(); loadIntoCounter(filename, c, counter); return counter; } /** * Loads a Counter from a text file. File must have the format of one * key/count pair per line, separated by whitespace. * * @param filename The path to the file to load the Counter from * @param c The Class to instantiate each member of the set. Must have a * String constructor. * @return The counter loaded from the file. */ public static IntCounter loadIntCounter(String filename, Class c) throws Exception { IntCounter counter = new IntCounter<>(); loadIntoCounter(filename, c, counter); return counter; } /** * Loads a file into an GenericCounter. */ private static void loadIntoCounter(String filename, Class c, Counter counter) throws RuntimeException { try { Constructor m = c.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename); for (String line; (line = in.readLine()) != null;) { String[] tokens = line.trim().split("\\s+"); if (tokens.length != 2) throw new RuntimeException(); double value = Double.parseDouble(tokens[1]); counter.setCount(m.newInstance(tokens[0]), value); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } /** * Saves a Counter as one key/count pair per line separated by white space to * the given OutputStream. Does not close the stream. */ public static void saveCounter(Counter c, OutputStream stream) { PrintStream out = new PrintStream(stream); for (E key : c.keySet()) { out.println(key + " " + c.getCount(key)); } } /** * Saves a Counter to a text file. Counter written as one key/count pair per * line, separated by whitespace. */ public static void saveCounter(Counter c, String filename) throws IOException { FileOutputStream fos = new FileOutputStream(filename); saveCounter(c, fos); fos.close(); } public static TwoDimensionalCounter load2DCounter(String filename, Class t1, Class t2) throws RuntimeException { try { TwoDimensionalCounter tdc = new TwoDimensionalCounter<>(); loadInto2DCounter(filename, t1, t2, tdc); return tdc; } catch (Exception e) { throw new RuntimeException(e); } } public static void loadInto2DCounter(String filename, Class t1, Class t2, TwoDimensionalCounter tdc) throws RuntimeException { try { Constructor m1 = t1.getConstructor(String.class); Constructor m2 = t2.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename);// new // BufferedReader(new // FileReader(filename)); for (String line; (line = in.readLine()) != null;) { String[] tuple = line.trim().split("\t"); String outer = tuple[0]; String inner = tuple[1]; String valStr = tuple[2]; tdc.setCount(m1.newInstance(outer.trim()), m2.newInstance(inner.trim()), Double.parseDouble(valStr.trim())); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } public static void loadIncInto2DCounter(String filename, Class t1, Class t2, TwoDimensionalCounterInterface tdc) throws RuntimeException { try { Constructor m1 = t1.getConstructor(String.class); Constructor m2 = t2.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename);// new // BufferedReader(new // FileReader(filename)); for (String line; (line = in.readLine()) != null;) { String[] tuple = line.trim().split("\t"); String outer = tuple[0]; String inner = tuple[1]; String valStr = tuple[2]; tdc.incrementCount(m1.newInstance(outer.trim()), m2.newInstance(inner.trim()), Double.parseDouble(valStr.trim())); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } public static void save2DCounter(TwoDimensionalCounter tdc, String filename) throws IOException { PrintWriter out = new PrintWriter(new FileWriter(filename)); for (T1 outer : tdc.firstKeySet()) { for (T2 inner : tdc.secondKeySet()) { out.println(outer + "\t" + inner + '\t' + tdc.getCount(outer, inner)); } } out.close(); } public static void save2DCounterSorted(TwoDimensionalCounterInterface tdc, String filename) throws IOException { PrintWriter out = new PrintWriter(new FileWriter(filename)); for (T1 outer : tdc.firstKeySet()) { Counter c = tdc.getCounter(outer); List keys = Counters.toSortedList(c); for (T2 inner : keys) { out.println(outer + "\t" + inner + '\t' + c.getCount(inner)); } } out.close(); } /** * Serialize a counter into an efficient string TSV * @param c The counter to serialize * @param filename The file to serialize to * @param minMagnitude Ignore values under this magnitude * @throws IOException * * @see Counters#deserializeStringCounter(String) */ public static void serializeStringCounter(Counter c, String filename, double minMagnitude) throws IOException { PrintWriter writer = IOUtils.getPrintWriter(filename); for (Entry entry : c.entrySet()) { if (Math.abs(entry.getValue()) < minMagnitude) { continue; } Triple parts = SloppyMath.segmentDouble(entry.getValue()); writer.println( entry.getKey().replace('\t', 'ߝ') + "\t" + (parts.first ? '-' : '+') + "\t" + parts.second + "\t" + parts.third ); } writer.close(); } /** @see Counters#serializeStringCounter(Counter, String, double) */ public static void serializeStringCounter(Counter c, String filename) throws IOException { serializeStringCounter(c, filename, 0.0); } /** * Read a Counter from a serialized file * @param filename The file to read from * * @see Counters#serializeStringCounter(Counter, String, double) */ public static ClassicCounter deserializeStringCounter(String filename) throws IOException { String[] fields = new String[4]; BufferedReader reader = IOUtils.readerFromString(filename); String line; ClassicCounter counts = new ClassicCounter<>(1000000); while ( (line = reader.readLine()) != null) { StringUtils.splitOnChar(fields, line, '\t'); long mantissa = SloppyMath.parseInt(fields[2]); int exponent = (int) SloppyMath.parseInt(fields[3]); double value = SloppyMath.parseDouble(fields[1].equals("-"), mantissa, exponent); counts.setCount(fields[0], value); } return counts; } public static void serializeCounter(Counter c, String filename) throws IOException { // serialize to file ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filename))); out.writeObject(c); out.close(); } public static ClassicCounter deserializeCounter(String filename) throws Exception { // reconstitute ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(filename))); ClassicCounter c = ErasureUtils.uncheckedCast(in.readObject()); in.close(); return c; } /** * Returns a string representation of a Counter, displaying the keys and their * counts in decreasing order of count. At most k keys are displayed. * * Note that this method subsumes many of the other toString methods, e.g.: * * toString(c, k) and toBiggestValuesFirstString(c, k) => toSortedString(c, k, * "%s=%f", ", ", "[%s]") * * toVerticalString(c, k) => toSortedString(c, k, "%2$g\t%1$s", "\n", "%s\n") * * @param counter A Counter. * @param k The number of keys to include. Use Integer.MAX_VALUE to include * all keys. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner The string used between pairs of key/value strings. * @param wrapperFormat * The format string for wrapping text around the joined items, where * the joined item string value is "%s". * @return The top k values from the Counter, formatted as specified. */ public static String toSortedString(Counter counter, int k, String itemFormat, String joiner, String wrapperFormat) { PriorityQueue queue = toPriorityQueue(counter); List strings = new ArrayList<>(); for (int rank = 0; rank < k && !queue.isEmpty(); ++rank) { T key = queue.removeFirst(); double value = counter.getCount(key); strings.add(String.format(itemFormat, key, value)); } return String.format(wrapperFormat, StringUtils.join(strings, joiner)); } /** * Returns a string representation of a Counter, displaying the keys and their * counts in decreasing order of count. At most k keys are displayed. * * @param counter A Counter. * @param k * The number of keys to include. Use Integer.MAX_VALUE to include * all keys. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner * The string used between pairs of key/value strings. * @return The top k values from the Counter, formatted as specified. */ public static String toSortedString(Counter counter, int k, String itemFormat, String joiner) { return toSortedString(counter, k, itemFormat, joiner, "%s"); } /** * Returns a string representation of a Counter, where (key, value) pairs are * sorted by key, and formatted as specified. * * @param counter The Counter. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner * The string used between pairs of key/value strings. * @param wrapperFormat * The format string for wrapping text around the joined items, where * the joined item string value is "%s". * @return The Counter, formatted as specified. */ public static > String toSortedByKeysString(Counter counter, String itemFormat, String joiner, String wrapperFormat) { List strings = new ArrayList<>(); for (T key : CollectionUtils.sorted(counter.keySet())) { strings.add(String.format(itemFormat, key, counter.getCount(key))); } return String.format(wrapperFormat, StringUtils.join(strings, joiner)); } /** * Returns a string representation which includes no more than the * maxKeysToPrint elements with largest counts. If maxKeysToPrint is * non-positive, all elements are printed. * * @param counter The Counter * @param maxKeysToPrint Max keys to print * @return A partial string representation */ public static String toString(Counter counter, int maxKeysToPrint) { return Counters.toPriorityQueue(counter).toString(maxKeysToPrint); } public static String toString(Counter counter, NumberFormat nf) { StringBuilder sb = new StringBuilder(); sb.append('{'); List list = ErasureUtils.sortedIfPossible(counter.keySet()); // */ for (Iterator iter = list.iterator(); iter.hasNext();) { E key = iter.next(); sb.append(key); sb.append('='); sb.append(nf.format(counter.getCount(key))); if (iter.hasNext()) { sb.append(", "); } } sb.append('}'); return sb.toString(); } /** * Pretty print a Counter. This one has more flexibility in formatting, and * doesn't sort the keys. */ public static String toString(Counter counter, NumberFormat nf, String preAppend, String postAppend, String keyValSeparator, String itemSeparator) { StringBuilder sb = new StringBuilder(); sb.append(preAppend); // List list = new ArrayList(map.keySet()); // try { // Collections.sort(list); // see if it can be sorted // } catch (Exception e) { // } for (Iterator iter = counter.keySet().iterator(); iter.hasNext();) { E key = iter.next(); double d = counter.getCount(key); sb.append(key); sb.append(keyValSeparator); sb.append(nf.format(d)); if (iter.hasNext()) { sb.append(itemSeparator); } } sb.append(postAppend); return sb.toString(); } public static String toBiggestValuesFirstString(Counter c) { return toPriorityQueue(c).toString(); } // TODO this method seems badly written. It should exploit topK printing of PriorityQueue public static String toBiggestValuesFirstString(Counter c, int k) { PriorityQueue pq = toPriorityQueue(c); PriorityQueue largestK = new BinaryHeapPriorityQueue<>(); // TODO: Is there any reason the original (commented out) line is better // than the one replacing it? // while (largestK.size() < k && ((Iterator)pq).hasNext()) { while (largestK.size() < k && !pq.isEmpty()) { double firstScore = pq.getPriority(pq.getFirst()); E first = pq.removeFirst(); largestK.changePriority(first, firstScore); } return largestK.toString(); } public static String toBiggestValuesFirstString(Counter c, int k, Index index) { PriorityQueue pq = toPriorityQueue(c); PriorityQueue largestK = new BinaryHeapPriorityQueue<>(); // while (largestK.size() < k && ((Iterator)pq).hasNext()) { //same as above while (largestK.size() < k && !pq.isEmpty()) { double firstScore = pq.getPriority(pq.getFirst()); int first = pq.removeFirst(); largestK.changePriority(index.get(first), firstScore); } return largestK.toString(); } public static String toVerticalString(Counter c) { return toVerticalString(c, Integer.MAX_VALUE); } public static String toVerticalString(Counter c, int k) { return toVerticalString(c, k, "%g\t%s", false); } public static String toVerticalString(Counter c, String fmt) { return toVerticalString(c, Integer.MAX_VALUE, fmt, false); } public static String toVerticalString(Counter c, int k, String fmt) { return toVerticalString(c, k, fmt, false); } /** * Returns a {@code String} representation of the {@code k} keys * with the largest counts in the given {@link Counter}, using the given * format string. * * @param c A Counter * @param k How many keys to print * @param fmt A format string, such as "%.0f\t%s" (do not include final "%n"). * If swap is false, you will get val, key as arguments, if true, key, val. * @param swap Whether the count should appear after the key */ public static String toVerticalString(Counter c, int k, String fmt, boolean swap) { PriorityQueue q = Counters.toPriorityQueue(c); List sortedKeys = q.toSortedList(); StringBuilder sb = new StringBuilder(); int i = 0; for (Iterator keyI = sortedKeys.iterator(); keyI.hasNext() && i < k; i++) { E key = keyI.next(); double val = q.getPriority(key); if (swap) { sb.append(String.format(fmt, key, val)); } else { sb.append(String.format(fmt, val, key)); } if (keyI.hasNext()) { sb.append('\n'); } } return sb.toString(); } /** * * @return Returns the maximum element of c that is within the restriction * Collection */ public static E restrictedArgMax(Counter c, Collection restriction) { E maxKey = null; double max = Double.NEGATIVE_INFINITY; for (E key : restriction) { double count = c.getCount(key); if (count > max) { max = count; maxKey = key; } } return maxKey; } public static Counter toCounter(double[] counts, Index index) { if (index.size() < counts.length) throw new IllegalArgumentException("Index not large enough to name all the array elements!"); Counter c = new ClassicCounter<>(); for (int i = 0; i < counts.length; i++) { if (counts[i] != 0.0) c.setCount(index.get(i), counts[i]); } return c; } /** * Turns the given map and index into a counter instance. For each entry in * counts, its key is converted to a counter key via lookup in the given * index. */ public static Counter toCounter(Map counts, Index index) { Counter counter = new ClassicCounter<>(); for (Map.Entry entry : counts.entrySet()) { counter.setCount(index.get(entry.getKey()), entry.getValue().doubleValue()); } return counter; } /** * Convert a counter to an array using a specified key index. Infer the dimension of * the returned vector from the index. */ public static double[] asArray(Counter counter, Index index) { return Counters.asArray(counter, index, index.size()); } /** * Convert a counter to an array using a specified key index. This method does *not* expand * the index, so all keys in the set keys(counter) - keys(index) are not added to the * output array. Also note that if counter is being used as a sparse array, the result * will be a dense array with zero entries. * * @return the values corresponding to the index */ public static double[] asArray(Counter counter, Index index, int dimension) { if (index.size() == 0) { throw new IllegalArgumentException("Empty index"); } Set keys = counter.keySet(); double[] array = new double[dimension]; for (E key : keys) { int i = index.indexOf(key); if (i >= 0) { array[i] = counter.getCount(key); } } return array; } /** * Convert a counter to an array, the order of the array is random */ public static double[] asArray(Counter counter) { Set keys = counter.keySet(); double[] array = new double[counter.size()]; int i = 0; for (E key : keys) { array[i] = counter.getCount(key); i++; } return array; } /** * Creates a new TwoDimensionalCounter where all the counts are scaled by d. * Internally, uses Counters.scale(); * * @return The TwoDimensionalCounter */ public static TwoDimensionalCounter scale(TwoDimensionalCounter c, double d) { TwoDimensionalCounter result = new TwoDimensionalCounter<>(c.getOuterMapFactory(), c.getInnerMapFactory()); for (T1 key : c.firstKeySet()) { ClassicCounter ctr = c.getCounter(key); result.setCounter(key, scale(ctr, d)); } return result; } static final Random RAND = new Random(); /** * Does not assumes c is normalized. * * @return A sample from c */ public static T sample(Counter c, Random rand) { // OMITTED: Seems like there should be a way to directly check if T is comparable // Set keySet = c.keySet(); // if (!keySet.isEmpty() && keySet.iterator().next() instanceof Comparable) { // List l = new ArrayList(keySet); // Collections.sort(l); // objects = l; // } else { // throw new RuntimeException("Results won't be stable since Counters keys are comparable."); // } if (rand == null) rand = RAND; double r = rand.nextDouble() * c.totalCount(); double total = 0.0; for (T t : c.keySet()) { // arbitrary ordering, but presumably stable total += c.getCount(t); if (total >= r) return t; } // only chance of reaching here is if c isn't properly normalized, or if // double math makes total<1.0 return c.keySet().iterator().next(); } /** * Does not assumes c is normalized. * * @return A sample from c */ public static T sample(Counter c) { return sample(c, null); } /** * Returns a counter where each element corresponds to the normalized count of * the corresponding element in c raised to the given power. */ public static Counter powNormalized(Counter c, double temp) { Counter d = c.getFactory().create(); double total = c.totalCount(); for (E e : c.keySet()) { d.setCount(e, Math.pow(c.getCount(e) / total, temp)); } return d; } public static Counter pow(Counter c, double temp) { Counter d = c.getFactory().create(); for (T t : c.keySet()) { d.setCount(t, Math.pow(c.getCount(t), temp)); } return d; } public static void powInPlace(Counter c, double temp) { for (T t : c.keySet()) { c.setCount(t, Math.pow(c.getCount(t), temp)); } } public static Counter exp(Counter c) { Counter d = c.getFactory().create(); for (T t : c.keySet()) { d.setCount(t, Math.exp(c.getCount(t))); } return d; } public static void expInPlace(Counter c) { for (T t : c.keySet()) { c.setCount(t, Math.exp(c.getCount(t))); } } public static Counter diff(Counter goldFeatures, Counter guessedFeatures) { Counter result = goldFeatures.getFactory().create(); for (T key : Sets.union(goldFeatures.keySet(), guessedFeatures.keySet())) { result.setCount(key, goldFeatures.getCount(key) - guessedFeatures.getCount(key)); } retainNonZeros(result); return result; } /** * Default equality comparison for two counters potentially backed by * alternative implementations. */ public static boolean equals(Counter o1, Counter o2) { return equals(o1, o2, 0.0); } /** * Equality comparison between two counters, allowing for a tolerance fudge factor. */ public static boolean equals(Counter o1, Counter o2, double tolerance) { if (o1 == o2) { return true; } if (Math.abs(o1.totalCount() - o2.totalCount()) > tolerance) { return false; } if (!o1.keySet().equals(o2.keySet())) { return false; } for (E key : o1.keySet()) { if (Math.abs(o1.getCount(key) - o2.getCount(key)) > tolerance) { return false; } } return true; } /** * Returns unmodifiable view of the counter. changes to the underlying Counter * are written through to this Counter. * * @param counter * The counter * @return unmodifiable view of the counter */ public static Counter unmodifiableCounter(final Counter counter) { return new AbstractCounter() { public void clear() { throw new UnsupportedOperationException(); } public boolean containsKey(T key) { return counter.containsKey(key); } public double getCount(Object key) { return counter.getCount(key); } public Factory> getFactory() { return counter.getFactory(); } public double remove(T key) { throw new UnsupportedOperationException(); } public void setCount(T key, double value) { throw new UnsupportedOperationException(); } @Override public double incrementCount(T key, double value) { throw new UnsupportedOperationException(); } @Override public double incrementCount(T key) { throw new UnsupportedOperationException(); } @Override public double logIncrementCount(T key, double value) { throw new UnsupportedOperationException(); } public int size() { return counter.size(); } public double totalCount() { return counter.totalCount(); } public Collection values() { return counter.values(); } public Set keySet() { return Collections.unmodifiableSet(counter.keySet()); } public Set> entrySet() { return Collections.unmodifiableSet(new AbstractSet>() { @Override public Iterator> iterator() { return new Iterator>() { final Iterator> inner = counter.entrySet().iterator(); public boolean hasNext() { return inner.hasNext(); } public Entry next() { return new Map.Entry() { final Entry e = inner.next(); @Override public T getKey() { return e.getKey(); } @Override @SuppressWarnings( { "UnnecessaryBoxing", "UnnecessaryUnboxing" }) public Double getValue() { return Double.valueOf(e.getValue().doubleValue()); } @Override public Double setValue(Double value) { throw new UnsupportedOperationException(); } }; } @Override public void remove() { throw new UnsupportedOperationException(); } }; } @Override public int size() { return counter.size(); } }); } @Override public void setDefaultReturnValue(double rv) { throw new UnsupportedOperationException(); } @Override public double defaultReturnValue() { return counter.defaultReturnValue(); } /** * {@inheritDoc} */ public void prettyLog(RedwoodChannels channels, String description) { PrettyLogger.log(channels, description, asMap(this)); } }; } // end unmodifiableCounter() /** * Returns a counter whose keys are the elements in this priority queue, and * whose counts are the priorities in this queue. In the event there are * multiple instances of the same element in the queue, the counter's count * will be the sum of the instances' priorities. * */ public static Counter asCounter(FixedPrioritiesPriorityQueue p) { FixedPrioritiesPriorityQueue pq = p.clone(); ClassicCounter counter = new ClassicCounter<>(); while (pq.hasNext()) { double priority = pq.getPriority(); E element = pq.next(); counter.incrementCount(element, priority); } return counter; } /** * Returns a counter view of the given map. Infers the numeric type of the * values from the first element in map.values(). */ @SuppressWarnings("unchecked") public static Counter fromMap(final Map map) { if (map.isEmpty()) { throw new IllegalArgumentException("Map must have at least one element" + " to infer numeric type; add an element first or use e.g." + " fromMap(map, Integer.class)"); } return fromMap(map, (Class) map.values().iterator().next().getClass()); } /** * Returns a counter view of the given map. The type parameter is the type of * the values in the map, which because of Java's generics type erasure, can't * be discovered by reflection if the map is currently empty. */ public static Counter fromMap(final Map map, final Class type) { // get our initial total double initialTotal = 0.0; for (Map.Entry entry : map.entrySet()) { initialTotal += entry.getValue().doubleValue(); } // and pass it in to the returned inner class with a final variable final double initialTotalFinal = initialTotal; return new AbstractCounter() { double total = initialTotalFinal; double defRV = 0.0; @Override public void clear() { map.clear(); total = 0.0; } @Override public boolean containsKey(E key) { return map.containsKey(key); } @Override public void setDefaultReturnValue(double rv) { defRV = rv; } @Override public double defaultReturnValue() { return defRV; } @Override @SuppressWarnings("unchecked") public boolean equals(Object o) { if (this == o) { return true; } else if (!(o instanceof Counter)) { return false; } else { return Counters.equals(this, (Counter) o); } } @Override public int hashCode() { return map.hashCode(); } public Set> entrySet() { return new AbstractSet>() { Set> entries = map.entrySet(); @Override public Iterator> iterator() { return new Iterator>() { Iterator> it = entries.iterator(); Entry lastEntry; // = null; public boolean hasNext() { return it.hasNext(); } public Entry next() { final Entry entry = it.next(); lastEntry = entry; return new Entry() { public E getKey() { return entry.getKey(); } public Double getValue() { return entry.getValue().doubleValue(); } public Double setValue(Double value) { final double lastValue = entry.getValue().doubleValue(); double rv; if (type == Double.class) { rv = ErasureUtils.> uncheckedCast(entry).setValue(value); } else if (type == Integer.class) { rv = ErasureUtils.> uncheckedCast(entry).setValue(value.intValue()); } else if (type == Float.class) { rv = ErasureUtils.> uncheckedCast(entry).setValue(value.floatValue()); } else if (type == Long.class) { rv = ErasureUtils.> uncheckedCast(entry).setValue(value.longValue()); } else if (type == Short.class) { rv = ErasureUtils.> uncheckedCast(entry).setValue(value.shortValue()); } else { throw new RuntimeException("Unrecognized numeric type in wrapped counter"); } // need to call getValue().doubleValue() to make sure // we keep the same precision as the underlying map total += entry.getValue().doubleValue() - lastValue; return rv; } }; } public void remove() { total -= lastEntry.getValue().doubleValue(); it.remove(); } }; } @Override public int size() { return map.size(); } }; } public double getCount(Object key) { final Number value = map.get(key); return value != null ? value.doubleValue() : defRV; } public Factory> getFactory() { return new Factory>() { private static final long serialVersionUID = -4063129407369590522L; public Counter create() { // return a HashMap backed by the same numeric type to // keep the precision of the returned counter consistent with // this one's precision return fromMap(Generics.newHashMap(), type); } }; } public Set keySet() { return new AbstractSet() { @Override public Iterator iterator() { return new Iterator() { Iterator it = map.keySet().iterator(); public boolean hasNext() { return it.hasNext(); } public E next() { return it.next(); } public void remove() { throw new UnsupportedOperationException("Cannot remove from key set"); } }; } @Override public int size() { return map.size(); } }; } public double remove(E key) { final Number removed = map.remove(key); if (removed != null) { final double rv = removed.doubleValue(); total -= rv; return rv; } return defRV; } public void setCount(E key, double value) { final Double lastValue; double newValue; if (type == Double.class) { lastValue = ErasureUtils.> uncheckedCast(map).put(key, value); newValue = value; } else if (type == Integer.class) { final Integer last = ErasureUtils.> uncheckedCast(map).put(key, (int) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((int) value); } else if (type == Float.class) { final Float last = ErasureUtils.> uncheckedCast(map).put(key, (float) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((float) value); } else if (type == Long.class) { final Long last = ErasureUtils.> uncheckedCast(map).put(key, (long) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((long) value); } else if (type == Short.class) { final Short last = ErasureUtils.> uncheckedCast(map).put(key, (short) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((short) value); } else { throw new RuntimeException("Unrecognized numeric type in wrapped counter"); } // need to use newValue instead of value to make sure we // keep same precision as underlying map. total += newValue - (lastValue != null ? lastValue : 0); } public int size() { return map.size(); } public double totalCount() { return total; } public Collection values() { return new AbstractCollection() { @Override public Iterator iterator() { return new Iterator() { final Iterator it = map.values().iterator(); public boolean hasNext() { return it.hasNext(); } public Double next() { return it.next().doubleValue(); } public void remove() { throw new UnsupportedOperationException("Cannot remove from values collection"); } }; } @Override public int size() { return map.size(); } }; } /** * {@inheritDoc} */ public void prettyLog(RedwoodChannels channels, String description) { PrettyLogger.log(channels, description, map); } }; } // end fromMap() /** * Returns a map view of the given counter. */ public static Map asMap(final Counter counter) { return new AbstractMap() { @Override public int size() { return counter.size(); } @Override public Set> entrySet() { return counter.entrySet(); } @Override @SuppressWarnings("unchecked") public boolean containsKey(Object key) { return counter.containsKey((E) key); } @Override @SuppressWarnings("unchecked") public Double get(Object key) { return counter.getCount((E) key); } @Override public Double put(E key, Double value) { double last = counter.getCount(key); counter.setCount(key, value); return last; } @Override @SuppressWarnings("unchecked") public Double remove(Object key) { return counter.remove((E) key); } @Override public Set keySet() { return counter.keySet(); } }; } /** * Check if this counter is a uniform distribution. * That is, it should sum to 1.0, and every value should be equal to every other value. * @param distribution The distribution to check. * @param tolerance The tolerance for floating point error, in both the equality and total count checks. * @param The type of the counter. * @return True if this counter is the uniform distribution over its domain. */ public static boolean isUniformDistribution(Counter distribution, double tolerance) { double value = Double.NaN; double totalCount = 0.0; for (double val : distribution.values()) { if (Double.isNaN(value)) { value = val; } if (Math.abs(val - value) > tolerance) { return false; } totalCount += val; } return Math.abs(totalCount - 1.0) < tolerance; } /** * Default comparator for breaking ties in argmin and argmax. * //TODO: What type should this be? * // Unused, so who cares? * private static final Comparator hashCodeComparator = * new Comparator() { * public int compare(Object o1, Object o2) { * return o1.hashCode() - o2.hashCode(); * } * * public boolean equals(Comparator comparator) { * return (comparator == this); * } * }; */ /** * Comparator that uses natural ordering. Returns 0 if o1 is not Comparable. */ static class NaturalComparator implements Comparator { public NaturalComparator() { } @Override public String toString() { return "NaturalComparator"; } @SuppressWarnings("unchecked") public int compare(E o1, E o2) { if (o1 instanceof Comparable) { return (((Comparable) o1).compareTo(o2)); } return 0; // soft-fail } } /** * * @param * @param originalCounter * @return a copy of the original counter */ public static Counter getCopy(Counter originalCounter) { Counter copyCounter = new ClassicCounter<>(); copyCounter.addAll(originalCounter); return copyCounter; } /** * Places the maximum of first and second keys values in the first counter. * @param */ public static void maxInPlace(Counter target, Counter other) { for(E e: CollectionUtils.union(other.keySet(), target.keySet())){ target.setCount(e, Math.max(target.getCount(e), other.getCount(e))); } } /** * Places the minimum of first and second keys values in the first counter. * @param */ public static void minInPlace(Counter target, Counter other){ for(E e: CollectionUtils.union(other.keySet(), target.keySet())){ target.setCount(e, Math.min(target.getCount(e), other.getCount(e))); } } /** * Retains the minimal set of top keys such that their count sum is more than thresholdCount. * @param counter * @param thresholdCount */ public static void retainTopMass(Counter counter, double thresholdCount){ PriorityQueue queue = Counters.toPriorityQueue(counter); counter.clear(); double mass = 0; while (mass < thresholdCount && !queue.isEmpty()) { double value = queue.getPriority(); E key = queue.removeFirst(); counter.setCount(key, value); mass += value; } } public static void divideInPlace(TwoDimensionalCounter counter, double divisor) { for(Entry> c: counter.entrySet()){ Counters.divideInPlace(c.getValue(), divisor); } counter.recomputeTotal(); } public static double pearsonsCorrelationCoefficient(Counter x, Counter y){ double stddevX = Counters.standardDeviation(x); double stddevY = Counters.standardDeviation(y); double meanX = Counters.mean(x); double meanY = Counters.mean(y); Counter t1 = Counters.add(x, -meanX); Counter t2 = Counters.add(y, -meanY); Counters.divideInPlace(t1, stddevX); Counters.divideInPlace(t2, stddevY); return Counters.dotProduct(t1, t2)/ (double)(x.size() -1); } public static double spearmanRankCorrelation(Counter x, Counter y){ Counter xrank = Counters.toTiedRankCounter(x); Counter yrank = Counters.toTiedRankCounter(y); return Counters.pearsonsCorrelationCoefficient(xrank, yrank); } /** * ensures that counter t has all keys in keys. If the counter does not have the keys, then add the key with count value. * Note that it does not change counts that exist in the counter */ public static void ensureKeys(Counter t, Collection keys, double value){ for(E k: keys){ if(!t.containsKey(k)) t.setCount(k, value); } } public static List topKeys(Counter t, int topNum){ List list = new ArrayList<>(); PriorityQueue q = Counters.toPriorityQueue(t); int num = 0; while(!q.isEmpty() && num < topNum){ num++; list.add(q.removeFirst()); } return list; } public static List> topKeysWithCounts(Counter t, int topNum){ List> list = new ArrayList<>(); PriorityQueue q = Counters.toPriorityQueue(t); int num = 0; while(!q.isEmpty() && num < topNum){ num++; E k = q.removeFirst(); list.add(new Pair<>(k, t.getCount(k))); } return list; } public static Counter getFCounter(Counter precision, Counter recall, double beta){ Counter fscores = new ClassicCounter<>(); for(E k: precision.keySet()){ fscores.setCount(k, precision.getCount(k)*recall.getCount(k)*(1+beta*beta)/(beta*beta*precision.getCount(k) + recall.getCount(k))); } return fscores; } public static void transformValuesInPlace(Counter counter, Function func){ for(E key: counter.keySet()){ counter.setCount(key, func.apply(counter.getCount(key))); } } public static Counter getCounts(Counter c, Collection keys){ Counter newcounter = new ClassicCounter<>(); for(E k : keys) newcounter.setCount(k, c.getCount(k)); return newcounter; } public static void retainKeys(Counter counter, Function retainFunction) { Set remove = new HashSet<>(); for(Entry en: counter.entrySet()){ if(!retainFunction.apply(en.getKey())){ remove.add(en.getKey()); } } Counters.removeKeys(counter, remove); } public static Counter flatten(Map> hier){ Counter flat = new ClassicCounter<>(); for(Entry> en: hier.entrySet()){ flat.addAll(en.getValue()); } return flat; } /** * Returns true if the given counter contains only finite, non-NaN values. * @param counts The counter to validate. * @param The parameterized type of the counter. * @return True if the counter is finite and not NaN on every value. */ public static boolean isFinite(Counter counts) { for (double value : counts.values()) { if (Double.isInfinite(value) || Double.isNaN(value)) { return false; } } return true; } }