All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.BalancedWinnow Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;


/** 
 * Classification methods of BalancedWinnow algorithm.
 *
 * @see BalancedWinnowTrainer
 * @author Gary Huang [email protected]
 */
public class BalancedWinnow extends Classifier implements Serializable
{
    double [][] m_weights;
	
    /**
     * Passes along data pipe and weights from 
     * {@link #BalancedWinnowTrainer BalancedWinnowTrainer}
     * @param dataPipe needed for dictionary, labels, feature vectors, etc
     * @param weights weights calculated during training phase
     */
    public BalancedWinnow (Pipe dataPipe, double [][] weights)
    {
        super (dataPipe);
        m_weights = new double[weights.length][weights[0].length];
        for (int i = 0; i < weights.length; i++)
			for (int j = 0; j < weights[0].length; j++)
				m_weights[i][j] = weights[i][j];
    }
	
    /**
     * @return a copy of the weight vectors
     */
    public double[][] getWeights()
    {
        int numCols = m_weights[0].length;
        double[][] ret = new double[m_weights.length][numCols];
        for (int i = 0; i < ret.length; i++)
			System.arraycopy(m_weights[i], 0, ret[i], 0, numCols);
        return ret;
    }
	
    /**
     * Classifies an instance using BalancedWinnow's weights
     *
     * 

Returns a Classification containing the normalized * dot products between class weight vectors and the instance * feature vector. * *

One can obtain the confidence of the classification by * calculating weight(j')/weight(j), where j' is the * highest weight prediction and j is the 2nd-highest. * Another possibility is to calculate *

e^{dot(w_j', x} / sum_j[e^{dot(w_j, x)}]
*/ public Classification classify (Instance instance) { int numClasses = getLabelAlphabet().size(); int numFeats = getAlphabet().size(); double[] scores = new double[numClasses]; FeatureVector fv = (FeatureVector) instance.getData (); // Make sure the feature vector's feature dictionary matches // what we are expecting from our data pipe (and thus our notion // of feature probabilities. assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ()); int fvisize = fv.numLocations(); // Take dot products double sum = 0; for (int ci = 0; ci < numClasses; ci++) { for (int fvi = 0; fvi < fvisize; fvi++) { int fi = fv.indexAtLocation (fvi); double vi = fv.valueAtLocation(fvi); if ( m_weights[ci].length > fi ) { scores[ci] += vi * m_weights[ci][fi]; sum += vi * m_weights[ci][fi]; } } scores[ci] += m_weights[ci][numFeats]; sum += m_weights[ci][numFeats]; } MatrixOps.timesEquals(scores, 1.0 / sum); // Create and return a Classification object return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores)); } // Serialization // serialVersionUID is overriden to prevent innocuous changes in this // class from making the serialization mechanism think the external // format has changed. private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeObject(getInstancePipe()); // write weight vector for each class out.writeObject(m_weights); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched BalancedWinnow versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); instancePipe = (Pipe) in.readObject(); m_weights = (double[][]) in.readObject(); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy