com.aliasi.classify.BernoulliClassifier Maven / Gradle / Ivy
Show all versions of aliasi-lingpipe Show documentation
/*
* LingPipe v. 4.1.0
* Copyright (C) 2003-2011 Alias-i
*
* This program is licensed under the Alias-i Royalty Free License
* Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i
* Royalty Free License Version 1 for more details.
*
* You should have received a copy of the Alias-i Royalty Free License
* Version 1 along with this program; if not, visit
* http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact
* Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211,
* +1 (718) 290-9170.
*/
package com.aliasi.classify;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.stats.MultivariateEstimator;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Counter;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.IOException;
import java.io.Serializable;
/**
* A BernoulliClassifier
provides a feature-based
* classifier where feature values are reduced to booleans based on a
* specified threshold. Training events are supplied in the usual
* way through the {@link #handle(Classified)} method.
*
* Given a feature threshold of t
, any feature with
* value strictly greater than the threshold t
for a
* given input is activated, and all other features are not activated
* for that input.
*
*
The likelihood of a feature in a category is estimated with the
* training sample counts using add-one smoothing (also known as
* Laplace smoothing, or a uniform Dirichlet prior). There is also
* a term for the category distribution. Suppose F
is
* the complete set of features seen during training. Further suppose
* that count(cat)
is the number of training samples
* for category cat
, and count(cat,feat)
* is the number of training instaces of the specified category that
* had the specified feature activated. Thus the contribution of
* each feature is computed by:
*
*
* p(+feat|cat) = (count(cat,feat) + 1) / (count(cat)+2)
* p(-feat|cat) = 1.0 - p(cat,feat)
*
* Assuming the total number of training instances is totalCount
,
* we use a simple maximum-likelihood estimate for the category probability:
*
*
* p(cat) = count(cat) / totalCount
*
* With these two definitions, we define the joint probability estimate for
* a category cat
given activated features
* {f[0],...,f[n-1]}
and unactivated features
* {g[0],...,g[m-1]}
is:
*
*
* p(cat,{f[0],...f[n-1]})
* = p(cat)
* * Πi < n p(f[i]|cat)
* * Πj < m p(-g[j]|cat)
*
* The {@link JointClassification} class requires log (base 2) estimates,
* and is responsible for converting these to conditional estimates.
* The scores in this case are just the log2 joint estimates.
*
*
The dynamic form of the estimator may be used for classification,
* but it is not very efficient. It loops over every feature for every
* category.
*
*
Serialization and Compilation
*
* The serialized version of a Bernoulli classifier will
* deserialize as an equivalent instance of
* BernoulliClassifier
. In order to serialize a
* Bernoulli classifier, the feature extractor must be serializable.
* Otherwise an exception will be raised during serialization.
*
*
Compilation is not yet implemented.
*
* @author Bob Carpenter
* @version 4.0.0
* @since LingPipe3.1
* @param the type of object classified
*/
public class BernoulliClassifier
implements JointClassifier,
ObjectHandler>,
Serializable {
static final long serialVersionUID = -7761909693358968780L;
private final MultivariateEstimator mCategoryDistribution;
private final FeatureExtractor mFeatureExtractor;
private final double mActivationThreshold;
private final Set mFeatureSet;
private final Map> mFeatureDistributionMap;
/**
* Construct a Bernoulli classifier with the specified feature
* extractor and the default feature activation threshold of 0.0.
*
* @param featureExtractor Feature extractor for classification.
*/
public BernoulliClassifier(FeatureExtractor featureExtractor) {
this(featureExtractor,0.0);
}
/**
* Construct a Bernoulli classifier with the specified feature
* extractor and specified feature activation threshold.
*
* @param featureExtractor Feature extractor for classification.
* @param featureActivationThreshold The threshold for feature
* activation (see the class documentation).
*/
public BernoulliClassifier(FeatureExtractor featureExtractor,
double featureActivationThreshold) {
this(new MultivariateEstimator(),
featureExtractor,
featureActivationThreshold,
new HashSet(),
new HashMap>());
}
BernoulliClassifier(MultivariateEstimator catDistro,
FeatureExtractor featureExtractor,
double activationThreshold,
Set featureSet,
Map> featureDistributionMap) {
mCategoryDistribution = catDistro;
mFeatureExtractor = featureExtractor;
mActivationThreshold = activationThreshold;
mFeatureSet = featureSet;
mFeatureDistributionMap = featureDistributionMap;
}
/**
* Returns the feature activation threshold.
*
* @return The feature activation threshold for this classifier.
*/
public double featureActivationThreshold() {
return mActivationThreshold;
}
/**
* Return the feature extractor for this classifier.
*
* @return The feature extractor for this classifier.
*/
public FeatureExtractor featureExtractor() {
return mFeatureExtractor;
}
/**
* Returns a copy of the list the categories for this classifier.
*
* @return The categories for this classifier.
*/
public String[] categories() {
String[] categories = new String[mCategoryDistribution.numDimensions()];
for (int i = 0; i < mCategoryDistribution.numDimensions(); ++i)
categories[i] = mCategoryDistribution.label(i);
return categories;
}
/**
* Handle the specified training classified object.
*
* @param classified Classified object to add to handle
* as training data.
*/
public void handle(Classified classified) {
handle(classified.getObject(),
classified.getClassification());
}
/**
* Handle the specified training event, consisting of an input
* and its first-best classification.
*
* @param input Object whose classification result is being
* trained.
* @param classification Classification result for object.
*/
void handle(E input, Classification classification) {
String category = classification.bestCategory();
mCategoryDistribution.train(category,1L);
ObjectToCounterMap categoryCounter
= mFeatureDistributionMap.get(category);
if (categoryCounter == null) {
categoryCounter = new ObjectToCounterMap();
mFeatureDistributionMap.put(category,categoryCounter);
}
for (String feature : activeFeatureSet(input)) {
categoryCounter.increment(feature);
mFeatureSet.add(feature);
}
}
/**
* Classify the specified input using this Bernoulli classifier.
* See the class documentation above for mathematical details.
*
* @param input Input to classify.
* @return Classification of the specified input.
*/
public JointClassification classify(E input) {
Set activeFeatureSet = activeFeatureSet(input);
Set inactiveFeatureSet = new HashSet(mFeatureSet);
inactiveFeatureSet.removeAll(activeFeatureSet);
String[] activeFeatures
= activeFeatureSet.toArray(Strings.EMPTY_STRING_ARRAY);
String[] inactiveFeatures
= inactiveFeatureSet.toArray(Strings.EMPTY_STRING_ARRAY);
ObjectToDoubleMap categoryToLog2P
= new ObjectToDoubleMap();
int numCategories = mCategoryDistribution.numDimensions();
for (long i = 0; i < numCategories; ++i) {
String category = mCategoryDistribution.label(i);
double log2P = com.aliasi.util.Math.log2(mCategoryDistribution.probability(i));
double categoryCount
= mCategoryDistribution.getCount(i);
ObjectToCounterMap categoryFeatureCounts
= mFeatureDistributionMap.get(category);
for (String activeFeature : activeFeatures) {
double featureCount = categoryFeatureCounts.getCount(activeFeature);
if (featureCount == 0.0) continue; // ignore unknown features
log2P += com.aliasi.util.Math.log2((featureCount+1.0) / (categoryCount+2.0));
}
for (String inactiveFeature : inactiveFeatures) {
double notFeatureCount
= categoryCount
- categoryFeatureCounts.getCount(inactiveFeature);
log2P += com.aliasi.util.Math.log2((notFeatureCount + 1.0) / (categoryCount + 2.0));
}
categoryToLog2P.set(category,log2P);
}
String[] categories = new String[numCategories];
double[] log2Ps = new double[numCategories];
List> scoredObjectList
= categoryToLog2P.scoredObjectsOrderedByValueList();
for (int i = 0; i < numCategories; ++i) {
ScoredObject so = scoredObjectList.get(i);
categories[i] = so.getObject();
log2Ps[i] = so.score();
}
return new JointClassification(categories,log2Ps);
}
Object writeReplace() {
return new Serializer(this);
}
private Set activeFeatureSet(E input) {
Set activeFeatureSet = new HashSet();
Map featureMap
= mFeatureExtractor.features(input);
for (Map.Entry entry : featureMap.entrySet()) {
String feature = entry.getKey();
Number val = entry.getValue();
if (val.doubleValue() > mActivationThreshold)
activeFeatureSet.add(feature);
}
return activeFeatureSet;
}
static class Serializer extends AbstractExternalizable {
static final long serialVersionUID = 4803666611627400222L;
final BernoulliClassifier mClassifier;
public Serializer(BernoulliClassifier classifier) {
mClassifier = classifier;
}
public Serializer() {
this(null);
}
@Override
public void writeExternal(ObjectOutput objOut) throws IOException {
objOut.writeObject(mClassifier.mCategoryDistribution);
objOut.writeObject(mClassifier.mFeatureExtractor);
objOut.writeDouble(mClassifier.mActivationThreshold);
objOut.writeInt(mClassifier.mFeatureSet.size());
for (String feature : mClassifier.mFeatureSet)
objOut.writeUTF(feature);
objOut.writeInt(mClassifier.mFeatureDistributionMap.size());
for (Map.Entry> entry : mClassifier.mFeatureDistributionMap.entrySet()) {
objOut.writeUTF(entry.getKey());
ObjectToCounterMap map = entry.getValue();
objOut.writeInt(map.size());
for (Map.Entry entry2 : map.entrySet()) {
objOut.writeUTF(entry2.getKey());
objOut.writeInt(entry2.getValue().intValue());
}
}
}
@Override
public Object read(ObjectInput objIn)
throws ClassNotFoundException, IOException {
MultivariateEstimator estimator
= (MultivariateEstimator) objIn.readObject();
@SuppressWarnings("unchecked")
FeatureExtractor featureExtractor
= (FeatureExtractor) objIn.readObject();
double activationThreshold = objIn.readDouble();
int featureSetSize = objIn.readInt();
Set featureSet = new HashSet(2 * featureSetSize);
for (int i = 0; i < featureSetSize; ++i)
featureSet.add(objIn.readUTF());
int featureDistributionMapSize = objIn.readInt();
Map> featureDistributionMap
= new HashMap>(2*featureDistributionMapSize);
for (int i = 0; i < featureDistributionMapSize; ++i) {
String key = objIn.readUTF();
int mapSize = objIn.readInt();
ObjectToCounterMap otc = new ObjectToCounterMap();
featureDistributionMap.put(key,otc);
for (int j = 0; j < mapSize; ++j) {
String key2 = objIn.readUTF();
int count = objIn.readInt();
otc.set(key2,count);
}
}
return new BernoulliClassifier(estimator,
featureExtractor,
activationThreshold,
featureSet,
featureDistributionMap);
}
}
}