All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.pipe.AddClassifierTokenPredictions Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


package cc.mallet.pipe;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.util.MalletLogger;


/**
 * This pipe uses a Classifier to label each token (i.e., using 0-th order Markov assumption), 
 * then adds the predictions as features to each token.
 * 
 * This pipe assumes the input Instance's data is of type FeatureVectorSequence 
 * (each an augmentable feature vector). 
 * 
 * Example usage:
 * 		1) Create and serialize a featurePipe that converts raw input to FeatureVectorSequences
 * 		2) Pipe input data through featurePipe, train a TokenClassifiers via cross validation, then serialize the classifiers
 * 		2) Pipe input data through featurePipe and this pipe (using the saved classifiers), and train a Transducer 
 * 		4) Serialize the trained Transducer 
 * 
* @author ghuang */ public class AddClassifierTokenPredictions extends Pipe implements Serializable { private static Logger logger = MalletLogger.getLogger(AddClassifierTokenPredictions.class.getName()); // Specify which predictions are to be added as features. // E.g., { 1, 2 } = add labels of the top 2 highest-scoring predictions as features. int[] m_predRanks2add; // The trained token classifier TokenClassifiers m_tokenClassifiers; // Whether to treat each instance's feature values as binary boolean m_binary; // Whether the pipe is currently being used at production time // (i.e., not being used as pipeline for training a transducer) boolean m_inProduction; // Augmented data alphabet that includes the class predictions Alphabet m_dataAlphabet; public AddClassifierTokenPredictions(InstanceList trainList) { this(trainList, null); } public AddClassifierTokenPredictions(InstanceList trainList, InstanceList testList) { this(new TokenClassifiers(convert(trainList, (Noop) trainList.getPipe())), new int[] { 1 }, true, convert(testList, (Noop) trainList.getPipe())); } public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add, boolean binary, InstanceList testList) { m_predRanks2add = predRanks2add; m_binary = binary; m_tokenClassifiers = tokenClassifiers; m_inProduction = false; m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone(); Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet(); // add the token prediction features to the alphabet for (int i = 0; i < m_predRanks2add.length; i++) { for (int j = 0; j < labelAlphabet.size(); j++) { String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "_@_RANK_" + m_predRanks2add[i]; m_dataAlphabet.lookupIndex(featName, true); } } // evaluate token classifier if (testList != null) { Trial trial = new Trial(m_tokenClassifiers, testList); logger.info("Token classifier accuracy on test set = " + trial.getAccuracy()); } } public void setInProduction(boolean inProduction) { m_inProduction = inProduction; } public boolean getInProduction() { return m_inProduction; } public static void setInProduction(Pipe p, boolean value) { if (p instanceof AddClassifierTokenPredictions) ((AddClassifierTokenPredictions) p).setInProduction(value); else if (p instanceof SerialPipes) { SerialPipes sp = (SerialPipes) p; for (int i = 0; i < sp.size(); i++) setInProduction(sp.getPipe(i), value); } } public Alphabet getDataAlphabet() { return m_dataAlphabet; } /** * Add the token classifier's predictions as features to the instance. * This method assumes the input instance contains FeatureVectorSequence as data */ public Instance pipe(Instance carrier) { FeatureVectorSequence fvs = (FeatureVectorSequence) carrier.getData(); InstanceList ilist = convert(carrier, (Noop) m_tokenClassifiers.getInstancePipe()); assert (fvs.size() == ilist.size()); // For passing instances to the token classifier, each instance's data alphabet needs to // match that used by the token classifier at training time. For the resulting piped // instance, each instance's data alphabet needs to contain token classifier's prediction // as features FeatureVector[] fva = new FeatureVector[fvs.size()]; for (int i = 0; i < ilist.size(); i++) { Instance inst = ilist.get(i); Classification c = m_tokenClassifiers.classify(inst, ! m_inProduction); LabelVector lv = c.getLabelVector(); AugmentableFeatureVector afv1 = (AugmentableFeatureVector) inst.getData(); int[] indices = afv1.getIndices(); AugmentableFeatureVector afv2 = new AugmentableFeatureVector(m_dataAlphabet, indices, afv1.getValues(), indices.length + m_predRanks2add.length); for (int j = 0; j < m_predRanks2add.length; j++) { Label label = lv.getLabelAtRank(m_predRanks2add[j]); int idx = m_dataAlphabet.lookupIndex("TOK_PRED=" + label.toString() + "_@_RANK_" + m_predRanks2add[j]); assert(idx >= 0); afv2.add(idx, 1); } fva[i] = afv2; } carrier.setData(new FeatureVectorSequence(fva)); return carrier; } /** * Converts each instance containing a FeatureVectorSequence to multiple instances, * each containing an AugmentableFeatureVector as data. * * @param ilist Instances with FeatureVectorSequence as data field * @param alphabetsPipe a Noop pipe containing the data and target alphabets for the resulting InstanceList * @return an InstanceList where each Instance contains one Token's AugmentableFeatureVector as data */ public static InstanceList convert(InstanceList ilist, Noop alphabetsPipe) { if (ilist == null) return null; // This monstrosity is necessary b/c Classifiers obtain the data/target alphabets via pipes InstanceList ret = new InstanceList(alphabetsPipe); for (Instance inst : ilist) ret.add(inst); //for (int i = 0; i < ilist.size(); i++) ret.add(convert(ilist.get(i), alphabetsPipe)); return ret; } /** * * @param inst input instance, with FeatureVectorSequence as data. * @param alphabetsPipe a Noop pipe containing the data and target alphabets for * the resulting InstanceList and AugmentableFeatureVectors * @return list of instances, each with one AugmentableFeatureVector as data */ public static InstanceList convert(Instance inst, Noop alphabetsPipe) { InstanceList ret = new InstanceList(alphabetsPipe); Object obj = inst.getData(); assert(obj instanceof FeatureVectorSequence); FeatureVectorSequence fvs = (FeatureVectorSequence) obj; LabelSequence ls = (LabelSequence) inst.getTarget(); assert(fvs.size() == ls.size()); Object instName = (inst.getName() == null ? "NONAME" : inst.getName()); for (int j = 0; j < fvs.size(); j++) { FeatureVector fv = fvs.getFeatureVector(j); int[] indices = fv.getIndices(); FeatureVector data = new AugmentableFeatureVector (alphabetsPipe.getDataAlphabet(), indices, fv.getValues(), indices.length); Labeling target = ls.getLabelAtPosition(j); String name = instName.toString() + "_@_POS_" + (j + 1); Object source = inst.getSource(); Instance toAdd = alphabetsPipe.pipe(new Instance(data, target, name, source)); ret.add(toAdd); } return ret; } // Serialization private static final long serialVersionUID = 1; /** * This inner class represents the trained token classifiers. * @author ghuang */ public static class TokenClassifiers extends Classifier implements Serializable { // number of folds in cross-validation training int m_numCV; // random seed to split training data for cross-validation int m_randSeed; // trainer for token classifier ClassifierTrainer m_trainer; // token classifier trained on the entirety of the training set Classifier m_tokenClassifier; // table storing instance name --> out-of-fold classifier // Used to prevent overfitting to the token classifier's predictions HashMap m_table; /** * Train a token classifier using the given Instances with 5-fold cross validation * @param trainList training instances */ public TokenClassifiers(InstanceList trainList) { this(trainList, 0, 5); } public TokenClassifiers(InstanceList trainList, int randSeed, int numCV) { // this(new AdaBoostM2Trainer(new DecisionTreeTrainer(2), 10), trainList, randSeed, numCV); // this(new NaiveBayesTrainer(), trainList, randSeed, numCV); this(new BalancedWinnowTrainer(), trainList, randSeed, numCV); // this(new SVMTrainer(), trainList, randSeed, numCV); } public TokenClassifiers(ClassifierTrainer trainer, InstanceList trainList, int randSeed, int numCV) { super(trainList.getPipe()); m_trainer = trainer; m_randSeed = randSeed; m_numCV = numCV; m_table = new HashMap(); doTraining(trainList); } // train the token classifier private void doTraining(InstanceList trainList) { // train a classifier on the entire training set logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")..."); m_tokenClassifier = m_trainer.train(trainList); Trial t = new Trial(m_tokenClassifier, trainList); logger.info("Training set accuracy = " + t.getAccuracy()); if (m_numCV == 0) return; // train classifiers using cross validation InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed); int f = 1; while (cvIter.hasNext()) { f++; InstanceList[] fold = cvIter.nextSplit(); logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")..."); Classifier foldClassifier = m_trainer.train(fold[0]); Trial t1 = new Trial(foldClassifier, fold[0]); Trial t2 = new Trial(foldClassifier, fold[1]); logger.info("Within-fold accuracy = " + t1.getAccuracy()); logger.info("Out-of-fold accuracy = " + t2.getAccuracy()); /*for (int x = 0; x < t2.size(); x++) { logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling()); }*/ for (int i = 0; i < fold[1].size(); i++) { Instance inst = fold[1].get(i); m_table.put(inst.getName(), foldClassifier); } } } public Classification classify(Instance instance) { return classify(instance, false); } /** * * @param instance the instance to classify * @param useOutOfFold whether to check the instance name and use the out-of-fold classifier * if the instance name matches one in the training data * @return the token classifier's output */ public Classification classify(Instance instance, boolean useOutOfFold) { Object instName = instance.getName(); if (! useOutOfFold || ! m_table.containsKey(instName)) return m_tokenClassifier.classify(instance); Classifier classifier = (Classifier) m_table.get(instName); return classifier.classify(instance); } // serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 1; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeObject(getInstancePipe()); out.writeInt(m_numCV); out.writeInt(m_randSeed); out.writeObject(m_table); out.writeObject(m_tokenClassifier); out.writeObject(m_trainer); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); if (version != CURRENT_SERIAL_VERSION) throw new ClassNotFoundException("Mismatched TokenClassifiers versions: wanted " + CURRENT_SERIAL_VERSION + ", got " + version); instancePipe = (Pipe) in.readObject(); m_numCV = in.readInt(); m_randSeed = in.readInt(); m_table = (HashMap) in.readObject(); m_tokenClassifier = (Classifier) in.readObject(); m_trainer = (ClassifierTrainer) in.readObject(); } } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy