All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.MCMaxEnt Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org.  For further
information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.DenseVector;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;

/**
 * Maximum Entropy classifier.
 @author Andrew McCallum [email protected]
 */

public class MCMaxEnt extends Classifier implements Serializable
{
    double [] parameters;										// indexed by 
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perClassFeatureSelection;

    // The default feature is always the feature with highest index
    public MCMaxEnt (Pipe dataPipe,
                   double[] parameters,
                   FeatureSelection featureSelection,
                   FeatureSelection[] perClassFeatureSelection)
    {
        super (dataPipe);
        assert (featureSelection == null || perClassFeatureSelection == null);
        this.parameters = parameters;
        this.featureSelection = featureSelection;
        this.perClassFeatureSelection = perClassFeatureSelection;
        this.defaultFeatureIndex = dataPipe.getDataAlphabet().size();
//		assert (parameters.getNumCols() == defaultFeatureIndex+1);
    }

    public MCMaxEnt (Pipe dataPipe,
                   double[] parameters,
                   FeatureSelection featureSelection)
    {
        this (dataPipe, parameters, featureSelection, null);
    }

    public MCMaxEnt (Pipe dataPipe,
                   double[] parameters,
                   FeatureSelection[] perClassFeatureSelection)
    {
        this (dataPipe, parameters, null, perClassFeatureSelection);
    }

    public MCMaxEnt (Pipe dataPipe,
                   double[] parameters)
    {
        this (dataPipe, parameters, null, null);
    }

    public double[] getParameters ()
    {
        return parameters;
    }

    public void setParameter (int classIndex, int featureIndex, double value)
    {
        parameters[classIndex*(getAlphabet().size()+1) + featureIndex] = value;
    }

    public void getUnnormalizedClassificationScores (Instance instance, double[] scores)
    {
			  //  arrayOutOfBounds if pipe has grown since training 
			  //        int numFeatures = getAlphabet().size() + 1;
        int numFeatures = this.defaultFeatureIndex + 1;

        int numLabels = getLabelAlphabet().size();
        assert (scores.length == numLabels);
        FeatureVector fv = (FeatureVector) instance.getData ();
        // Make sure the feature vector's feature dictionary matches
        // what we are expecting from our data pipe (and thus our notion
        // of feature probabilities.
        assert (fv.getAlphabet ()
                == this.instancePipe.getDataAlphabet ());

        // Include the feature weights according to each label
        for (int li = 0; li < numLabels; li++) {
            scores[li] = parameters[li*numFeatures + defaultFeatureIndex]
                    + MatrixOps.rowDotProduct (parameters, numFeatures,
                            li, fv,
                            defaultFeatureIndex,
                            (perClassFeatureSelection == null
                    ? featureSelection
                    : perClassFeatureSelection[li]));
        }
    }

    public void getClassificationScores (Instance instance, double[] scores)
    {
        int numLabels = getLabelAlphabet().size();
        assert (scores.length == numLabels);
        FeatureVector fv = (FeatureVector) instance.getData ();
        // Make sure the feature vector's feature dictionary matches
        // what we are expecting from our data pipe (and thus our notion
        // of feature probabilities.
        assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
			  //  arrayOutOfBounds if pipe has grown since training 
			  //        int numFeatures = getAlphabet().size() + 1;
        int numFeatures = this.defaultFeatureIndex + 1;

        // Include the feature weights according to each label
        for (int li = 0; li < numLabels; li++) {
            scores[li] = parameters[li*numFeatures + defaultFeatureIndex]
                    + MatrixOps.rowDotProduct (parameters, numFeatures,
                            li, fv,
                            defaultFeatureIndex,
                            (perClassFeatureSelection == null
                    ? featureSelection
                    : perClassFeatureSelection[li]));
            // xxxNaN assert (!Double.isNaN(scores[li])) : "li="+li;
        }

        // Move scores to a range where exp() is accurate, and normalize
        double max = MatrixOps.max (scores);
        double sum = 0;
        for (int li = 0; li < numLabels; li++)
            sum += (scores[li] = Math.exp (scores[li] - max));
        for (int li = 0; li < numLabels; li++) {
            scores[li] /= sum;
            // xxxNaN assert (!Double.isNaN(scores[li]));
        }
    }

    public Classification classify (Instance instance)
    {
        int numClasses = getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        getClassificationScores (instance, scores);
        // Create and return a Classification object
        return new Classification (instance, this,
                new LabelVector (getLabelAlphabet(),
                        scores));
    }

	public void print () 
	{		
		final Alphabet dict = getAlphabet();
		final LabelAlphabet labelDict = getLabelAlphabet();
				
		int numFeatures = dict.size() + 1;
		int numLabels = labelDict.size();
		
		 // Include the feature weights according to each label
		 for (int li = 0; li < numLabels; li++) {
		 	System.out.println ("FEATURES FOR CLASS "+labelDict.lookupObject (li));
		 	System.out.println ("  "+parameters [li*numFeatures + defaultFeatureIndex]);
		 	for (int i = 0; i < defaultFeatureIndex; i++) {
		 		Object name = dict.lookupObject (i);
	            double weight = parameters [li*numFeatures + i];
		 		System.out.println (" "+name+" "+weight);
		 	}
		 }
	}
	
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;

    private void writeObject(ObjectOutputStream out) throws IOException
    {
        out.writeInt(CURRENT_SERIAL_VERSION);
        out.writeObject(getInstancePipe());
        int np = parameters.length;
        out.writeInt(np);
        for (int p = 0; p < np; p++)
            out.writeDouble(parameters[p]);
        out.writeInt(defaultFeatureIndex);
        if (featureSelection == null)
            out.writeInt(NULL_INTEGER);
        else
        {
            out.writeInt(1);
            out.writeObject(featureSelection);
        }
        if (perClassFeatureSelection == null)
            out.writeInt(NULL_INTEGER);
        else
        {
            out.writeInt(perClassFeatureSelection.length);
            for (int i = 0; i < perClassFeatureSelection.length; i++)
                if (perClassFeatureSelection[i] == null)
                    out.writeInt(NULL_INTEGER);
                else
                {
                    out.writeInt(1);
                    out.writeObject(perClassFeatureSelection[i]);
                }
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != CURRENT_SERIAL_VERSION)
            throw new ClassNotFoundException("Mismatched MCMaxEnt versions: wanted " +
                    CURRENT_SERIAL_VERSION + ", got " +
                    version);
        instancePipe = (Pipe) in.readObject();
        int np = in.readInt();
        parameters = new double[np];
        for (int p = 0; p < np; p++)
            parameters[p] = in.readDouble();
        defaultFeatureIndex = in.readInt();
        int opt = in.readInt();
        if (opt == 1)
            featureSelection = (FeatureSelection)in.readObject();
        int nfs = in.readInt();
        if (nfs >= 0)
        {
            perClassFeatureSelection = new FeatureSelection[nfs];
            for (int i = 0; i < nfs; i++)
            {
                opt = in.readInt();
                if (opt == 1)
                    perClassFeatureSelection[i] = (FeatureSelection)in.readObject();
            }
        }
    }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy