![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.classify.MaximumEntropyClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.classify;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
/**
* Maximum entropy classifier for assignment 2.
*
* @author Dan Klein
*/
public class MaximumEntropyClassifier implements ProbabilisticClassifier,
Serializable {
private static final long serialVersionUID = 1L;
/**
* Factory for training MaximumEntropyClassifiers.
*/
public static class Factory implements ProbabilisticClassifierFactory {
double sigma;
int iterations;
FeatureExtractor featureExtractor;
public ProbabilisticClassifier trainClassifier(
List> trainingData) {
return trainClassifier(trainingData, true);
}
public ProbabilisticClassifier trainClassifier(
List> trainingData, boolean verbose) {
// build data encodings so the inner loops can be efficient
if (verbose) Logger.i().startTrack("Building encoding");
Encoding encoding = buildEncoding(trainingData);
IndexLinearizer indexLinearizer = buildIndexLinearizer(encoding);
double[] initialWeights = buildInitialWeights(indexLinearizer);
EncodedDatum[] data = encodeData(trainingData, encoding);
if (verbose) Logger.i().endTrack();
// build a minimizer object
LBFGSMinimizer minimizer = new LBFGSMinimizer(iterations);
// build the objective function for this data
DifferentiableFunction objective = new ObjectiveFunction(encoding, data,
indexLinearizer, sigma);
// learn our voting weights
if (verbose) Logger.i().startTrack("Training weights");
double[] weights = minimizer.minimize(objective, initialWeights, 1e-4, verbose);
if (verbose) Logger.i().endTrack();
// build a classifer using these weights (and the data encodings)
return new MaximumEntropyClassifier(weights, encoding,
indexLinearizer, featureExtractor);
}
private double[] buildInitialWeights(IndexLinearizer indexLinearizer) {
return DoubleArrays.constantArray(0.0, indexLinearizer.getNumLinearIndexes());
}
private IndexLinearizer buildIndexLinearizer(Encoding encoding) {
return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
}
private Encoding buildEncoding(List> data) {
Indexer featureIndexer = new Indexer();
Indexer labelIndexer = new Indexer();
for (LabeledInstance labeledInstance : data) {
L label = labeledInstance.getLabel();
Counter features = featureExtractor.extractFeatures(labeledInstance
.getInput());
LabeledFeatureVector labeledDatum = new BasicLabeledFeatureVector(
label, features);
labelIndexer.getIndex(labeledDatum.getLabel());
for (F feature : labeledDatum.getFeatures().keySet()) {
featureIndexer.getIndex(feature);
}
}
return new Encoding(featureIndexer, labelIndexer);
}
private EncodedDatum[] encodeData(List> data,
Encoding encoding) {
EncodedDatum[] encodedData = new EncodedDatum[data.size()];
for (int i = 0; i < data.size(); i++) {
LabeledInstance labeledInstance = data.get(i);
L label = labeledInstance.getLabel();
Counter features = featureExtractor.extractFeatures(labeledInstance
.getInput());
LabeledFeatureVector labeledFeatureVector = new BasicLabeledFeatureVector(
label, features);
encodedData[i] = EncodedDatum.encodeLabeledDatum(labeledFeatureVector,
encoding);
}
return encodedData;
}
/**
* Sigma controls the variance on the prior / penalty term. 1.0 is a
* reasonable value for large problems, bigger sigma means LESS
* smoothing. Zero sigma is a special indicator that no smoothing is to
* be done. Iterations determines the maximum number of iterations
* the optimization code can take before stopping.
*/
public Factory(double sigma, int iterations, FeatureExtractor featureExtractor) {
this.sigma = sigma;
this.iterations = iterations;
this.featureExtractor = featureExtractor;
}
}
/**
* This is the MaximumEntropy objective function: the (negative) log
* conditional likelihood of the training data, possibly with a penalty for
* large weights. Note that this objective get MINIMIZED so it's the
* negative of the objective we normally think of.
*/
public static class ObjectiveFunction implements DifferentiableFunction {
IndexLinearizer indexLinearizer;
Encoding encoding;
EncodedDatum[] data;
double sigma;
double lastValue;
double[] lastDerivative;
double[] lastX;
public int dimension() {
return indexLinearizer.getNumLinearIndexes();
}
public double valueAt(double[] x) {
ensureCache(x);
return lastValue;
}
public double[] derivativeAt(double[] x) {
ensureCache(x);
return lastDerivative;
}
private void ensureCache(double[] x) {
if (requiresUpdate(lastX, x)) {
Pair currentValueAndDerivative = calculate(x);
lastValue = currentValueAndDerivative.getFirst();
lastDerivative = currentValueAndDerivative.getSecond();
lastX = x;
}
}
private boolean requiresUpdate(double[] lastX, double[] x) {
if (lastX == null) return true;
for (int i = 0; i < x.length; i++) {
if (lastX[i] != x[i]) return true;
}
return false;
}
/**
* The most important part of the classifier learning process! This
* method determines, for the given weight vector x, what the (negative)
* log conditional likelihood of the data is, as well as the derivatives
* of that likelihood wrt each weight parameter.
*/
private Pair calculate(double[] x) {
double objective = 0.0;
double[] derivatives = DoubleArrays.constantArray(0.0, dimension());
double[] classActivations = new double[encoding.getNumLabels()];
double[] classPosteriors = new double[encoding.getNumLabels()];
for (EncodedDatum datum : data) {
// For each datum we get the activation for each class
// and then the posteriors
int numActiveFeatures = datum.getNumActiveFeatures();
for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
double activation = 0.0;
for (int num = 0; num < numActiveFeatures; ++num) {
int featureIndex = datum.getFeatureIndex(num);
double featureCount = datum.getFeatureCount(num);
int linearFeatureIndex = indexLinearizer.getLinearIndex(
featureIndex, labelIndex);
activation += x[linearFeatureIndex] * featureCount;
}
classActivations[labelIndex] = activation;
}
double logSumActivation = SloppyMath.logAdd(classActivations);
int correctLabelIndex = datum.getLabelIndex();
// Log Prob
objective += (classActivations[correctLabelIndex] - logSumActivation);
// Class Posteriors
for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
classPosteriors[labelIndex] = SloppyMath
.exp(classActivations[labelIndex] - logSumActivation);
}
// Derivative: Feature Expectations
for (int num = 0; num < numActiveFeatures; ++num) {
int featureIndex = datum.getFeatureIndex(num);
int correctLinearFeatureIndex = indexLinearizer.getLinearIndex(
featureIndex, correctLabelIndex);
double featureCount = datum.getFeatureCount(num);
derivatives[correctLinearFeatureIndex] += featureCount;
for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
int linearFeatureIndex = indexLinearizer.getLinearIndex(
featureIndex, labelIndex);
double classProb = classPosteriors[labelIndex];
derivatives[linearFeatureIndex] -= classProb * featureCount;
}
}
}
// Scale by -1 since we are minimizing negative log-liklihood
objective *= -1;
DoubleArrays.scale(derivatives, -1);
// L2 Penalty
for (int i = 0; i < x.length; ++i) {
double weight = x[i];
objective += (weight * weight) / (2 * sigma * sigma);
derivatives[i] += (weight) / (sigma * sigma);
}
return new Pair(objective, derivatives);
}
public ObjectiveFunction(Encoding encoding, EncodedDatum[] data,
IndexLinearizer indexLinearizer, double sigma) {
this.indexLinearizer = indexLinearizer;
this.encoding = encoding;
this.data = data;
this.sigma = sigma;
}
public double[] unregularizedDerivativeAt(double[] x) {
// TODO Auto-generated method stub
return null;
}
}
/**
* EncodedDatums are sparse representations of (labeled) feature count
* vectors for a given data point. Use getNumActiveFeatures() to see how
* many features have non-zero count in a datum. Then, use getFeatureIndex()
* and getFeatureCount() to retreive the number and count of each non-zero
* feature. Use getLabelIndex() to get the label's number.
*/
public static class EncodedDatum {
public static EncodedDatum encodeDatum(FeatureVector featureVector,
Encoding encoding) {
Counter features = featureVector.getFeatures();
Counter knownFeatures = new Counter();
for (F feature : features.keySet()) {
if (encoding.getFeatureIndex(feature) < 0) continue;
knownFeatures.incrementCount(feature, features.getCount(feature));
}
int numActiveFeatures = knownFeatures.keySet().size();
int[] featureIndexes = new int[numActiveFeatures];
double[] featureCounts = new double[knownFeatures.keySet().size()];
int i = 0;
for (F feature : knownFeatures.keySet()) {
int index = encoding.getFeatureIndex(feature);
double count = knownFeatures.getCount(feature);
featureIndexes[i] = index;
featureCounts[i] = count;
i++;
}
EncodedDatum encodedDatum = new EncodedDatum(-1, featureIndexes, featureCounts);
return encodedDatum;
}
public static EncodedDatum encodeLabeledDatum(
LabeledFeatureVector labeledDatum, Encoding encoding) {
EncodedDatum encodedDatum = encodeDatum(labeledDatum, encoding);
encodedDatum.labelIndex = encoding.getLabelIndex(labeledDatum.getLabel());
return encodedDatum;
}
int labelIndex;
int[] featureIndexes;
double[] featureCounts;
public int getLabelIndex() {
return labelIndex;
}
public int getNumActiveFeatures() {
return featureCounts.length;
}
public int getFeatureIndex(int num) {
return featureIndexes[num];
}
public double getFeatureCount(int num) {
return featureCounts[num];
}
public EncodedDatum(int labelIndex, int[] featureIndexes, double[] featureCounts) {
this.labelIndex = labelIndex;
this.featureIndexes = featureIndexes;
this.featureCounts = featureCounts;
}
}
private double[] weights;
private Encoding encoding;
private IndexLinearizer indexLinearizer;
private transient FeatureExtractor featureExtractor;
/**
*
*/
public void setFeatureExtractor(FeatureExtractor featureExtractor) {
this.featureExtractor = featureExtractor;
}
/**
* Calculate the log probabilities of each class, for the given datum
* (feature bundle). Note that the weighted votes (refered to as
* activations) are *almost* log probabilities, but need to be normalized.
*/
private static double[] getLogProbabilities(EncodedDatum datum,
double[] weights, Encoding encoding, IndexLinearizer indexLinearizer) {
double[] logProbabilities = new double[encoding.getNumLabels()];
for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
for (int num = 0; num < datum.getNumActiveFeatures(); ++num) {
int featureIndex = datum.getFeatureIndex(num);
double featureCount = datum.getFeatureCount(num);
int linearFeatureIndex = indexLinearizer.getLinearIndex(featureIndex,
labelIndex);
logProbabilities[labelIndex] += weights[linearFeatureIndex] * featureCount;
}
}
double logSumProb = SloppyMath.logAdd(logProbabilities);
for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
logProbabilities[labelIndex] -= logSumProb;
}
return logProbabilities;
}
public Counter getProbabilities(I input) {
FeatureVector featureVector = new BasicFeatureVector(featureExtractor
.extractFeatures(input));
return getProbabilities(featureVector);
}
private Counter getProbabilities(FeatureVector featureVector) {
EncodedDatum encodedDatum = EncodedDatum.encodeDatum(featureVector, encoding);
double[] logProbabilities = getLogProbabilities(encodedDatum, weights, encoding,
indexLinearizer);
return logProbabiltyArrayToProbabiltyCounter(logProbabilities);
}
private Counter logProbabiltyArrayToProbabiltyCounter(double[] logProbabilities) {
Counter probabiltyCounter = new Counter();
for (int labelIndex = 0; labelIndex < logProbabilities.length; labelIndex++) {
double logProbability = logProbabilities[labelIndex];
double probability = Math.exp(logProbability);
L label = encoding.getLabel(labelIndex);
probabiltyCounter.setCount(label, probability);
}
return probabiltyCounter;
}
public L getLabel(I input) {
return getProbabilities(input).argMax();
}
public MaximumEntropyClassifier(double[] weights, Encoding encoding,
IndexLinearizer indexLinearizer, FeatureExtractor featureExtractor) {
this.weights = weights;
this.encoding = encoding;
this.indexLinearizer = indexLinearizer;
this.featureExtractor = featureExtractor;
}
public static void main(String[] args) {
// Execution.init(args);
// create datums
LabeledInstance datum1 = new LabeledInstance(
"cat", new String[] { "fuzzy", "claws", "small" });
LabeledInstance datum2 = new LabeledInstance(
"bear", new String[] { "fuzzy", "claws", "big" });
LabeledInstance datum3 = new LabeledInstance(
"cat", new String[] { "claws", "medium" });
LabeledInstance datum4 = new LabeledInstance(
"cat", new String[] { "claws", "small" });
// create training set
List> trainingData = new ArrayList>();
trainingData.add(datum1);
trainingData.add(datum2);
trainingData.add(datum3);
// create test set
List> testData = new ArrayList>();
testData.add(datum4);
// build classifier
FeatureExtractor featureExtractor = new FeatureExtractor() {
/**
*
*/
private static final long serialVersionUID = 8296036312980792350L;
public Counter extractFeatures(String[] featureArray) {
return new Counter(Arrays.asList(featureArray));
}
};
MaximumEntropyClassifier.Factory maximumEntropyClassifierFactory = new MaximumEntropyClassifier.Factory(
1.0, 20, featureExtractor);
ProbabilisticClassifier maximumEntropyClassifier = maximumEntropyClassifierFactory
.trainClassifier(trainingData);
System.out.println("Probabilities on test instance: "
+ maximumEntropyClassifier.getProbabilities(datum4.getInput()));
// Execution.finish();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy