All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.MaxEnt Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org.  For further
information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.io.PrintStream;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;

/**
 * Maximum Entropy (AKA Multivariate Logistic Regression) classifier.
 @author Andrew McCallum [email protected]
 */

public class MaxEnt extends Classifier implements Serializable
{
	protected double [] parameters;										// indexed by 
	protected int defaultFeatureIndex;
	protected FeatureSelection featureSelection;
	protected FeatureSelection[] perClassFeatureSelection;

	// The default feature is always the feature with highest index
	public MaxEnt (Pipe dataPipe,
			double[] parameters,
			FeatureSelection featureSelection,
			FeatureSelection[] perClassFeatureSelection)
	{
		super (dataPipe);
		assert (featureSelection == null || perClassFeatureSelection == null);
		if (parameters != null)
			this.parameters = parameters;
		else
			this.parameters = new double[getNumParameters(dataPipe)];
		this.featureSelection = featureSelection;
		this.perClassFeatureSelection = perClassFeatureSelection;
		this.defaultFeatureIndex = dataPipe.getDataAlphabet().size();
//		assert (parameters.getNumCols() == defaultFeatureIndex+1);
	}

	public MaxEnt (Pipe dataPipe,	double[] parameters,
			FeatureSelection featureSelection)
	{
		this (dataPipe, parameters, featureSelection, null);
	}

	public MaxEnt (Pipe dataPipe, double[] parameters,
			FeatureSelection[] perClassFeatureSelection)
	{
		this (dataPipe, parameters, null, perClassFeatureSelection);
	}

	public MaxEnt (Pipe dataPipe,	double[] parameters)
	{
		this (dataPipe, parameters, null, null);
	}

	public double[] getParameters () {
		return parameters;
	}

	public int getNumParameters () {
		assert (this.instancePipe.getDataAlphabet() != null);
		assert (this.instancePipe.getTargetAlphabet() != null);
		return MaxEnt.getNumParameters(this.instancePipe);
	}

	public static int getNumParameters (Pipe instancePipe) {
		return (instancePipe.getDataAlphabet().size() + 1) * instancePipe.getTargetAlphabet().size();
	}

	public void setParameters(double[] parameters){
		this.parameters = parameters;
	}

	public void setParameter (int classIndex, int featureIndex, double value)
	{
		parameters[classIndex*(getAlphabet().size()+1) + featureIndex] = value;
	}

	public FeatureSelection getFeatureSelection() {
		return featureSelection;
	}

	public MaxEnt setFeatureSelection (FeatureSelection fs) {
		featureSelection = fs;
		return this;
	}


	public FeatureSelection[] getPerClassFeatureSelection(){
		return perClassFeatureSelection;
	}

	public MaxEnt setPerClassFeatureSelection (FeatureSelection[] fss){
		this.perClassFeatureSelection = fss;
		return this;
	}


	public int getDefaultFeatureIndex(){
		return defaultFeatureIndex;
	}

	public void setDefaultFeatureIndex(int defaultFeatureIndex){
		this.defaultFeatureIndex = defaultFeatureIndex;
	}


	public void getUnnormalizedClassificationScores (Instance instance, double[] scores)
	{
		//  arrayOutOfBounds if pipe has grown since training
		//        int numFeatures = getAlphabet().size() + 1;
		int numFeatures = this.defaultFeatureIndex + 1;

		int numLabels = getLabelAlphabet().size();
		assert (scores.length == numLabels);
		FeatureVector fv = (FeatureVector) instance.getData ();
		// Make sure the feature vector's feature dictionary matches
		// what we are expecting from our data pipe (and thus our notion
		// of feature probabilities.
		assert (fv.getAlphabet ()
				== this.instancePipe.getDataAlphabet ());

		// Include the feature weights according to each label
		for (int li = 0; li < numLabels; li++) {
			scores[li] = parameters[li*numFeatures + defaultFeatureIndex]
			                        + MatrixOps.rowDotProduct (parameters, numFeatures,
			                        		li, fv,
			                        		defaultFeatureIndex,
			                        		(perClassFeatureSelection == null
			                        				? featureSelection
			                        						: perClassFeatureSelection[li]));
		}
	}

	public void getClassificationScores (Instance instance, double[] scores)
	{
		getUnnormalizedClassificationScores(instance, scores);
		// Move scores to a range where exp() is accurate, and normalize
		int numLabels = getLabelAlphabet().size();
		double max = MatrixOps.max (scores);
		double sum = 0;
		for (int li = 0; li < numLabels; li++)
			sum += (scores[li] = Math.exp (scores[li] - max));
		for (int li = 0; li < numLabels; li++) {
			scores[li] /= sum;
			// xxxNaN assert (!Double.isNaN(scores[li]));
		}
	}

	//modified by Limin Yao, to deal with decreasing the peak of some labels
	public void getClassificationScoresWithTemperature (Instance instance, double temperature, double[] scores)
	{
		getUnnormalizedClassificationScores(instance, scores);

		//scores should be divided by temperature, scores are sum of weighted features
		MatrixOps.timesEquals(scores, 1/temperature);

		// Move scores to a range where exp() is accurate, and normalize
		int numLabels = getLabelAlphabet().size();
		double max = MatrixOps.max (scores);
		double sum = 0;
		for (int li = 0; li < numLabels; li++)
			sum += (scores[li] = Math.exp (scores[li] - max));
		for (int li = 0; li < numLabels; li++) {
			scores[li] /= sum;
			// xxxNaN assert (!Double.isNaN(scores[li]));
		}
	}

	//modified by Limin Yao, using temperature classification score
	public Classification classify (Instance instance)
	{
		int numClasses = getLabelAlphabet().size();
		double[] scores = new double[numClasses];
		//getClassificationScores (instance, scores);
		getClassificationScores(instance, scores);
		// Create and return a Classification object
		return new Classification (instance, this,
				new LabelVector (getLabelAlphabet(),
						scores));
	}

	public void print () {
		print(System.out);
	}

	@Override
	public void print(PrintWriter out) {
		final Alphabet dict = getAlphabet();
		final LabelAlphabet labelDict = getLabelAlphabet();

		int numFeatures = dict.size() + 1;
		int numLabels = labelDict.size();

		// Include the feature weights according to each label
		for (int li = 0; li < numLabels; li++) {
			out.println ("FEATURES FOR CLASS "+labelDict.lookupObject (li));
			out.println ("  "+parameters [li*numFeatures + defaultFeatureIndex]);
			for (int i = 0; i < defaultFeatureIndex; i++) {
				Object name = dict.lookupObject (i);
				double weight = parameters [li*numFeatures + i];
				out.println (" "+name+" "+weight);
			}
		}
	}

	public void print (PrintStream out) {
		print(new PrintWriter(out));
	}

	//printRank, added by Limin Yao
	public void printRank (PrintWriter out)
	{
		final Alphabet dict = getAlphabet();
		final LabelAlphabet labelDict = getLabelAlphabet();

		int numFeatures = dict.size() + 1;
		int numLabels = labelDict.size();
	// Include the feature weights according to each label
		RankedFeatureVector rfv;
		double[] weights = new double[numFeatures-1]; // do not deal with the default feature
		for (int li = 0; li < numLabels; li++) {
			out.print ("FEATURES FOR CLASS "+labelDict.lookupObject (li) + " ");
			for (int i = 0; i < defaultFeatureIndex; i++) {
				double weight = parameters [li*numFeatures + i];
				weights[i] = weight;
			}
			rfv = new RankedFeatureVector(dict,weights);
			rfv.printByRank(out);
			out.println ("  "+parameters [li*numFeatures + defaultFeatureIndex] + " ");
		}
	}

	public void printExtremeFeatures (PrintWriter out,int num)
	{
		final Alphabet dict = getAlphabet();
		final LabelAlphabet labelDict = getLabelAlphabet();

		int numFeatures = dict.size() + 1;
		int numLabels = labelDict.size();

		// Include the feature weights according to each label
		RankedFeatureVector rfv;
		double[] weights = new double[numFeatures-1]; // do not deal with the default feature
		for (int li = 0; li < numLabels; li++) {
			out.print ("FEATURES FOR CLASS "+labelDict.lookupObject (li) + " ");
			for (int i = 0; i < defaultFeatureIndex; i++) {
				Object name = dict.lookupObject (i);
				double weight = parameters [li*numFeatures + i];
				weights[i] = weight;
			}
			rfv = new RankedFeatureVector(dict,weights);
			rfv.printTopK(out,num);
			out.print ("  "+parameters [li*numFeatures + defaultFeatureIndex] + " ");
			rfv.printLowerK(out, num);
			out.println();
		}
	}

	private static final long serialVersionUID = 1;
	private static final int CURRENT_SERIAL_VERSION = 1;
	static final int NULL_INTEGER = -1;

	private void writeObject(ObjectOutputStream out) throws IOException
	{
		out.writeInt(CURRENT_SERIAL_VERSION);
		out.writeObject(getInstancePipe());
		int np = parameters.length;
		out.writeInt(np);
		for (int p = 0; p < np; p++)
			out.writeDouble(parameters[p]);
		out.writeInt(defaultFeatureIndex);
		if (featureSelection == null)
			out.writeInt(NULL_INTEGER);
		else
		{
			out.writeInt(1);
			out.writeObject(featureSelection);
		}
		if (perClassFeatureSelection == null)
			out.writeInt(NULL_INTEGER);
		else
		{
			out.writeInt(perClassFeatureSelection.length);
			for (int i = 0; i < perClassFeatureSelection.length; i++)
				if (perClassFeatureSelection[i] == null)
					out.writeInt(NULL_INTEGER);
				else
				{
					out.writeInt(1);
					out.writeObject(perClassFeatureSelection[i]);
				}
		}
	}

	private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
		int version = in.readInt();
		if (version != CURRENT_SERIAL_VERSION)
			throw new ClassNotFoundException("Mismatched MaxEnt versions: wanted " +
					CURRENT_SERIAL_VERSION + ", got " +
					version);
		instancePipe = (Pipe) in.readObject();
		int np = in.readInt();
		parameters = new double[np];
		for (int p = 0; p < np; p++)
			parameters[p] = in.readDouble();
		defaultFeatureIndex = in.readInt();
		int opt = in.readInt();
		if (opt == 1)
			featureSelection = (FeatureSelection)in.readObject();
		int nfs = in.readInt();
		if (nfs >= 0)
		{
			perClassFeatureSelection = new FeatureSelection[nfs];
			for (int i = 0; i < nfs; i++)
			{
				opt = in.readInt();
				if (opt == 1)
					perClassFeatureSelection[i] = (FeatureSelection)in.readObject();
			}
		}
	}
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy