All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.classify.MaximumEntropyClassifier Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.classify;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;

/**
 * Maximum entropy classifier for assignment 2.
 * 
 * @author Dan Klein
 */
public class MaximumEntropyClassifier implements ProbabilisticClassifier,
		Serializable {

	private static final long serialVersionUID = 1L;

	/**
	 * Factory for training MaximumEntropyClassifiers.
	 */
	public static class Factory implements ProbabilisticClassifierFactory {

		double sigma;
		int iterations;
		FeatureExtractor featureExtractor;

		public ProbabilisticClassifier trainClassifier(
				List> trainingData) {
			return trainClassifier(trainingData, true);
		}

		public ProbabilisticClassifier trainClassifier(
				List> trainingData, boolean verbose) {
			// build data encodings so the inner loops can be efficient
			if (verbose) Logger.i().startTrack("Building encoding");
			Encoding encoding = buildEncoding(trainingData);
			IndexLinearizer indexLinearizer = buildIndexLinearizer(encoding);
			double[] initialWeights = buildInitialWeights(indexLinearizer);
			EncodedDatum[] data = encodeData(trainingData, encoding);
			if (verbose) Logger.i().endTrack();

			// build a minimizer object
			LBFGSMinimizer minimizer = new LBFGSMinimizer(iterations);
			// build the objective function for this data
			DifferentiableFunction objective = new ObjectiveFunction(encoding, data,
					indexLinearizer, sigma);

			// learn our voting weights
			if (verbose) Logger.i().startTrack("Training weights");
			double[] weights = minimizer.minimize(objective, initialWeights, 1e-4, verbose);
			if (verbose) Logger.i().endTrack();

			// build a classifer using these weights (and the data encodings)
			return new MaximumEntropyClassifier(weights, encoding,
					indexLinearizer, featureExtractor);
		}

		private double[] buildInitialWeights(IndexLinearizer indexLinearizer) {
			return DoubleArrays.constantArray(0.0, indexLinearizer.getNumLinearIndexes());
		}

		private IndexLinearizer buildIndexLinearizer(Encoding encoding) {
			return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
		}

		private Encoding buildEncoding(List> data) {
			Indexer featureIndexer = new Indexer();
			Indexer labelIndexer = new Indexer();
			for (LabeledInstance labeledInstance : data) {
				L label = labeledInstance.getLabel();
				Counter features = featureExtractor.extractFeatures(labeledInstance
						.getInput());
				LabeledFeatureVector labeledDatum = new BasicLabeledFeatureVector(
						label, features);
				labelIndexer.getIndex(labeledDatum.getLabel());
				for (F feature : labeledDatum.getFeatures().keySet()) {
					featureIndexer.getIndex(feature);
				}
			}
			return new Encoding(featureIndexer, labelIndexer);
		}

		private EncodedDatum[] encodeData(List> data,
				Encoding encoding) {
			EncodedDatum[] encodedData = new EncodedDatum[data.size()];
			for (int i = 0; i < data.size(); i++) {
				LabeledInstance labeledInstance = data.get(i);
				L label = labeledInstance.getLabel();
				Counter features = featureExtractor.extractFeatures(labeledInstance
						.getInput());
				LabeledFeatureVector labeledFeatureVector = new BasicLabeledFeatureVector(
						label, features);
				encodedData[i] = EncodedDatum.encodeLabeledDatum(labeledFeatureVector,
						encoding);
			}
			return encodedData;
		}

		/**
		 * Sigma controls the variance on the prior / penalty term. 1.0 is a
		 * reasonable value for large problems, bigger sigma means LESS
		 * smoothing. Zero sigma is a special indicator that no smoothing is to
		 * be done. 

Iterations determines the maximum number of iterations * the optimization code can take before stopping. */ public Factory(double sigma, int iterations, FeatureExtractor featureExtractor) { this.sigma = sigma; this.iterations = iterations; this.featureExtractor = featureExtractor; } } /** * This is the MaximumEntropy objective function: the (negative) log * conditional likelihood of the training data, possibly with a penalty for * large weights. Note that this objective get MINIMIZED so it's the * negative of the objective we normally think of. */ public static class ObjectiveFunction implements DifferentiableFunction { IndexLinearizer indexLinearizer; Encoding encoding; EncodedDatum[] data; double sigma; double lastValue; double[] lastDerivative; double[] lastX; public int dimension() { return indexLinearizer.getNumLinearIndexes(); } public double valueAt(double[] x) { ensureCache(x); return lastValue; } public double[] derivativeAt(double[] x) { ensureCache(x); return lastDerivative; } private void ensureCache(double[] x) { if (requiresUpdate(lastX, x)) { Pair currentValueAndDerivative = calculate(x); lastValue = currentValueAndDerivative.getFirst(); lastDerivative = currentValueAndDerivative.getSecond(); lastX = x; } } private boolean requiresUpdate(double[] lastX, double[] x) { if (lastX == null) return true; for (int i = 0; i < x.length; i++) { if (lastX[i] != x[i]) return true; } return false; } /** * The most important part of the classifier learning process! This * method determines, for the given weight vector x, what the (negative) * log conditional likelihood of the data is, as well as the derivatives * of that likelihood wrt each weight parameter. */ private Pair calculate(double[] x) { double objective = 0.0; double[] derivatives = DoubleArrays.constantArray(0.0, dimension()); double[] classActivations = new double[encoding.getNumLabels()]; double[] classPosteriors = new double[encoding.getNumLabels()]; for (EncodedDatum datum : data) { // For each datum we get the activation for each class // and then the posteriors int numActiveFeatures = datum.getNumActiveFeatures(); for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { double activation = 0.0; for (int num = 0; num < numActiveFeatures; ++num) { int featureIndex = datum.getFeatureIndex(num); double featureCount = datum.getFeatureCount(num); int linearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, labelIndex); activation += x[linearFeatureIndex] * featureCount; } classActivations[labelIndex] = activation; } double logSumActivation = SloppyMath.logAdd(classActivations); int correctLabelIndex = datum.getLabelIndex(); // Log Prob objective += (classActivations[correctLabelIndex] - logSumActivation); // Class Posteriors for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { classPosteriors[labelIndex] = SloppyMath .exp(classActivations[labelIndex] - logSumActivation); } // Derivative: Feature Expectations for (int num = 0; num < numActiveFeatures; ++num) { int featureIndex = datum.getFeatureIndex(num); int correctLinearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, correctLabelIndex); double featureCount = datum.getFeatureCount(num); derivatives[correctLinearFeatureIndex] += featureCount; for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { int linearFeatureIndex = indexLinearizer.getLinearIndex( featureIndex, labelIndex); double classProb = classPosteriors[labelIndex]; derivatives[linearFeatureIndex] -= classProb * featureCount; } } } // Scale by -1 since we are minimizing negative log-liklihood objective *= -1; DoubleArrays.scale(derivatives, -1); // L2 Penalty for (int i = 0; i < x.length; ++i) { double weight = x[i]; objective += (weight * weight) / (2 * sigma * sigma); derivatives[i] += (weight) / (sigma * sigma); } return new Pair(objective, derivatives); } public ObjectiveFunction(Encoding encoding, EncodedDatum[] data, IndexLinearizer indexLinearizer, double sigma) { this.indexLinearizer = indexLinearizer; this.encoding = encoding; this.data = data; this.sigma = sigma; } public double[] unregularizedDerivativeAt(double[] x) { // TODO Auto-generated method stub return null; } } /** * EncodedDatums are sparse representations of (labeled) feature count * vectors for a given data point. Use getNumActiveFeatures() to see how * many features have non-zero count in a datum. Then, use getFeatureIndex() * and getFeatureCount() to retreive the number and count of each non-zero * feature. Use getLabelIndex() to get the label's number. */ public static class EncodedDatum { public static EncodedDatum encodeDatum(FeatureVector featureVector, Encoding encoding) { Counter features = featureVector.getFeatures(); Counter knownFeatures = new Counter(); for (F feature : features.keySet()) { if (encoding.getFeatureIndex(feature) < 0) continue; knownFeatures.incrementCount(feature, features.getCount(feature)); } int numActiveFeatures = knownFeatures.keySet().size(); int[] featureIndexes = new int[numActiveFeatures]; double[] featureCounts = new double[knownFeatures.keySet().size()]; int i = 0; for (F feature : knownFeatures.keySet()) { int index = encoding.getFeatureIndex(feature); double count = knownFeatures.getCount(feature); featureIndexes[i] = index; featureCounts[i] = count; i++; } EncodedDatum encodedDatum = new EncodedDatum(-1, featureIndexes, featureCounts); return encodedDatum; } public static EncodedDatum encodeLabeledDatum( LabeledFeatureVector labeledDatum, Encoding encoding) { EncodedDatum encodedDatum = encodeDatum(labeledDatum, encoding); encodedDatum.labelIndex = encoding.getLabelIndex(labeledDatum.getLabel()); return encodedDatum; } int labelIndex; int[] featureIndexes; double[] featureCounts; public int getLabelIndex() { return labelIndex; } public int getNumActiveFeatures() { return featureCounts.length; } public int getFeatureIndex(int num) { return featureIndexes[num]; } public double getFeatureCount(int num) { return featureCounts[num]; } public EncodedDatum(int labelIndex, int[] featureIndexes, double[] featureCounts) { this.labelIndex = labelIndex; this.featureIndexes = featureIndexes; this.featureCounts = featureCounts; } } private double[] weights; private Encoding encoding; private IndexLinearizer indexLinearizer; private transient FeatureExtractor featureExtractor; /** * */ public void setFeatureExtractor(FeatureExtractor featureExtractor) { this.featureExtractor = featureExtractor; } /** * Calculate the log probabilities of each class, for the given datum * (feature bundle). Note that the weighted votes (refered to as * activations) are *almost* log probabilities, but need to be normalized. */ private static double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding encoding, IndexLinearizer indexLinearizer) { double[] logProbabilities = new double[encoding.getNumLabels()]; for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { for (int num = 0; num < datum.getNumActiveFeatures(); ++num) { int featureIndex = datum.getFeatureIndex(num); double featureCount = datum.getFeatureCount(num); int linearFeatureIndex = indexLinearizer.getLinearIndex(featureIndex, labelIndex); logProbabilities[labelIndex] += weights[linearFeatureIndex] * featureCount; } } double logSumProb = SloppyMath.logAdd(logProbabilities); for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) { logProbabilities[labelIndex] -= logSumProb; } return logProbabilities; } public Counter getProbabilities(I input) { FeatureVector featureVector = new BasicFeatureVector(featureExtractor .extractFeatures(input)); return getProbabilities(featureVector); } private Counter getProbabilities(FeatureVector featureVector) { EncodedDatum encodedDatum = EncodedDatum.encodeDatum(featureVector, encoding); double[] logProbabilities = getLogProbabilities(encodedDatum, weights, encoding, indexLinearizer); return logProbabiltyArrayToProbabiltyCounter(logProbabilities); } private Counter logProbabiltyArrayToProbabiltyCounter(double[] logProbabilities) { Counter probabiltyCounter = new Counter(); for (int labelIndex = 0; labelIndex < logProbabilities.length; labelIndex++) { double logProbability = logProbabilities[labelIndex]; double probability = Math.exp(logProbability); L label = encoding.getLabel(labelIndex); probabiltyCounter.setCount(label, probability); } return probabiltyCounter; } public L getLabel(I input) { return getProbabilities(input).argMax(); } public MaximumEntropyClassifier(double[] weights, Encoding encoding, IndexLinearizer indexLinearizer, FeatureExtractor featureExtractor) { this.weights = weights; this.encoding = encoding; this.indexLinearizer = indexLinearizer; this.featureExtractor = featureExtractor; } public static void main(String[] args) { // Execution.init(args); // create datums LabeledInstance datum1 = new LabeledInstance( "cat", new String[] { "fuzzy", "claws", "small" }); LabeledInstance datum2 = new LabeledInstance( "bear", new String[] { "fuzzy", "claws", "big" }); LabeledInstance datum3 = new LabeledInstance( "cat", new String[] { "claws", "medium" }); LabeledInstance datum4 = new LabeledInstance( "cat", new String[] { "claws", "small" }); // create training set List> trainingData = new ArrayList>(); trainingData.add(datum1); trainingData.add(datum2); trainingData.add(datum3); // create test set List> testData = new ArrayList>(); testData.add(datum4); // build classifier FeatureExtractor featureExtractor = new FeatureExtractor() { /** * */ private static final long serialVersionUID = 8296036312980792350L; public Counter extractFeatures(String[] featureArray) { return new Counter(Arrays.asList(featureArray)); } }; MaximumEntropyClassifier.Factory maximumEntropyClassifierFactory = new MaximumEntropyClassifier.Factory( 1.0, 20, featureExtractor); ProbabilisticClassifier maximumEntropyClassifier = maximumEntropyClassifierFactory .trainClassifier(trainingData); System.out.println("Probabilities on test instance: " + maximumEntropyClassifier.getProbabilities(datum4.getInput())); // Execution.finish(); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy