com.aliasi.classify.TradNaiveBayesClassifier Maven / Gradle / Ivy
Show all versions of aliasi-lingpipe Show documentation
/*
* LingPipe v. 4.1.0
* Copyright (C) 2003-2011 Alias-i
*
* This program is licensed under the Alias-i Royalty Free License
* Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i
* Royalty Free License Version 1 for more details.
*
* You should have received a copy of the Alias-i Royalty Free License
* Version 1 along with this program; if not, visit
* http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact
* Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211,
* +1 (718) 290-9170.
*/
package com.aliasi.classify;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
// import com.aliasi.util.Math;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Counter;
import com.aliasi.util.Exceptions;
import com.aliasi.util.Factory;
import com.aliasi.util.Iterators;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import com.aliasi.stats.Statistics;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
/**
* A {@code TradNaiveBayesClassifier} implements a traditional
* token-based approach to naive Bayes text classification. It wraps
* a tokenization factory to convert character sequences into
* sequences of tokens. This implementation supports several
* enhancements to simple naive Bayes: priors, length normalization,
* and semi-supervised training with EM.
*
* It is the token counts (aka "bag of words") sequence
* that is actually being classified, not the raw character sequence
* input. So any character sequences that produce the same bags of
* tokens are considered equal.
*
*
Naive Bayes is trainable online, meaning that it can be given
* training instances one at a time, and at any point can be used as a
* classifier. Training cases consist of a character sequence and
* classification, as dictated by the interface {@code
* ObjectHandler>}.
*
* Given a character sequence, a naive Bayes classifier returns
* joint probability estimates of categories and tokens; this is
* reflected in its implementing the {@code
* Classifier} interface. Note that
* this is the joint probability of the token counts, so sums of
* probabilities over all input character sequences will exceed 1.0.
* Typically, only the conditional probability estimates are used in
* practice.
*
* If there is length normalization, the joint probabilities will
* not sum to 1.0 over all inputs and outputs. The conditional
* probabilities will always sum to 1.0.
*
*
*
Classification
*
* A token-based naive Bayes classifier computes joint token count and
* category probabilities by factoring the joint into the marginal
* probability of a category times the conditinoal probability of the
* tokens given the category.
*
*
* p(tokens,cat) = p(tokens|cat) * p(cat)
*
* Conditional probabilities are derived by applying Bayes's rule to
* invert the probability calculation:
*
*
* p(cat|tokens) = p(tokens,cat) / p(tokens)
* = p(tokens|cat) * p(cat) / p(tokens)
*
* The tokens are assumed to be independent (this is the
* "naive" step):
*
*
* p(tokens|cat) = p(tokens[0]|cat) * ... * p(tokens[tokens.length-1]|cat)
* = Πi < tokens.length p(tokens[i]|cat)
*
* Finally, an explicit marginalization allows us to compute the
* marginal distribution of tokens:
*
*
* p(tokens) = Σcat' p(tokens,cat')
* = Σcat' p(tokens|cat') * p(cat')
*
*
* Estimation with Priors
*
* We now have defined the conditional probability {@code
* p(cat|tokens)} in terms of two distributions, the conditional
* probability of a token given a category {@code p(token|cat)}, and the
* marginal probability of a category {@code p(cat)} (sometimes called
* the category's prior probability, though this shouldn't be confused
* with the usual Bayesian prior on model parameters).
*
* Traditional naive Bayes uses a maximum a posterior (MAP)
* estimate of the multinomial distributions: {@code p(cat)} over the
* set of categories, and for each category {@code cat}, the
* multinomial distribution {@code p(token|cat)} over the set of tokens.
* Traditional naive Bayes employs the Dirichlet conjugate prior for
* multinomials, which is straightforward to compute by adding a fixed
* "prior count" to each count in the training data. This lends the
* traditional name "additive smoothing".
*
*
Two sets of counts are sufficient for estimating a traditional
* naive Bayes classifier. The first is {@code tokenCount(w,c)}, the
* number of times token {@code w} appeared as a token in a training
* case for category {@code c}. The second is {@code caseCount(c)},
* which is the number of training cases for category {@code c}.
*
*
We assume prior counts α
for the case counts
* and β
for the token counts. These values are supplied
* in the constructor for this class.
*
* The estimates for category and token probabilities p'
* are most easily understood as proportions:
*
*
* p'(w|c) ∝ tokenCount(w,c) + β
*
* p'(c) ∝ caseCount(c) + α
*
* The probability estimates p'
are obtained through the
* usual normalization:
*
*
* p'(w|c) = ( tokenCount(w,c) + β ) / Σw ( tokenCount(w,c) + β )
*
* p'(c) = ( caseCount(c) + α ) / Σc ( caseCount(c) + α )
*
*
* Maximum Likelihood Estimates
*
* Although not traditionally used for naive Bayes, maximum
* likelihood estimates arise from setting the prior counts equal to
* zero (α = β = 0
). The prior counts drop
* out of the equations to yield the maximum likelihood estimates
* p*
:
*
*
* p*(w|c) = tokenCount(w,c) / Σw tokenCount(w,c)
*
* p*(c) = caseCount(c) / Σc caseCount(c)
*
* Weighted and Conditional Training
*
* Unlike traditional naive Bayes implementations, this class
* allows weighted training, including training directly from a
* conditional classification. When training using a conditional
* classification, each category is weighted according to its
* conditional probability.
*
*
Weights may be negative, allowing
* counts to be decremented (e.g. for Gibbs sampling).
*
*
*
Length Normalization
*
* Because the (almost always faulty) independence of tokens
* assumptions underlying the naive Bayes classifier, the conditional
* probability estimates tend toward either 0.0 or 1.0 as the input
* grows longer. In practice, it sometimes help to length normalize
* the documents. That is, consider each document to be a given
* number of tokens long, lengthNorm
.
*
*
Length normalization can be computed directly on the linear
* scale:
*
*
* pn(tokens|cat) = p(tokens|cat)(lengthNorm/tokens.length)
*
*
* but is more easily understood on the log scale, where we multiply
* the length norm by the log probability normalized per token:
*
*
* log2 pn(tokens|cat) = lengthNorm * log2 p(tokens|c) / tokens.length
*
*
* The length normalization parameter is supplied in the
* constructor, with a {@code Double.NaN} value indicating
* that no length normalization should be done.
*
* Length normalization will be applied during training, too.
* Length normalization may be changed using the set method.
* For instance, this allows training to skip length normalization
* and classification to use length normalization.
*
*
*
Semi-Supervised Training with Expectation Maximization (EM)
*
* Naive bayes is a common model to use in conjunction with the
* general semi-supervised or unsupervised training strategy known as
* expectation maximization (EM). The basic idea behind EM is
* is that it starts with a classifier, then applies it to
* unseen data, looks at the weighted output predictions, then
* uses the output predictions as training data.
*
* EM is controlled by epoch. Each epoch consists of an
* expectation (E) step, followed by a maximization (M) step.
* The expectation step computes expectations which are then
* fed in as training weights to the maximization step.
*
*
The version of EM implemented in this class allows a mixture of
* supervised and unsupervised data.
*
*
The supervised training data is
* in the form of a corpus of classifications, implementing
* Corpus>}.
*
* Unsupervised data is in the form of a corpus of texts, implementing
* {@code Corpus}.
*
* The method also requires a factory with which to produce a new
* classifier in each epoch, namely an implementation of {@code
* Factory}. And it also takes an initial
* classifier, which may be different than the classifiers generated
* by the factory.
*
* EM works by iteratively training better and better classifiers
* using the previous classifier to label unlabeled data to use
* for training.
*
*
* set lastClassifier to initialClassifier
* for (epoch = 0; epoch < maxEpochs; ++epoch) {
* create classifier using factory
* train classifier on supervised items
* for (x in unsupervised items) {
* compute p(c|x) with lastClassifier
* for (c in category)
* train classifier on c weighted by p(c|x)
* }
* evaluate corpus and model probability under classifier
* set lastClassifier to classifier
* break if converged
* }
* return lastClassifier
*
* Note that in each round, the new classifier is trained on
* the supervised items.
*
*
In general, we have found that EM training works best if the
* initial classifier does more smoothing than the classifiers
* returned by the factory.
*
*
Annealing, of a sort, may be built in by having the factory
* return a sequence of classifiers with ever longer length
* normalizations and/or lower prior counts, both of which attenuate
* the posterior predictions of a naive Bayes classifier. With a
* short length normalization, classifications are driven closer to
* uniform; with longer length normalizations they are more peaky.
*
*
*
Unsupervised Learning and EM Soft Clustering
*
* It is possible to train a classifier in a completely
* unsupervised fashion by having the initial classifier assign
* categories at random. Only the number of categories must be fixed.
* The algorithm is exactly the same, and the result after
* convergence or the maximum number of epochs is a classifier.
*
*
Now take the trained classifier and run it over the texts in the
* unsupervised text corpus. This will assign probabilities of the
* text belonging to each of the categories. This is known as a soft
* clustering, and the algorithm overall is known as EM clustering.
* If we assign each item to its most likely category, the result
* is then a hard clustering.
*
*
Serialization and Compilation
*
* A naive Bayes classifier may be serialized. The object read
* back in will behave just as the naive Bayes classifier that was
* serialized. The tokenizer factory must be serializable in order
* to serialize the classifier.
*
*
A naive Bayes classifier may be compiled. In order to be
* compiled, the tokenizer factory must be either serializable or
* compilable. The object read back in will implement {@code
* ConditionalClassifier} if the compiled classifier is
* binary (i.e., has exactly two categories) and {@code
* JointClassifier} if the compiled classifier has more
* than two categories. The ability to compute joint probabilities in
* the binary case is lost due to an optimization in the compiler, so
* the resulting class only implements conditional classifier.
*
* A compiled classifier may not be trained.
*
*
Comparison to {@code NaiveBayesClassifier}
*
* The naive Bayes classifier implemented in {@link
* NaiveBayesClassifier} differs from this version in smoothing the
* token estimates with character language model estimates.
*
*
* Thread Safety
*
* A {@code TradNaiveBayesClassifier} must be synchronized externally
* using read/write synchronization (e.g. with {@link
* java.util.concurrent.locks.ReadWriteLock}. The write methods
* include {@link #handle(Classified)}, {@link
* #train(CharSequence,Classification,double)}, {@link
* #trainConditional(CharSequence,ConditionalClassification,double,double)},
* and {@link #setLengthNorm(double)}. All other methods are read
* methods.
*
* A compiled classifier is completely thread safe.
*
* @author Bob Carpenter
* @version 4.1.0
* @since Lingpipe3.8
*/
public class TradNaiveBayesClassifier
implements JointClassifier,
ObjectHandler>,
Serializable,
Compilable {
static final long serialVersionUID = -300327951207213311L;
private final Set mCategorySet;
private final String[] mCategories;
private final TokenizerFactory mTokenizerFactory;
private final double mCategoryPrior;
private final double mTokenInCategoryPrior;
private Map mTokenToCountsMap; // wordCount(w,c)
private double[] mTotalCountsPerCategory; // SUM_w wordCount(w,c); indexed by c
private double[] mCaseCounts; // caseCount(c)
private double mTotalCaseCount; // SUM_c caseCount(c)
private double mLengthNorm;
/**
* Return a string representation of this classifier.
*
* @return String representation of this classifier.
*/
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("categories=" + Arrays.asList(mCategories) + "\n");
sb.append("category Prior=" + mCategoryPrior + "\n");
sb.append("token in category prior=" + mTokenInCategoryPrior + "\n");
sb.append("total case count=" + mTotalCaseCount + "\n");
for (int i = 0; i < mCategories.length; ++i) {
sb.append("category count(" + mCategories[i] + ")=" + mCaseCounts[i] + "\n");
}
for (String token : mTokenToCountsMap.keySet()) {
sb.append("token=" + token + "\n");
double[] counts = mTokenToCountsMap.get(token);
for (int i = 0; i < mCategories.length; ++i) {
sb.append(" tokenCount(" + mCategories[i] + "," + token + ")=" + counts[i] + "\n");
}
}
return sb.toString();
}
private TradNaiveBayesClassifier(String[] categories,
TokenizerFactory tokenizerFactory,
double categoryPrior,
double tokenInCategoryPrior,
Map tokenToCountsMap,
double[] totalCountsPerCategory,
double[] caseCounts,
double totalCaseCount,
double lengthNorm) {
mCategories = categories;
mCategorySet = new HashSet(Arrays.asList(categories));
mTokenizerFactory = tokenizerFactory;
mCategoryPrior = categoryPrior;
mTokenInCategoryPrior = tokenInCategoryPrior;
mTokenToCountsMap = tokenToCountsMap;
mTotalCountsPerCategory = totalCountsPerCategory;
mCaseCounts = caseCounts;
mTotalCaseCount = totalCaseCount;
mLengthNorm = lengthNorm;
}
/**
* Constructs a naive Bayes classifier over the specified
* categories, using the specified tokenizer factory. The
* category and token-in-category priors will be set to reasonable
* default value of 0.5, and there is no length normlization (length
* normalization set to {@code Double.NaN}).
*
* See the class documentation above for more information.
*
* @param categorySet Categories for classification.
* @param tokenizerFactory Factory to convert char sequences to
* tokens.
* @throws IllegalArgumentException If there are fewer than two
* categories.
*/
public TradNaiveBayesClassifier(Set categorySet,
TokenizerFactory tokenizerFactory) {
this(categorySet,tokenizerFactory,0.5,0.5,Double.NaN);
}
/**
* Constructs a naive Bayes classifier over the specified
* categories, using the specified tokenizer factory, priors and
* length normalization. See the class documentation for an
* explanation of the parameter's affect on classification.
*
* @param categorySet Categories for classification.
* @param tokenizerFactory Factory to convert char sequences to
* tokens.
* @param categoryPrior Prior count for categories.
* @param tokenInCategoryPrior Prior count for tokens per category.
* @param lengthNorm A positive, finite length norm, or {@code
* Double.NaN} if no length normalization is to be done.
* @throws IllegalArgumentException If either prior is negative or
* not finite, if there are fewer than two categories, or if the
* length normalization constant is negative, zero, or infinite.
*/
public TradNaiveBayesClassifier(Set categorySet,
TokenizerFactory tokenizerFactory,
double categoryPrior,
double tokenInCategoryPrior,
double lengthNorm) {
if (categorySet.size() < 2) {
String msg = "Require at least two categorySet."
+ " Found categorySet.size()=" + categorySet.size();
throw new IllegalArgumentException(msg);
}
Exceptions.finiteNonNegative("categoryPrior",categoryPrior);
Exceptions.finiteNonNegative("tokenInCategoryPrior",
tokenInCategoryPrior);
setLengthNorm(lengthNorm);
mTotalCaseCount = 0L;
mCategorySet = new HashSet(categorySet);
mCategories = mCategorySet.toArray(Strings.EMPTY_STRING_ARRAY);
Arrays.sort(mCategories);
mTokenizerFactory = tokenizerFactory;
mCategoryPrior = categoryPrior;
mTokenInCategoryPrior = tokenInCategoryPrior;
mTokenToCountsMap = new HashMap();
mTotalCountsPerCategory = new double[mCategories.length];
mCaseCounts = new double[mCategories.length];
}
/**
* Returns a set of categories for this classifier.
*
* @return The set of categories for this classifier.
*/
public Set categorySet() {
return Collections.unmodifiableSet(mCategorySet);
}
/**
* Set the length normalization factor to the specified value.
* See the class documentation for
*
* @param lengthNorm Length normalization or {@code Double.NaN} to turn
* off normalization.
* @throws IllegalArgumentException If the length norm is
* infinite, zero, or negative.
*/
public void setLengthNorm(double lengthNorm) {
if (lengthNorm <= 0.0 || Double.isInfinite(lengthNorm)) {
String msg = "Length norm must be finite and positive, or Double.NaN."
+ " Found lengthNorm=" + lengthNorm;
throw new IllegalArgumentException(msg);
}
mLengthNorm = lengthNorm;
}
/**
* Return the classification of the specified character sequence.
*
* @param in Character sequence being classified.
* @return The classifcation of the char sequence.
*/
public JointClassification classify(CharSequence in) {
double[] logps = new double[mCategories.length];
char[] cs = Strings.toCharArray(in);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
int tokenCount = 0;
for (String token : tokenizer) {
double[] tokenCounts = mTokenToCountsMap.get(token);
++tokenCount;
if (tokenCounts == null)
continue;
for (int i = 0; i < mCategories.length; ++i)
logps[i] += com.aliasi.util.Math.log2(probTokenByIndexArray(i,tokenCounts));
}
if ((!Double.isNaN(mLengthNorm)) && (tokenCount > 0)) {
for (int i = 0; i < logps.length; ++i)
logps[i] *= mLengthNorm/tokenCount;
}
for (int i = 0; i < logps.length; ++i)
logps[i] += com.aliasi.util.Math.log2(probCatByIndex(i));
return JointClassification.create(mCategories,logps);
}
/**
* Returns the length normalization factor for this
* classifier. See the class documentation above for
* details.
*
* @return The length normalization for this classifier.
*/
public double lengthNorm() {
return mLengthNorm;
}
/**
* Returns {@code true} if the token has been seen in
* training data.
*
* @param token Token to test.
* @return {@code true} if the token has been seen in
* training data.
*/
public boolean isKnownToken(String token) {
return mTokenToCountsMap.containsKey(token);
}
/**
* Returns an unmodifiable view of the set of tokens.
* The set is not modifiable, but will change to reflect
* any tokens added during training.
*
* @return The set of known tokens.
*/
public Set knownTokenSet() {
return Collections.unmodifiableSet(mTokenToCountsMap.keySet());
}
/**
* Returns the probability of the specified token
* in the specified category. See the class documentation
* above for definitions.
*
* @throws IllegalArgumentException If the category is not known
* or the token is not known.
*/
public double probToken(String token, String cat) {
int catIndex = getIndex(cat);
double[] tokenCounts = mTokenToCountsMap.get(token);
if (tokenCounts == null) {
String msg = "Requires known token."
+ " Found token=" + token;
throw new IllegalArgumentException(msg);
}
return probTokenByIndexArray(catIndex,tokenCounts);
}
/**
* Compile this classifier to the specified object output.
*
* @param out Object output to which this classifier is compiled.
* @throws IOException If there is an underlying I/O error
* during the write.
*/
public void compileTo(ObjectOutput out) throws IOException {
out.writeObject(new Compiler(this));
}
/**
* Returns the probability estimate for the specified
* category.
*
* @param cat Category whose probability is returned.
* @return Probability for category.
* @throws IllegalArgumentException If the category is not known.
*/
public double probCat(String cat) {
int catIndex = getIndex(cat);
return probCatByIndex(catIndex);
}
/**
* Trains the classifier with the specified classified character
* sequence. Only the first-best result is used from the
* classification; to train on conditional outputs, see {@link
* #trainConditional(CharSequence,ConditionalClassification,double,double)}.
*
* @param classifiedObject Classified character sequence.
*/
public void handle(Classified classifiedObject) {
handle(classifiedObject.getObject(), classifiedObject.getClassification());
}
/**
* Trains the classifier with the specified case consisting of a
* character sequence and first-best classification. Only the
* first-best result is used from the classification; to train on
* conditional outputs, see {@link
* #trainConditional(CharSequence,ConditionalClassification,double,double)}.
*
* @param cSeq Character sequence being classified.
* @param classification Classification of character sequence.
*/
void handle(CharSequence cSeq, Classification classification) {
train(cSeq,classification,1.0);
}
/**
* Trains this classifier using tokens extracted from the
* specified character sequence, using category count multipliers
* derived by multiplying the specified count multiplier by the
* conditional probablity of a category in the specified
* classification. A category is not trained for the sequence
* if its conditional probability times the count multiplier
* is less than the minimum count.
*
* @param cSeq Character sequence being trained.
* @param classification Conditional classification to train.
* @param countMultiplier Count multiplier of training instance.
* @param minCount Minimum count for which a category is trained for this character
* sequence.
* @throws IllegalArgumentException If the countMultiplier is not finite and
* non-negative, or if the min count is below zero or not a number.
*/
public void trainConditional(CharSequence cSeq,
ConditionalClassification classification,
double countMultiplier,
double minCount) {
if (countMultiplier < 0.0
|| Double.isNaN(countMultiplier)
|| Double.isInfinite(countMultiplier)) {
String msg = "Count multipliers must be finite and non-negative."
+ " Found countMultiplier=" + countMultiplier;
throw new IllegalArgumentException(msg);
}
if (minCount < 0.0 || Double.isNaN(minCount) || Double.isInfinite(minCount)) {
String msg = "Minimum count must be finite non-negative."
+ " Found minCount=" + minCount;
throw new IllegalArgumentException(msg);
}
int numCats = 0;
while (numCats < classification.size()
&& classification.conditionalProbability(numCats) * countMultiplier >= minCount)
++numCats;
ObjectToCounterMap tokenCountMap = tokenCountMap(cSeq);
double lengthMultiplier = lengthMultiplier(tokenCountMap);
// cache results per cat
double[] lengthNormCatMultipliers = new double[numCats];
int[] catIndexes = new int[numCats];
for (int j = 0; j < numCats; ++j) {
catIndexes[j] = getIndex(classification.category(j));
double count = countMultiplier * classification.conditionalProbability(j);
mTotalCaseCount += count;
mCaseCounts[catIndexes[j]] += count;
lengthNormCatMultipliers[j] = lengthMultiplier * count;
}
for (Map.Entry entry : tokenCountMap.entrySet()) {
String token = entry.getKey();
double tokenCount = entry.getValue().doubleValue();
double[] tokenCounts = mTokenToCountsMap.get(token);
if (tokenCounts == null) {
tokenCounts = new double[mCategories.length];
mTokenToCountsMap.put(token,tokenCounts);
}
for (int j = 0; j < numCats; ++j) {
double addend = tokenCount * lengthNormCatMultipliers[j];
tokenCounts[catIndexes[j]] += addend;
mTotalCountsPerCategory[catIndexes[j]] += addend;
}
}
}
/**
* Trains the classifier with the specified case consisting of
* a character sequence and conditional classification with the
* specified count.
*
* If the count value is negative, counts are subtracted rather
* than added. If any of the counts fall below zero, an illegal
* argument exception will be thrown and the classifier will be
* reverted to the counts in place before the method was called.
* Cleanup after errors requires the tokenizer factory to return
* the same tokenizer given the same string, but no check is made
* that it does.
*
* @param cSeq Character sequence on which to train.
* @param classification Classification to train with character
* sequence.
* @param count How many instances the classification will count
* as for training purposes.
* @throws IllegalArgumentException If the count is negative and
* increments cause accumulated counts to fall below zero.
*/
public void train(CharSequence cSeq, Classification classification, double count) {
if (count == 0.0) return;
String cat = classification.bestCategory();
int catIndex = getIndex(cat);
// throw if underflow
if (mCaseCounts[catIndex] < -count) {
String msg = "Decrement caused negative token count."
+ "Revert to previous state."
+ " cSeq=" + cSeq
+ " classification=" + cat
+ " count=" + count;
throw new IllegalArgumentException(msg);
}
mCaseCounts[catIndex] += count;
mTotalCaseCount += count;
ObjectToCounterMap tokenCountMap = tokenCountMap(cSeq);
double lengthMultiplier = lengthMultiplier(tokenCountMap);
double lengthNormCount = lengthMultiplier * count;
char[] cs = Strings.toCharArray(cSeq);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
int pos = 0;
for (String token : tokenizer) {
double[] tokenCounts = mTokenToCountsMap.get(token);
// cleanup underflow and throw
if (lengthNormCount < 0 && ((tokenCounts == null)
|| (tokenCounts[catIndex] < -lengthNormCount))) {
// first two are unnormed
mCaseCounts[catIndex] -= count;
mTotalCaseCount -= count;
Tokenizer tokenizer2 = mTokenizerFactory.tokenizer(cs,0,cs.length);
int fixPos = 0;
for (String token2 : tokenizer2) {
if (fixPos >= pos) break;
++fixPos;
double[] tokenCounts2 = mTokenToCountsMap.get(token2);
tokenCounts2[catIndex] -= lengthNormCount;
mTotalCountsPerCategory[catIndex] -= lengthNormCount;
}
String msg = "Decrement caused negative token count."
+ "Revert to previous state."
+ " cSeq=" + cSeq
+ " classification=" + cat
+ " count=" + count;
throw new IllegalArgumentException(msg);
}
++pos;
if (tokenCounts == null) {
tokenCounts = new double[mCategories.length];
mTokenToCountsMap.put(token,tokenCounts);
}
tokenCounts[catIndex] += lengthNormCount;;
mTotalCountsPerCategory[catIndex] += lengthNormCount;
}
}
/**
* Returns the log (base 2) marginal probability of the specified
* input. This value is calculated by:
*
*
* p(x) = Σc in cats p(c,x)
*
*
* Note that this value is normalized by the number of tokens
* in the input, so that
*
*
* Σlength(x) = n p(x) = 1.0
*
*
* @param input Input character sequence.
* @return The log probability of the input under this joint
* model.
*/
public double log2CaseProb(CharSequence input) {
JointClassification c = classify(input);
double maxJointLog2P = Double.NEGATIVE_INFINITY;
for (int rank = 0; rank < c.size(); ++rank) {
double jointLog2P = c.jointLog2Probability(rank);
if (jointLog2P > maxJointLog2P)
maxJointLog2P = jointLog2P;
}
double sum = 0.0;
for (int rank = 0; rank < c.size(); ++rank)
sum += Math.pow(2.0,c.jointLog2Probability(rank) - maxJointLog2P);
return maxJointLog2P + com.aliasi.util.Math.log2(sum);
}
/**
* Returns the log (base 2) of the probability density of this
* model in the Dirichlet prior specified by this classifier.
* Note that the result is a log density is not technically a
* probability, and may return values that are positive.
*
* The result is the sum of the log density of the multinomial
* over categories and the log density of the per-category
* multinomials over tokens.
*
*
For a definition of the probability function for each
* category's multinomial and the overall category multinomial,
* see {@link Statistics#dirichletLog2Prob(double,double[])}.
*
* @return The log model density value.
*/
public double log2ModelProb() {
double[] catProbs = new double[mCategories.length];
for (int i = 0; i < mCategories.length; ++i) {
catProbs[i] = probCatByIndex(i);
}
double sum = Statistics.dirichletLog2Prob(mCategoryPrior,catProbs);
double[] wordProbs = new double[mTokenToCountsMap.size()];
for (int catIndex = 0; catIndex < mCategories.length; ++catIndex) {
int j = 0;
for (double[] counts : mTokenToCountsMap.values()) {
double totalCountForCat = mTotalCountsPerCategory[catIndex];
wordProbs[j++] = (counts[catIndex] + mTokenInCategoryPrior)/(totalCountForCat + mCaseCounts.length * mTokenInCategoryPrior);
}
sum += Statistics.dirichletLog2Prob(mTokenInCategoryPrior,wordProbs);
}
return sum;
}
private Object writeReplace() throws ObjectStreamException {
return new TradNaiveBayesClassifier.Serializer(this);
}
private double probTokenByIndexArray(int catIndex, double[] tokenCounts) {
double tokenCatCount = tokenCounts[catIndex];
double totalCatCount = mTotalCountsPerCategory[catIndex];
return (tokenCatCount + mTokenInCategoryPrior)
/ (totalCatCount + mTokenToCountsMap.size() * mTokenInCategoryPrior);
}
private double probCatByIndex(int catIndex) {
double caseCountCat = mCaseCounts[catIndex];
return (caseCountCat + mCategoryPrior)
/ (mTotalCaseCount + mCategories.length * mCategoryPrior);
}
private ObjectToCounterMap tokenCountMap(CharSequence cSeq) {
ObjectToCounterMap tokenCountMap = new ObjectToCounterMap();
char[] cs = Strings.toCharArray(cSeq);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
for (String token : tokenizer)
tokenCountMap.increment(token);
return tokenCountMap;
}
private double lengthMultiplier(ObjectToCounterMap tokenCountMap) {
if (Double.isNaN(mLengthNorm)) return 1.0;
int length = 0;
for (Counter counter : tokenCountMap.values())
length += counter.intValue();
return length != 0.0
? mLengthNorm / length
: 1.0;
}
private int getIndex(String cat) {
int catIndex = java.util.Arrays.binarySearch(mCategories,cat);
if (catIndex < 0) {
String msg = "Unknown category. Require category in category set."
+ " Found category=" + cat
+ " category set=" + mCategorySet;
throw new IllegalArgumentException(msg);
}
return catIndex;
}
/**
* Apply the expectation maximization (EM) algorithm to train a traditional
* naive Bayes classifier using the specified labeled and unabled data,
* initial classifier and factory for creating subsequent factories.
*
* This method lets the client take control over assessing convergence,
* so there are no convergence-related arguments.
*
* @param initialClassifier Initial classifier to bootstrap.
* @param classifierFactory Factory for creating subsequent classifiers.
* @param labeledData Labeled data for supervised trianing.
* @param unlabeledData Unlabeled data for unsupervised training.
* @param minTokenCount Min count for a word to not be pruned.
* @return An iterator over classifiers that returns each epoch's
* classifier.
*/
public static Iterator
emIterator(TradNaiveBayesClassifier initialClassifier,
Factory classifierFactory,
Corpus>> labeledData,
Corpus> unlabeledData,
double minTokenCount) throws IOException {
return new EmIterator(initialClassifier,classifierFactory,
labeledData,unlabeledData,minTokenCount);
}
/**
* Apply the expectation maximization (EM) algorithm to train a traditional
* naive Bayes classifier using the specified labeled and unabled data,
* initial classifier and factory for creating subsequent factories,
* maximum number of epochs, minimum improvement per epoch, and reporter
* to which progress reports are sent.
*
* @param initialClassifier Initial classifier to bootstrap.
* @param classifierFactory Factory for creating subsequent classifiers.
* @param labeledData Labeled data for supervised trianing.
* @param unlabeledData Unlabeled data for unsupervised training.
* @param minTokenCount Min count for a word to not be pruned.
* @param maxEpochs Maximum number of epochs to run training.
* @param minImprovement Minimum relative improvement per epoch.
* @param reporter Reporter to which intermediate results are reported,
* or {@code null} for no reporting.
* @return The trained classifier.
*/
public static TradNaiveBayesClassifier
emTrain(TradNaiveBayesClassifier initialClassifier,
Factory classifierFactory,
Corpus>> labeledData,
Corpus> unlabeledData,
double minTokenCount,
int maxEpochs,
double minImprovement,
Reporter reporter) throws IOException {
if (reporter == null)
reporter = Reporters.silent();
long startTime = System.currentTimeMillis();
double lastLogProb = Double.NEGATIVE_INFINITY;
Iterator it
= emIterator(initialClassifier,classifierFactory,labeledData,unlabeledData,minTokenCount);
TradNaiveBayesClassifier classifier = null;
for (int epoch = 0; it.hasNext() && epoch < maxEpochs; ++epoch) {
classifier = it.next();
double modelLogProb = classifier.log2ModelProb();
double dataLogProb = dataProb(classifier,labeledData,unlabeledData);
double logProb = modelLogProb + dataLogProb;
double relativeDiff = relativeDiff(lastLogProb,logProb);
if (reporter.isDebugEnabled()) {
Formatter formatter = new Formatter();
formatter.format("epoch=%4d dataLogProb=%15.2f modelLogProb=%15.2f logProb=%15.2f diff=%15.12f",
epoch, dataLogProb, modelLogProb, logProb, relativeDiff);
String msg = formatter.toString();
reporter.debug(msg);
}
if (!Double.isNaN(lastLogProb) && relativeDiff < minImprovement) {
reporter.info("Converged");
return classifier;
} else {
lastLogProb = logProb;
}
}
return classifier;
}
static double dataProb(TradNaiveBayesClassifier classifier,
Corpus>> labeledData,
Corpus> unlabeledData) throws IOException {
CaseProbAccumulator accum = new CaseProbAccumulator(classifier);
labeledData.visitTrain(accum.supHandler());
unlabeledData.visitTrain(accum);
return accum.mCaseProb;
}
static double relativeDiff(double x, double y) {
return 2.0 * Math.abs(x-y) / (Math.abs(x) + Math.abs(y));
}
static class CaseProbAccumulator
implements ObjectHandler {
double mCaseProb = 0.0;
final TradNaiveBayesClassifier mClassifier;
CaseProbAccumulator(TradNaiveBayesClassifier classifier) {
mClassifier = classifier;
}
public void handle(CharSequence cSeq) {
mCaseProb += mClassifier.log2CaseProb(cSeq);
}
public ObjectHandler> supHandler() {
final ObjectHandler cSeqHandler = this;
return new ObjectHandler>() {
public void handle(Classified classified) {
cSeqHandler.handle(classified.getObject());
}
};
}
}
static class EmIterator extends Iterators.Buffered {
private final Factory mClassifierFactory;
private final Corpus>> mLabeledData;
private final Corpus> mUnlabeledData;
private final double mMinTokenCount;
private JointClassifier mLastClassifier;
EmIterator(TradNaiveBayesClassifier initialClassifier,
Factory classifierFactory,
Corpus>> labeledData,
Corpus> unlabeledData,
double minTokenCount) {
mClassifierFactory = classifierFactory;
mLabeledData = labeledData;
mUnlabeledData = unlabeledData;
mMinTokenCount = minTokenCount;
trainSup(labeledData,initialClassifier);
compile(initialClassifier);
}
@Override
public TradNaiveBayesClassifier bufferNext() {
TradNaiveBayesClassifier classifier = mClassifierFactory.create();
trainSup(mLabeledData,classifier);
trainUnsup(mUnlabeledData,classifier);
compile(classifier);
return classifier;
}
void trainSup(Corpus>> labeledData,
TradNaiveBayesClassifier classifier) {
try {
labeledData.visitTrain(classifier);
} catch (IOException e) {
throw new IllegalStateException("Error during labeled training",e);
}
}
void trainUnsup(final Corpus> unlabeledData,
final TradNaiveBayesClassifier classifier) {
try {
unlabeledData.visitTrain(new ObjectHandler() {
public void handle(CharSequence cSeq) {
ConditionalClassification c = mLastClassifier.classify(cSeq);
classifier.trainConditional(cSeq,c,1.0,mMinTokenCount);
}
});
} catch (IOException e) {
throw new IllegalStateException("Error during unlabeled training",e);
}
}
void compile(TradNaiveBayesClassifier classifier) {
try {
@SuppressWarnings("unchecked") // know this is OK, assignment required to to scope
JointClassifier lastClassifier
= (JointClassifier)
AbstractExternalizable.compile(classifier);
mLastClassifier = lastClassifier;
} catch (IOException e) {
mLastClassifier = null;
throw new IllegalStateException("Error during compilation.",e);
} catch (ClassNotFoundException e) {
mLastClassifier = null;
throw new IllegalStateException("Error during compilation.",e);
}
}
}
static class Serializer extends AbstractExternalizable {
static final long serialVersionUID = -4786039228920809976L;
private final TradNaiveBayesClassifier mClassifier;
public Serializer(TradNaiveBayesClassifier classifier) {
mClassifier = classifier;
}
public Serializer() {
this(null);
}
@Override
public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
int numCats = in.readInt();
String[] categories = new String[numCats];
for (int i = 0; i < numCats; ++i)
categories[i] = in.readUTF();
TokenizerFactory tokenizerFactory = (TokenizerFactory) in.readObject();
double catPrior = in.readDouble();
double tokenInCatPrior = in.readDouble();
int tokenToCountsMapSize = in.readInt();
Map tokenToCountsMap
= new HashMap((tokenToCountsMapSize*3)/2);
for (int k = 0; k < tokenToCountsMapSize; ++k) {
String key = in.readUTF();
double[] vals = new double[categories.length];
for (int i = 0; i < categories.length; ++i)
vals[i] = in.readDouble();
tokenToCountsMap.put(key,vals);
}
double[] totalCountsPerCategory = new double[categories.length];
for (int i = 0; i < categories.length; ++i)
totalCountsPerCategory[i] = in.readDouble();
double[] caseCounts = new double[categories.length];
for (int i = 0; i < categories.length; ++i)
caseCounts[i] = in.readDouble();
double totalCaseCount = in.readDouble();
double lengthNorm = in.readDouble();
return new TradNaiveBayesClassifier(categories,
tokenizerFactory,
catPrior,
tokenInCatPrior,
tokenToCountsMap,
totalCountsPerCategory,
caseCounts,
totalCaseCount,
lengthNorm);
}
@Override
public void writeExternal(ObjectOutput objOut) throws IOException {
objOut.writeInt(mClassifier.mCategories.length);
for (String category : mClassifier.mCategories)
objOut.writeUTF(category);
// may throw exception here if tokenizer factory not serializable
objOut.writeObject(mClassifier.mTokenizerFactory);
objOut.writeDouble(mClassifier.mCategoryPrior);
objOut.writeDouble(mClassifier.mTokenInCategoryPrior);
objOut.writeInt(mClassifier.mTokenToCountsMap.size());
for (Map.Entry entry : mClassifier.mTokenToCountsMap.entrySet()) {
objOut.writeUTF(entry.getKey());
double[] vals = entry.getValue();
for (int i = 0; i < mClassifier.mCategories.length; ++i)
objOut.writeDouble(vals[i]);
}
for (int i = 0; i < mClassifier.mCategories.length; ++i)
objOut.writeDouble(mClassifier.mTotalCountsPerCategory[i]);
for (int i = 0; i < mClassifier.mCategories.length; ++i)
objOut.writeDouble(mClassifier.mCaseCounts[i]);
objOut.writeDouble(mClassifier.mTotalCaseCount);
objOut.writeDouble(mClassifier.mLengthNorm);
}
}
static class Compiler extends AbstractExternalizable {
static final long serialVersionUID = 5689464666886334529L;
private final TradNaiveBayesClassifier mClassifier;
public Compiler() {
this(null);
}
public Compiler(TradNaiveBayesClassifier classifier) {
mClassifier = classifier;
}
@Override
public void writeExternal(ObjectOutput objOut) throws IOException {
objOut.writeInt(mClassifier.mCategories.length);
for (int i = 0; i < mClassifier.mCategories.length; ++i)
objOut.writeUTF(mClassifier.mCategories[i]);
AbstractExternalizable.compileOrSerialize(mClassifier.mTokenizerFactory,objOut);
objOut.writeInt(mClassifier.mTokenToCountsMap.size());
for (Map.Entry entry : mClassifier.mTokenToCountsMap.entrySet()) {
objOut.writeUTF(entry.getKey());
double[] tokenCounts = entry.getValue();
for (int i = 0; i < mClassifier.mCategories.length; ++i) {
double log2Prob = com.aliasi.util.Math.log2(mClassifier.probTokenByIndexArray(i,tokenCounts));
if (log2Prob > 0.0) {
String msg = "key=" + entry.getKey() +
" i=" + i
+ " log2Prob=" + log2Prob
+ " prob=" + mClassifier.probTokenByIndexArray(i,tokenCounts)
+ " token counts[" + i + "]=" + tokenCounts[i]
+ " totalCatCount=" + mClassifier.mTotalCountsPerCategory[i]
+ " mTokenToCountsMap.size()=" + mClassifier.mTokenToCountsMap.size();
throw new IllegalArgumentException(msg);
}
objOut.writeDouble(log2Prob);
}
}
for (int i = 0; i < mClassifier.mCategories.length; ++i)
objOut.writeDouble(com.aliasi.util.Math.log2(mClassifier.probCatByIndex(i)));
objOut.writeDouble(mClassifier.mLengthNorm);
}
@Override
public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
int numCategories = in.readInt();
String[] categories = new String[numCategories];
for (int i = 0; i < numCategories; ++i)
categories[i] = in.readUTF();
TokenizerFactory tokenizerFactory = (TokenizerFactory) in.readObject();
int size = in.readInt();
Map tokenToLog2ProbsInCats
= new HashMap((size * 3)/2);
for (int k = 0; k < size; ++k) {
String token = in.readUTF();
double[] log2ProbsInCats = new double[numCategories];
for (int i = 0; i < numCategories; ++i)
log2ProbsInCats[i] = in.readDouble();
tokenToLog2ProbsInCats.put(token,log2ProbsInCats);
}
double[] log2CatProbs = new double[numCategories];
for (int i = 0; i < numCategories; ++i)
log2CatProbs[i] = in.readDouble();
double lengthNorm = in.readDouble();
return (categories.length == 2)
? new CompiledBinaryTradNaiveBayesClassifier(categories,
tokenizerFactory,
tokenToLog2ProbsInCats,
log2CatProbs,
lengthNorm)
: new CompiledTradNaiveBayesClassifier(categories,
tokenizerFactory,
tokenToLog2ProbsInCats,
log2CatProbs,
lengthNorm);
}
}
private static class CompiledBinaryTradNaiveBayesClassifier
implements ConditionalClassifier {
private final TokenizerFactory mTokenizerFactory;
private final Map mTokenToLog2ProbDiff;
private final double mLog2CatProbDiff;
private final double mLengthNorm;
private final String[] mCats01;
private final String[] mCats10;
CompiledBinaryTradNaiveBayesClassifier(String[] categories,
TokenizerFactory tokenizerFactory,
Map tokenToLog2ProbsInCats,
double[] log2CatProbs,
double lengthNorm) {
mTokenizerFactory = tokenizerFactory;
mTokenToLog2ProbDiff = new HashMap();
for (Map.Entry entry : tokenToLog2ProbsInCats.entrySet()) {
String token = entry.getKey();
double[] log2Probs = entry.getValue();
double log2ProbDiff = (log2Probs[0] - log2Probs[1]) / com.aliasi.util.Math.LOG2_E;
mTokenToLog2ProbDiff.put(token,log2ProbDiff);
}
mLog2CatProbDiff = (log2CatProbs[0] - log2CatProbs[1]) / com.aliasi.util.Math.LOG2_E;
mLengthNorm = lengthNorm;
mCats01 = new String[] { categories[0], categories[1] };
mCats10 = new String[] { categories[1], categories[0] };
}
public ConditionalClassification classify(CharSequence in) {
double logDiff = 0.0;
char[] cs = Strings.toCharArray(in);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
int tokenCount = 0;
for (String token : tokenizer) {
Double tokLogDiff = mTokenToLog2ProbDiff.get(token);
++tokenCount;
if (tokLogDiff == null) continue;
logDiff += tokLogDiff;
}
if ((!Double.isNaN(mLengthNorm)) && (tokenCount > 0))
logDiff *= mLengthNorm / tokenCount;
logDiff += mLog2CatProbDiff;
double expProd = Math.exp(logDiff);
double p0 = expProd / (1.0 + expProd);
double p1 = 1.0 - p0;
return (p0 > p1)
? new ConditionalClassification(mCats01, new double[] { p0, p1 })
: new ConditionalClassification(mCats10, new double[] { p1, p0 });
}
}
private static class CompiledTradNaiveBayesClassifier
implements JointClassifier {
private final TokenizerFactory mTokenizerFactory;
private final String[] mCategories;
private final Map mTokenToLog2ProbsInCats;
private final double[] mLog2CatProbs;
private final double mLengthNorm;
CompiledTradNaiveBayesClassifier(String[] categories,
TokenizerFactory tokenizerFactory,
Map tokenToLog2ProbsInCats,
double[] log2CatProbs,
double lengthNorm) {
mCategories = categories;
mTokenizerFactory = tokenizerFactory;
mTokenToLog2ProbsInCats = tokenToLog2ProbsInCats;
mLog2CatProbs = log2CatProbs;
mLengthNorm = lengthNorm;
}
public JointClassification classify(CharSequence in) {
double[] logps = new double[mCategories.length];
char[] cs = Strings.toCharArray(in);
Tokenizer tokenizer = mTokenizerFactory.tokenizer(cs,0,cs.length);
int tokenCount = 0;
for (String token : tokenizer) {
double[] tokenLog2Probs = mTokenToLog2ProbsInCats.get(token);
++tokenCount;
if (tokenLog2Probs == null) continue;
for (int i = 0; i < logps.length; ++i) {
logps[i] += tokenLog2Probs[i];
}
}
if ((!Double.isNaN(mLengthNorm)) && (tokenCount > 0)) {
for (int i = 0; i < logps.length; ++i) {
logps[i] *= mLengthNorm / tokenCount;
}
}
for (int i = 0; i < logps.length; ++i) {
logps[i] += mLog2CatProbs[i];
}
return JointClassification.create(mCategories,logps);
}
}
}