
cc.mallet.classify.WinnowTrainer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta [email protected]
*/
package cc.mallet.classify;
import cc.mallet.classify.Winnow;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
/**
* An implementation of the training methods of a
* Winnow2 on-line classifier. Given an instance xi,
* the algorithm computes Sum(xi*wi), where wi is
* the weight for that feature in the given class.
* If the Sum is greater than some threshold
* {@link #theta theta}, then the classifier guess
* true for that class.
* Only when the classifier makes a mistake are the
* weights updated in one of two steps:
* Promote: guessed 0 and answer was 1. Multiply
* all weights of present features by {@link #alpha alpha}.
* Demote: guessed 1 and answer was 0. Divide
* all weights of present features by {@link #beta beta}.
*
* Limitations: Winnow2 only considers binary feature
* vectors (i.e. whether or not the feature is present,
* not its value).
*/
public class WinnowTrainer extends ClassifierTrainer
{
static final double DEFAULT_ALPHA = 2.0;
static final double DEFAULT_BETA = 2.0;
static final double DEFAULT_NFACTOR = .5;
/**
*constant to multiply to "correct" weights in promotion step
*/
double alpha;
/**
*constant to divide "incorrect" weights by in demotion step
*/
double beta;
/**
*threshold for sum of wi*xi in formulating guess
*/
double theta;
/**
*factor of n to set theta to. e.g. if n=1/2, theta = n/2.
*/
double nfactor;
/**
*array of weights, one for each feature, initialized to 1
*/
double [][] weights;
Winnow classifier;
/**
* Default constructor. Sets all features to defaults.
*/
public WinnowTrainer(){
this(DEFAULT_ALPHA, DEFAULT_BETA, DEFAULT_NFACTOR);
}
/**
* Sets alpha and beta and default value for theta
* @param a alpha value
* @param b beta value
*/
public WinnowTrainer(double a, double b){
this(a, b, DEFAULT_NFACTOR);
}
/**
* Sets alpha, beta, and nfactor
* @param a alpha value
* @param b beta value
* @param nfact nfactor value
*/
public WinnowTrainer(double a, double b, double nfact){
this.alpha = a;
this.beta = b;
this.nfactor = nfact;
}
public Winnow getClassifier () { return classifier; }
/**
* Trains winnow on the instance list, updating
* {@link #weights weights} according to errors
* @param ilist Instance list to be trained on
* @return Classifier object containing learned weights
*/
public Winnow train (InstanceList trainingList)
{
FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
if (selectedFeatures != null)
// xxx Attend to FeatureSelection!!!
throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");
// if "train" is run more than once,
// we will be reinitializing the weights
// TODO: provide method to save weights
trainingList.getDataAlphabet().stopGrowth();
trainingList.getTargetAlphabet().stopGrowth();
Pipe dataPipe = trainingList.getPipe ();
Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();
int numLabels = trainingList.getTargetAlphabet().size();
int numFeats = dict.size();
this.theta = numFeats * this.nfactor;
this.weights = new double [numLabels][numFeats];
// init weights to 1
for(int i=0; i this.theta){ // guess 1
if(correctIndex != ri) // correct is 0
demote(ri, fv);
}
else{ // guess 0
if(correctIndex == ri) // correct is 1
promote(ri, fv);
}
}
// System.out.println("Results guessed:")
// for(int x=0; x promotion
for(int fvi=0; fvi < fvisize; fvi++){
int fi = fv.indexAtLocation(fvi);
this.weights[lpos][fi] *= this.alpha;
}
}
/**
*Demotes (by {@link #beta beta) the weights
* responsible for the incorrect guess
* @param lpos index of incorrectly guessed label
* @param fv feature vector
*/
private void demote(int lpos, FeatureVector fv){
int fvisize = fv.numLocations();
// learner predicted 1, correct is 0 -> demotion
for(int fvi=0; fvi < fvisize; fvi++){
int fi = fv.indexAtLocation(fvi);
this.weights[lpos][fi] /= this.beta;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy