All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.WinnowTrainer Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */



/** 
   @author Aron Culotta [email protected]
*/

package cc.mallet.classify;

import cc.mallet.classify.Winnow;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;

/**
 * An implementation of the training methods of a 
 * Winnow2 on-line classifier. Given an instance xi,
 * the algorithm computes Sum(xi*wi), where wi is 
 * the weight for that feature in the given class. 
 * If the Sum is greater than some threshold 
 * {@link #theta theta}, then the classifier guess
 * true for that class. 
 * Only when the classifier makes a mistake are the 
 * weights updated in one of two steps:
 * Promote: guessed 0 and answer was 1. Multiply
 * all weights of present features by {@link #alpha alpha}.
 * Demote: guessed 1 and answer was 0. Divide
 * all weights of present features by {@link #beta beta}.
 *
 * Limitations: Winnow2 only considers binary feature
 * vectors (i.e. whether or not the feature is present,
 * not its value).
 */
public class WinnowTrainer extends ClassifierTrainer
{
	static final double DEFAULT_ALPHA = 2.0; 
	static final double DEFAULT_BETA = 2.0;   
	static final double DEFAULT_NFACTOR = .5;
	
	/**
	 *constant to multiply to "correct" weights in promotion step
	 */
	double alpha;
	/**
	 *constant to divide "incorrect" weights by in demotion step
	 */
	double beta;
	/**
	 *threshold for sum of wi*xi in formulating guess 
	 */
	double theta;
	/** 
	 *factor of n to set theta to. e.g. if n=1/2, theta = n/2.
	 */
	double nfactor;
	/**
	 *array of weights, one for each feature, initialized to 1
	 */
	double [][] weights;
	
	Winnow classifier;
	
	/**
	 * Default constructor. Sets all features to defaults.
	 */
	public WinnowTrainer(){
		this(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_NFACTOR);
	}
	
	/**
	 * Sets alpha and beta and default value for theta
	 * @param a alpha value
	 * @param b beta value
	 */
	public WinnowTrainer(double a, double b){
		this(a, b, DEFAULT_NFACTOR);
	}
	
	/**
	 * Sets alpha, beta, and nfactor
	 * @param a alpha value
	 * @param b beta value
	 * @param nfact nfactor value
	 */
	public WinnowTrainer(double a, double b, double nfact){
		this.alpha = a;
		this.beta = b;
		this.nfactor = nfact;
	}
	
	public Winnow getClassifier () { return classifier; }
	
	/**
	 * Trains winnow on the instance list, updating 
	 * {@link #weights weights} according to errors
	 * @param ilist Instance list to be trained on
	 * @return Classifier object containing learned weights
	 */
	public Winnow train (InstanceList trainingList)
	{
		FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
		if (selectedFeatures != null)
			// xxx Attend to FeatureSelection!!!
			throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");
		// if "train" is run more than once, 
		// we will be reinitializing the weights
		// TODO: provide method to save weights
		trainingList.getDataAlphabet().stopGrowth();
		trainingList.getTargetAlphabet().stopGrowth();
		Pipe dataPipe = trainingList.getPipe ();
		Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();
		int numLabels = trainingList.getTargetAlphabet().size();
		int numFeats = dict.size(); 
		this.theta =  numFeats * this.nfactor;
		this.weights = new double [numLabels][numFeats];
		// init weights to 1
		for(int i=0; i this.theta){ // guess 1
					if(correctIndex != ri) // correct is 0
				    demote(ri, fv);
				}
				else{ // guess 0
					if(correctIndex == ri) // correct is 1
						promote(ri, fv);   
				}
			}
//			System.out.println("Results guessed:")
//		for(int x=0; x promotion
		for(int fvi=0; fvi < fvisize; fvi++){
			int fi = fv.indexAtLocation(fvi);
			this.weights[lpos][fi] *= this.alpha;
		}		
	}

  /**
   *Demotes (by {@link #beta beta) the weights 
   * responsible for the incorrect guess
   * @param lpos index of incorrectly guessed label
   * @param fv feature vector
   */
  private void demote(int lpos, FeatureVector fv){
		int fvisize = fv.numLocations();
		// learner predicted 1, correct is 0 -> demotion
		for(int fvi=0; fvi < fvisize; fvi++){
			int fi = fv.indexAtLocation(fvi);
			this.weights[lpos][fi] /= this.beta;
		}		
	}
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy