All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.ner.CMMClassifier Maven / Gradle / Ivy

// CMMClassifier -- a probabilistic (CMM) Named Entity Recognizer
// Copyright (c) 2002-2006 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program has been made available for research purposes only.
// Please do not further distribute it.
// Commercial development of the software is not to be undertaken without
// prior agreement from Stanford University.
// This program is not open source nor is it in the public domain.
//
// For information contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    [email protected]

package edu.stanford.nlp.ie.ner;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.NBLinearClassifierFactory;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.classify.SVMLightClassifierFactory;
import edu.stanford.nlp.ie.AbstractSequenceClassifier;
import edu.stanford.nlp.ie.NERFeatureFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.Document;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.WordTag;
import edu.stanford.nlp.ling.CoreAnnotations.AnswerAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.GazAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.IDAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.PositionAnnotation;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.process.DocumentProcessor;
import edu.stanford.nlp.process.ListProcessor;
import edu.stanford.nlp.sequences.BeamBestSequenceFinder;
import edu.stanford.nlp.sequences.Clique;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.ExactBestSequenceFinder;
import edu.stanford.nlp.sequences.FeatureFactory;
import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.PaddedList;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;


/**
 * Does Sequence Classification using a Conditional Markov Model.
 * It could be used for other purposes, but the provided features
 * are aimed at doing Named Entity Recognition.
 * The code has functionality for different document encodings, but when
 * using the standard ColumnDocumentReader,
 * input files are expected to
 * be one word per line with the columns indicating things like the word,
 * POS, chunk, and class.
 * 

* Typical usage *

For running a trained model with a provided serialized classifier:

* * java -server -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -loadClassifier * conll.ner.gz -textFile samplesentences.txt *

* When specifying all parameters in a properties file (train, test, or * runtime):

* * java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -prop propFile *

* To train and test a model from the command line:

* java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier * -trainFile trainFile -testFile testFile -goodCoNLL > output *

* Features are defined by a {@link FeatureFactory}; the * {@link FeatureFactory} which is used by default is * {@link NERFeatureFactory}, and you should look there for feature templates. * Features are specified either by a Properties file (which is the * recommended method) or on the command line. The features are read into * a {@link SeqClassifierFlags} object, which the * user need not know much about, unless one wishes to add new features. *

* CMMClassifier may also be used programmatically. When creating a new instance, you * must specify a properties file. The other way to get a CMMClassifier is to * deserialize one via {@link CMMClassifier#getClassifier(String)}, which returns a * deserialized classifier. You may then tag sentences using either the assorted * test or testSentence methods. * * @author Dan Klein * @author Jenny Finkel * @author Christopher Manning * @author Shipra Dingare * @author Huy Nguyen * @author Sarah Spikes ([email protected]) - cleanup and filling in types */ @SuppressWarnings("rawtypes") public class CMMClassifier extends AbstractSequenceClassifier implements DocumentProcessor, ListProcessor { private ProbabilisticClassifier classifier; /** The set of empirically legal label sequences (of length (order) at most * flags.maxLeft). Used to filter valid class sequences if * useObuseObservedSequencesOnly is set. */ Set> answerArrays; /** Default place to look in Jar file for classifier. */ public static final String DEFAULT_CLASSIFIER = "/classifiers/ner-eng-ie.cmm-3-all2006.ser.gz"; protected CMMClassifier() { super(new SeqClassifierFlags()); } public CMMClassifier(Properties props) { super(props); } /** * Returns the Set of entities recognized by this Classifier. * * @return The Set of entities recognized by this Classifier. */ public Set getTags() { Set tags = new HashSet(classIndex.objectsList()); tags.remove(flags.backgroundSymbol); return tags; } /** * Classify a {@link List} of {@link CoreLabel}s. * * @param document A {@link List} of {@link CoreLabel}s * to be classified. */ @Override public List classify(List document) { if (flags.useSequences) { classifySeq(document); } else { classifyNoSeq(document); } return document; } /** * Classify a List of {@link CoreLabel}s without using sequence information * (i.e. no Viterbi algorithm, just distribution over next class). * * @param document a List of {@link CoreLabel}s to be classified */ private void classifyNoSeq(List document) { if (flags.useReverse) { Collections.reverse(document); } if (flags.lowerNewgeneThreshold) { // Used to raise recall for task 1B System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold); for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel wordInfo = document.get(i); Datum d = makeDatum(document, i, featureFactory); Counter scores = classifier.scoresOf(d); //String answer = BACKGROUND; String answer = flags.backgroundSymbol; // HN: The evaluation of scoresOf seems to result in some // kind of side effect. Specifically, the symptom is that // if scoresOf is not evaluated at every position, the // answers are different if ("NEWGENE".equals(wordInfo.get(GazAnnotation.class))) { for (String label : scores.keySet()) { if ("G".equals(label)) { System.err.println(wordInfo.word() + ':' + scores.getCount(label)); if (scores.getCount(label) > flags.newgeneThreshold) { answer = label; } } } } wordInfo.set(AnswerAnnotation.class, answer); } } else { for (int i = 0, listSize = document.size(); i < listSize; i++) { String answer = classOf(document, i); CoreLabel wordInfo = document.get(i); //System.err.println("XXX answer for " + // wordInfo.word() + " is " + answer); wordInfo.set(AnswerAnnotation.class, answer); } if (flags.justify && (classifier instanceof LinearClassifier)) { LinearClassifier lc = (LinearClassifier) classifier; for (int i = 0, lsize = document.size(); i < lsize; i++) { CoreLabel lineInfo = document.get(i); System.err.print("@@ Position " + i + ": "); System.err.println(lineInfo.word() + " chose " + lineInfo.get(AnswerAnnotation.class)); lc.justificationOf(makeDatum(document, i, featureFactory)); } } } if (flags.useReverse) { Collections.reverse(document); } } /** * Returns the most likely class for the word at the given position. */ protected String classOf(List lineInfos, int pos) { Datum d = makeDatum(lineInfos, pos, featureFactory); return classifier.classOf(d); } /** * Returns the log conditional likelihood of the given dataset. * * @return The log conditional likelihood of the given dataset. */ public double loglikelihood(List lineInfos) { double cll = 0.0; for (int i = 0; i < lineInfos.size(); i++) { Datum d = makeDatum(lineInfos, i, featureFactory); Counter c = classifier.logProbabilityOf(d); double total = Double.NEGATIVE_INFINITY; for (String s : c.keySet()) { total = SloppyMath.logAdd(total, c.getCount(s)); } cll -= c.getCount(d.label()) - total; } // quadratic prior // HN: TODO: add other priors if (classifier instanceof LinearClassifier) { double sigmaSq = flags.sigma * flags.sigma; LinearClassifier lc = (LinearClassifier)classifier; for (String feature: lc.features()) { for (String classLabel: classIndex) { double w = lc.weight(feature, classLabel); cll += w * w / 2.0 / sigmaSq; } } } return cll; } @Override public SequenceModel getSequenceModel(List document) { //System.err.println(flags.useReverse); if (flags.useReverse) { Collections.reverse(document); } // cdm Aug 2005: why is this next line needed? Seems really ugly!!! [2006: it broke things! removed] // document.add(0, new CoreLabel()); SequenceModel ts = new Scorer(document, classIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), (flags.useNextSequences ? 1 : 0), answerArrays); return ts; } /** * Classify a List of {@link CoreLabel}s using sequence information * (i.e. Viterbi or Beam Search). * * @param document A List of {@link CoreLabel}s to be classified */ private void classifySeq(List document) { if (document.isEmpty()) { return; } SequenceModel ts = getSequenceModel(document); // TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays); int[] tags; //System.err.println("***begin test***"); if (flags.useViterbi) { ExactBestSequenceFinder ti = new ExactBestSequenceFinder(); tags = ti.bestSequence(ts); } else { BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true); tags = ti.bestSequence(ts, document.size()); } //System.err.println("***end test***"); // used to improve recall in task 1b if (flags.lowerNewgeneThreshold) { System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold); int[] copy = new int[tags.length]; System.arraycopy(tags, 0, copy, 0, tags.length); // for each sequence marked as NEWGENE in the gazette // tag the entire sequence as NEWGENE and sum the score // if the score is greater than newgeneThreshold, accept int ngTag = classIndex.indexOf("G"); //int bgTag = classIndex.indexOf(BACKGROUND); int bgTag = classIndex.indexOf(flags.backgroundSymbol); for (int i = 0, dSize = document.size(); i < dSize; i++) { CoreLabel wordInfo =document.get(i); if ("NEWGENE".equals(wordInfo.get(GazAnnotation.class))) { int start = i; int j; for (j = i; j < document.size(); j++) { wordInfo = document.get(j); if (!"NEWGENE".equals(wordInfo.get(GazAnnotation.class))) { break; } } int end = j; //int end = i + 1; int winStart = Math.max(0, start - 4); int winEnd = Math.min(tags.length, end + 4); // clear a window around the sequences for (j = winStart; j < winEnd; j++) { copy[j] = bgTag; } // score as nongene double bgScore = 0.0; for (j = start; j < end; j++) { double[] scores = ts.scoresOf(copy, j); scores = Scorer.recenter(scores); bgScore += scores[bgTag]; } // first pass, compute all of the scores ClassicCounter> prevScores = new ClassicCounter>(); for (j = start; j < end; j++) { // clear the sequence for (int k = start; k < end; k++) { copy[k] = bgTag; } // grow the sequence from j until the end for (int k = j; k < end; k++) { copy[k] = ngTag; // score the sequence double ngScore = 0.0; for (int m = start; m < end; m++) { double[] scores = ts.scoresOf(copy, m); scores = Scorer.recenter(scores); ngScore += scores[tags[m]]; } prevScores.incrementCount(new Pair(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore); } } for (j = start; j < end; j++) { // grow the sequence from j until the end for (int k = j; k < end; k++) { double score = prevScores.getCount(new Pair(Integer.valueOf(j), Integer.valueOf(k))); Pair al = new Pair(Integer.valueOf(j - 1), Integer.valueOf(k)); // adding a word to the left Pair ar = new Pair(Integer.valueOf(j), Integer.valueOf(k + 1)); // adding a word to the right Pair sl = new Pair(Integer.valueOf(j + 1), Integer.valueOf(k)); // subtracting word from left Pair sr = new Pair(Integer.valueOf(j), Integer.valueOf(k - 1)); // subtracting word from right // make sure the score is greater than all its neighbors (one add or subtract) if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) { StringBuilder sb = new StringBuilder(); wordInfo = document.get(j); String docId = wordInfo.get(IDAnnotation.class); String startIndex = wordInfo.get(PositionAnnotation.class); wordInfo = document.get(k); String endIndex = wordInfo.get(PositionAnnotation.class); for (int m = j; m <= k; m++) { wordInfo = document.get(m); sb.append(wordInfo.word()); sb.append(' '); } /*System.err.println(sb.toString()+"score:"+score+ " al:"+prevScores.getCount(al)+ " ar:"+prevScores.getCount(ar)+ " sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/ System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim()); } } } // restore the original tags for (j = winStart; j < winEnd; j++) { copy[j] = tags[j]; } i = end; } } } for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel lineInfo = document.get(i); String answer = classIndex.get(tags[i]); lineInfo.set(AnswerAnnotation.class, answer); } if (flags.justify && classifier instanceof LinearClassifier) { LinearClassifier lc = (LinearClassifier) classifier; if (flags.dump) { lc.dump(); } for (int i = 0, docSize = document.size(); i < docSize; i++) { CoreLabel lineInfo = document.get(i); System.err.print("@@ Position is: " + i + ": "); System.err.println(lineInfo.word() + ' ' + lineInfo.get(AnswerAnnotation.class)); lc.justificationOf(makeDatum(document, i, featureFactory)); } } // document.remove(0); if (flags.useReverse) { Collections.reverse(document); } } // end testSeq /** * @param filename adaptation file * @param trainDataset original dataset (used in training) */ public void adapt(String filename, Dataset trainDataset, DocumentReaderAndWriter readerWriter) { flags.ocrTrain = false; // ?? Do we need this? (Pi-Chuan Sat Nov 5 15:42:49 2005) ObjectBank> docs = makeObjectBankFromFile(filename, readerWriter); adapt(docs, trainDataset); } /** * @param featureLabels adaptation docs * @param trainDataset original dataset (used in training) */ public void adapt(ObjectBank> featureLabels, Dataset trainDataset) { Dataset adapt = getDataset(featureLabels, trainDataset); adapt(adapt); } /** * @param featureLabels retrain docs * @param featureIndex featureIndex of original dataset (used in training) * @param labelIndex labelIndex of original dataset (used in training) */ public void retrain(ObjectBank> featureLabels, Index featureIndex, Index labelIndex) { int fs = featureIndex.size(); // old dim int ls = labelIndex.size(); // old dim Dataset adapt = getDataset(featureLabels, featureIndex, labelIndex); int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); LinearClassifier lc = (LinearClassifier) classifier; LinearClassifierFactory lcf = new LinearClassifierFactory(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); double[][] weights = lc.weights(); // old dim Index newF = adapt.featureIndex; Index newL = adapt.labelIndex; int newFS = newF.size(); int newLS = newL.size(); double[] x = new double[newFS*newLS]; // new dim //System.err.println("old ["+fs+"]"+"["+ls+"]"); //System.err.println("new ["+newFS+"]"+"["+newLS+"]"); //System.err.println("new ["+newFS*newLS+"]"); for (int i = 0; i < fs; i++) { for (int j = 0; j < ls; j++) { String f = featureIndex.get(i); String l = labelIndex.get(j); int newi = newF.indexOf(f)*newLS+newL.indexOf(l); x[newi] = weights[i][j]; //if (newi == 144745*2) { //System.err.println("What??"+i+"\t"+j); //} } } //System.err.println("x[144745*2]"+x[144745*2]); weights = lcf.trainWeights(adapt, x); //System.err.println("x[144745*2]"+x[144745*2]); //System.err.println("weights[144745]"+"[0]="+weights[144745][0]); lc.setWeights(weights); /* int delme = 0; if (true) { for (double[] dd : weights) { delme++; for (double d : dd) { } } } System.err.println(weights[delme-1][0]); System.err.println("size of weights: "+delme); */ } public void retrain(ObjectBank> doc) { if (classifier == null) { System.err.println("Cannot retrain before you train!"); System.exit(-1); } Index findex = ((LinearClassifier)classifier).featureIndex(); Index lindex = ((LinearClassifier)classifier).labelIndex(); System.err.println("Starting retrain:\t# of original features"+findex.size()+", # of original labels"+lindex.size()); retrain(doc, findex, lindex); } @Override public void train(Collection> wordInfos, DocumentReaderAndWriter readerAndWriter) { Dataset train = getDataset(wordInfos); //train.summaryStatistics(); //train.printSVMLightFormat(); // wordInfos = null; // cdm: I think this does no good as ptr exists in caller (could empty the list or better refactor so conversion done earlier?) train(train); for (int i = 0; i < flags.numTimesPruneFeatures; i++) { Index featuresAboveThreshhold = getFeaturesAboveThreshhold(train, flags.featureDiffThresh); System.err.println("Removing features with weight below " + flags.featureDiffThresh + " and retraining..."); train = getDataset(train, featuresAboveThreshhold); int tmp = flags.QNsize; flags.QNsize = flags.QNsize2; train(train); flags.QNsize = tmp; } if (flags.doAdaptation && flags.adaptFile != null) { adapt(flags.adaptFile,train,readerAndWriter); } System.err.print("Built this classifier: "); if (classifier instanceof LinearClassifier) { String classString = ((LinearClassifier)classifier).toString(flags.printClassifier, flags.printClassifierParam); System.err.println(classString); } else { String classString = classifier.toString(); System.err.println(classString); } } public Index getFeaturesAboveThreshhold(Dataset dataset, double thresh) { if (!(classifier instanceof LinearClassifier)) { throw new RuntimeException("Attempting to remove features based on weight from a non-linear classifier"); } Index featureIndex = dataset.featureIndex; Index labelIndex = dataset.labelIndex; Index features = new HashIndex(); Iterator featureIt = featureIndex.iterator(); LinearClassifier lc = (LinearClassifier)classifier; LOOP: while (featureIt.hasNext()) { String f = featureIt.next(); Iterator labelIt = labelIndex.iterator(); double smallest = Double.POSITIVE_INFINITY; double biggest = Double.NEGATIVE_INFINITY; while (labelIt.hasNext()) { String l = labelIt.next(); double weight = lc.weight(f, l); if (weight < smallest) { smallest = weight; } if (weight > biggest) { biggest = weight; } if (biggest - smallest > thresh) { features.add(f); continue LOOP; } } } return features; } /** * Build a Dataset from some data. Used for training a classifier. * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset getDataset(Collection> data) { return getDataset(data, null, null); } /** * Build a Dataset from some data. Used for training a classifier. * * By passing in extra featureIndex and classIndex, you can get a Dataset based on featureIndex and * classIndex * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @param classIndex if you want to get a Dataset based on featureIndex and * classIndex in an existing origDataset * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset getDataset(Collection> data, Index featureIndex, Index classIndex) { makeAnswerArraysAndTagIndex(data); int size = 0; for (List doc : data) { size += doc.size(); } System.err.println("Making Dataset..."); Dataset train; if (featureIndex != null && classIndex != null) { System.err.println("Using feature/class Index from existing Dataset..."); System.err.println("(This is used when getting Dataset from adaptation set. We want to make the index consistent.)"); //pichuan train = new Dataset(size, featureIndex, classIndex); } else { train = new Dataset(size); } for (List doc : data) { if (flags.useReverse) { Collections.reverse(doc); } for (int i = 0, dsize = doc.size(); i < dsize; i++) { Datum d = makeDatum(doc, i, featureFactory); //CoreLabel fl = doc.get(i); train.add(d); } if (flags.useReverse) { Collections.reverse(doc); } } System.err.println("done."); // reset printing before test data // what is this???? -JRF // if (featureFactory instanceof FeatureFactory) { // ((FeatureFactory) featureFactory).resetPrintFeatures(); // } if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } public Dataset getBiasedDataset(ObjectBank> data, Index featureIndex, Index classIndex) { makeAnswerArraysAndTagIndex(data); Index origFeatIndex = new HashIndex(featureIndex.objectsList()); // mg2009: TODO: check int size = 0; for (List doc : data) { size += doc.size(); } System.err.println("Making Dataset..."); Dataset train = new Dataset(size, featureIndex, classIndex); for (List doc : data) { if (flags.useReverse) { Collections.reverse(doc); } for (int i = 0, dsize = doc.size(); i < dsize; i++) { Datum d = makeDatum(doc, i, featureFactory); Collection newFeats = new ArrayList(); for (String f : d.asFeatures()) { if ( ! origFeatIndex.contains(f)) { newFeats.add(f); } } // System.err.println(d.label()+"\t"+d.asFeatures()+"\n\t"+newFeats); // d = new BasicDatum(newFeats, d.label()); train.add(d); } if (flags.useReverse) { Collections.reverse(doc); } } System.err.println("done."); // reset printing before test data // what is this???? -JRF // if (featureFactory instanceof FeatureFactory) { // ((FeatureFactory) featureFactory).resetPrintFeatures(); // } if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } /** * Build a Dataset from some data. Used for training a classifier. * * By passing in an extra origDataset, you can get a Dataset based on featureIndex and * classIndex in an existing origDataset. * * @param data This variable is a list of lists of CoreLabel. That is, * it is a collection of documents, each of which is represented * as a sequence of CoreLabel objects. * @param origDataset if you want to get a Dataset based on featureIndex and * classIndex in an existing origDataset * @return The Dataset which is an efficient encoding of the information * in a List of Datums */ public Dataset getDataset(ObjectBank> data, Dataset origDataset) { if(origDataset == null) { return getDataset(data); } return getDataset(data, origDataset.featureIndex, origDataset.labelIndex); } /** * Build a Dataset from some data. * * @param oldData This {@link Dataset} represents data for which we which to * some features, specifically those features not in the {@link edu.stanford.nlp.util.Index} * goodFeatures. * @param goodFeatures An {@link edu.stanford.nlp.util.Index} of features we wish to retain. * @return A new {@link Dataset} wheres each datapoint contains only features * which were in goodFeatures. */ public Dataset getDataset(Dataset oldData, Index goodFeatures) { //public Dataset getDataset(List data, Collection goodFeatures) { //makeAnswerArraysAndTagIndex(data); int[][] oldDataArray = oldData.getDataArray(); int[] oldLabelArray = oldData.getLabelsArray(); Index oldFeatureIndex = oldData.featureIndex; int[] oldToNewFeatureMap = new int[oldFeatureIndex.size()]; int[][] newDataArray = new int[oldDataArray.length][]; System.err.print("Building reduced dataset..."); int size = oldFeatureIndex.size(); int max = 0; for (int i = 0; i < size; i++) { oldToNewFeatureMap[i] = goodFeatures.indexOf(oldFeatureIndex.get(i)); if (oldToNewFeatureMap[i] > max) { max = oldToNewFeatureMap[i]; } } for (int i = 0; i < oldDataArray.length; i++) { int[] data = oldDataArray[i]; size = 0; for (int j = 0; j < data.length; j++) { if (oldToNewFeatureMap[data[j]] > 0) { size++; } } int[] newData = new int[size]; int index = 0; for (int j = 0; j < data.length; j++) { int f = oldToNewFeatureMap[data[j]]; if (f > 0) { newData[index++] = f; } } newDataArray[i] = newData; } Dataset train = new Dataset(oldData.labelIndex, oldLabelArray, goodFeatures, newDataArray, newDataArray.length); System.err.println("done."); if (flags.featThreshFile != null) { System.err.println("applying thresholds..."); List> thresh = getThresholds(flags.featThreshFile); train.applyFeatureCountThreshold(thresh); } else if (flags.featureThreshold > 1) { System.err.println("Removing Features with counts < " + flags.featureThreshold); train.applyFeatureCountThreshold(flags.featureThreshold); } train.summaryStatistics(); return train; } private void adapt(Dataset adapt) { if (flags.classifierType.equalsIgnoreCase("SVM")) { throw new UnsupportedOperationException(); } adaptMaxEnt(adapt); } private void adaptMaxEnt(Dataset adapt) { if (classifier instanceof LinearClassifier) { // So far the adaptation is only done on Gaussian Prior. Haven't checked how it'll work on other kinds of priors. -pichuan int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { throw new UnsupportedOperationException(); } else if (flags.useQuartic) { throw new UnsupportedOperationException(); } LinearClassifierFactory lcf = new LinearClassifierFactory(flags.tolerance, flags.useSum, prior, flags.adaptSigma, flags.epsilon, flags.QNsize); ((LinearClassifier)classifier).adaptWeights(adapt,lcf); } else { throw new UnsupportedOperationException(); } } private void train(Dataset train) { if (flags.classifierType.equalsIgnoreCase("SVM")) { trainSVM(train); } else { trainMaxEnt(train); } } private void trainSVM(Dataset train) { SVMLightClassifierFactory fact = new SVMLightClassifierFactory(); classifier = fact.trainClassifier(train); } private void trainMaxEnt(Dataset train) { int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { prior = LogPrior.LogPriorType.HUBER.ordinal(); } else if (flags.useQuartic) { prior = LogPrior.LogPriorType.QUARTIC.ordinal(); } LinearClassifier lc; if (flags.useNB) { lc = new NBLinearClassifierFactory(flags.sigma).trainClassifier(train); } else { LinearClassifierFactory lcf = new LinearClassifierFactory(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); if (flags.useQN) { lcf.useQuasiNewton(flags.useRobustQN); } else if(flags.useStochasticQN) { lcf.useStochasticQN(flags.initialGain,flags.stochasticBatchSize); } else if(flags.useSMD) { lcf.useStochasticMetaDescent(flags.initialGain, flags.stochasticBatchSize,flags.stochasticMethod,flags.SGDPasses); } else if(flags.useSGD) { lcf.useStochasticGradientDescent(flags.gainSGD,flags.stochasticBatchSize); } else if(flags.useSGDtoQN) { lcf.useStochasticGradientDescentToQuasiNewton(flags.initialGain, flags.stochasticBatchSize, flags.SGDPasses, flags.QNPasses, flags.SGD2QNhessSamples, flags.QNsize, flags.outputIterationsToFile); } else if(flags.useHybrid) { lcf.useHybridMinimizer(flags.initialGain, flags.stochasticBatchSize ,flags.stochasticMethod ,flags.hybridCutoffIteration ); } else { lcf.useConjugateGradientAscent(); } lc = lcf.trainClassifier(train); } this.classifier = lc; } private void trainSemiSup(Dataset data, Dataset biasedData, double[][] confusionMatrix) { int prior = LogPrior.LogPriorType.QUADRATIC.ordinal(); if (flags.useHuber) { prior = LogPrior.LogPriorType.HUBER.ordinal(); } else if (flags.useQuartic) { prior = LogPrior.LogPriorType.QUARTIC.ordinal(); } LinearClassifierFactory lcf; lcf = new LinearClassifierFactory(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize); if (flags.useQN) { lcf.useQuasiNewton(); } else{ lcf.useConjugateGradientAscent(); } this.classifier = (LinearClassifier) lcf.trainClassifierSemiSup(data, biasedData, confusionMatrix, null); } // public void crossValidateTrainAndTest() throws Exception { // crossValidateTrainAndTest(flags.trainFile); // } // public void crossValidateTrainAndTest(String filename) throws Exception { // // wordshapes // for (int fold = flags.startFold; fold <= flags.endFold; fold++) { // System.err.println("fold " + fold + " of " + flags.endFold); // // train // List = makeObjectBank(filename); // List folds = split(data, flags.numFolds); // data = null; // List train = new ArrayList(); // for (int i = 0; i < flags.numFolds; i++) { // List docs = (List) folds.get(i); // if (i != fold) { // train.addAll(docs); // } // } // folds = null; // train(train); // train = null; // List test = new ArrayList(); // data = makeObjectBank(filename); // folds = split(data, flags.numFolds); // data = null; // for (int i = 0; i < flags.numFolds; i++) { // List docs = (List) folds.get(i); // if (i == fold) { // test.addAll(docs); // } // } // folds = null; // // test // test(test); // writeAnswers(test); // } // } // /** // * Splits the given train corpus into a train and a test corpus based on the fold number. // * 1 / numFolds documents are held out for test, with the offset determined by the fold number. // * // * @param data The original data // * @param numFolds The number of folds to split the data into // * @return A list of folds giving the new training set // */ // private List split(List data, int numFolds) { // List folds = new ArrayList(); // int foldSize = data.size() / numFolds; // int r = data.size() - (numFolds * foldSize); // int index = 0; // for (int i = 0; i < numFolds; i++) { // List fold = new ArrayList(); // int end = (i < r ? foldSize + 1 : foldSize); // for (int j = 0; j < end; j++) { // fold.add(data.get(index++)); // } // folds.add(fold); // } // return folds; // } @Override public void serializeClassifier(String serializePath) { System.err.print("Serializing classifier to " + serializePath + "..."); try { ObjectOutputStream oos = IOUtils.writeStreamFromString(serializePath); oos.writeObject(classifier); oos.writeObject(flags); oos.writeObject(featureFactory); oos.writeObject(classIndex); oos.writeObject(answerArrays); //oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords()); oos.writeObject(knownLCWords); oos.close(); System.err.println("Done."); } catch (Exception e) { System.err.println("Error serializing to " + serializePath); e.printStackTrace(); // dont actually exit in case they're testing too //System.exit(1); } } /** * Used to load the default supplied classifier. **THIS FUNCTION * WILL ONLY WORK IF RUN INSIDE A JAR FILE** */ public void loadDefaultClassifier() { loadJarClassifier(DEFAULT_CLASSIFIER, null); } /** * Used to obtain the default classifier which is * stored inside a jar file. THIS FUNCTION * WILL ONLY WORK IF RUN INSIDE A JAR FILE. * * @return A Default CMMClassifier from a jar file */ public static CMMClassifier getDefaultClassifier() { CMMClassifier cmm = new CMMClassifier(); cmm.loadDefaultClassifier(); return cmm; } /** Load a classifier from the given Stream. * Implementation note: This method does not close the * Stream that it reads from. * * @param ois The ObjectInputStream to load the serialized classifier from * * @throws IOException If there are problems accessing the input stream * @throws ClassCastException If there are problems interpreting the serialized data * @throws ClassNotFoundException If there are problems interpreting the serialized data * */ @SuppressWarnings("unchecked") @Override public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException, ClassNotFoundException { classifier = (LinearClassifier) ois.readObject(); flags = (SeqClassifierFlags) ois.readObject(); featureFactory = (FeatureFactory) ois.readObject(); if (props != null) { flags.setProperties(props); } reinit(); classIndex = (Index) ois.readObject(); answerArrays = (Set>) ois.readObject(); knownLCWords = (Set) ois.readObject(); } public static CMMClassifier getClassifierNoExceptions(File file) { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifierNoExceptions(file); return cmm; } public static CMMClassifier getClassifier(File file) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifier(file); return cmm; } public static CMMClassifier getClassifierNoExceptions(String loadPath) { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifierNoExceptions(loadPath); return cmm; } public static CMMClassifier getClassifier(String loadPath) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifier(loadPath); return cmm; } public static CMMClassifier getClassifierNoExceptions(InputStream in) { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifierNoExceptions(new BufferedInputStream(in), null); return cmm; } public static CMMClassifier getClassifier(InputStream in) throws IOException, ClassCastException, ClassNotFoundException { CMMClassifier cmm = new CMMClassifier(); cmm.loadClassifier(new BufferedInputStream(in)); return cmm; } /** This routine builds the answerArrays which give the * empirically legal label sequences (of length (order) at most * flags.maxLeft) and the classIndex, * which indexes known answer classes. * * @param docs The training data: A List of List of CoreLabel */ private void makeAnswerArraysAndTagIndex(Collection> docs) { if (answerArrays == null) { answerArrays = new HashSet>(); } if (classIndex == null) { classIndex = new HashIndex(); } for (List doc : docs) { if (flags.useReverse) { Collections.reverse(doc); } int leng = doc.size(); for (int start = 0; start < leng; start++) { for (int diff = 1; diff <= flags.maxLeft && start + diff <= leng; diff++) { String[] seq = new String[diff]; for (int i = start; i < start + diff; i++) { seq[i - start] = doc.get(i).get(AnswerAnnotation.class); } answerArrays.add(Arrays.asList(seq)); } } for (int i = 0; i < leng; i++) { CoreLabel wordInfo = doc.get(i); classIndex.add(wordInfo.get(AnswerAnnotation.class)); } if (flags.useReverse) { Collections.reverse(doc); } } } /** Make an individual Datum out of the data list info, focused at position * loc. * @param info A List of WordInfo objects * @param loc The position in the info list to focus feature creation on * @param featureFactory The factory that constructs features out of the item * @return A Datum (BasicDatum) representing this data instance */ public Datum makeDatum(List info, int loc, FeatureFactory featureFactory) { PaddedList pInfo = new PaddedList(info, pad); Collection features = new ArrayList(); List cliques = featureFactory.getCliques(); for (Clique c : cliques) { Collection feats = featureFactory.getCliqueFeatures(pInfo, loc, c); feats = addOtherClasses(feats, pInfo, loc, c); features.addAll(feats); } printFeatures(pInfo.get(loc), features); CoreLabel c = info.get(loc); return new BasicDatum(features, c.get(AnswerAnnotation.class)); } /** This adds to the feature name the name of classes that are other than * the current class that are involved in the clique. In the CMM, these * other classes become part of the conditioning feature, and only the * class of the current position is being predicted. * * @return A collection of features with extra class information put * into the feature name. */ private static Collection addOtherClasses(Collection feats, List info, int loc, Clique c) { String addend = null; String pAnswer = info.get(loc - 1).get(AnswerAnnotation.class); String p2Answer = info.get(loc - 2).get(AnswerAnnotation.class); String p3Answer = info.get(loc - 3).get(AnswerAnnotation.class); String p4Answer = info.get(loc - 4).get(AnswerAnnotation.class); String p5Answer = info.get(loc - 5).get(AnswerAnnotation.class); String nAnswer = info.get(loc + 1).get(AnswerAnnotation.class); // cdm 2009: Is this really right? Do we not need to differentiate names that would collide??? if (c == FeatureFactory.cliqueCpC) { addend = '|' + pAnswer; } else if (c == FeatureFactory.cliqueCp2C) { addend = '|' + p2Answer; } else if (c == FeatureFactory.cliqueCp3C) { addend = '|' + p3Answer; } else if (c == FeatureFactory.cliqueCp4C) { addend = '|' + p4Answer; } else if (c == FeatureFactory.cliqueCp5C) { addend = '|' + p5Answer; } else if (c == FeatureFactory.cliqueCpCp2C) { addend = '|' + pAnswer + '-' + p2Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer; } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4Cp5C) { addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer + '-' + p5Answer; } else if (c == FeatureFactory.cliqueCnC) { addend = '|' + nAnswer; } else if (c == FeatureFactory.cliqueCpCnC) { addend = '|' + pAnswer + '-' + nAnswer; } if (addend == null) { return feats; } Collection newFeats = new HashSet(); for (String feat : feats) { String newFeat = feat + addend; newFeats.add(newFeat); } return newFeats; } private static List> getThresholds(String filename) { try { BufferedReader in = new BufferedReader(new FileReader(filename)); List> thresholds = new ArrayList>(); String line; while ((line = in.readLine()) != null) { int i = line.lastIndexOf(' '); Pattern p = Pattern.compile(line.substring(0, i)); //System.err.println(":"+line.substring(0,i)+":"); Integer t = Integer.valueOf(line.substring(i + 1)); Pair pair = new Pair(p, t); thresholds.add(pair); } in.close(); return thresholds; } catch (Exception e) { throw new RuntimeException("Error reading threshold file", e); } } public void trainSemiSup() { DocumentReaderAndWriter readerAndWriter = makeReaderAndWriter(); String filename = flags.trainFile; String biasedFilename = flags.biasedTrainFile; ObjectBank> data = makeObjectBankFromFile(filename, readerAndWriter); ObjectBank> biasedData = makeObjectBankFromFile(biasedFilename, readerAndWriter); Index featureIndex = new HashIndex(); Index classIndex = new HashIndex(); Dataset dataset = getDataset(data, featureIndex, classIndex); Dataset biasedDataset = getBiasedDataset(biasedData, featureIndex, classIndex); double[][] confusionMatrix = new double[classIndex.size()][classIndex.size()]; for (int i = 0; i < confusionMatrix.length; i++) { Arrays.fill(confusionMatrix[i], 0.0); confusionMatrix[i][i] = 1.0; } String cm = flags.confusionMatrix; String[] bits = cm.split(":"); for (String bit : bits) { String[] bits1 = bit.split("\\|"); int i1 = classIndex.indexOf(bits1[0]); int i2 = classIndex.indexOf(bits1[1]); double d = Double.parseDouble(bits1[2]); confusionMatrix[i2][i1] = d; } for (int i = 0; i < confusionMatrix.length; i++) { ArrayMath.normalize(confusionMatrix[i]); } for (int i = 0; i < confusionMatrix.length; i++) { for (int j = 0; j < i; j++) { double d = confusionMatrix[i][j]; confusionMatrix[i][j] = confusionMatrix[j][i]; confusionMatrix[j][i] = d; } } for (int i = 0; i < confusionMatrix.length; i++) { for (int j = 0; j < confusionMatrix.length; j++) { System.err.println("P("+classIndex.get(j)+ '|' +classIndex.get(i)+") = "+confusionMatrix[j][i]); } } trainSemiSup(dataset, biasedDataset, confusionMatrix); } static class Scorer implements SequenceModel { private CMMClassifier classifier = null; private int[] tagArray = null; private int[] backgroundTags = null; private Index tagIndex = null; private List lineInfos = null; private int pre = 0; private int post = 0; private Set> legalTags = null; private static final boolean VERBOSE = false; void buildTagArray() { int sz = tagIndex.size(); tagArray = new int[sz]; for (int i = 0; i < sz; i++) { tagArray[i] = i; } } public int length() { return lineInfos.size() - pre - post; } public int leftWindow() { return pre; } public int rightWindow() { return post; } public int[] getPossibleValues(int position) { // if (position == 0 || position == lineInfos.size() - 1) { // int[] a = new int[1]; // a[0] = tagIndex.indexOf(BACKGROUND); // return a; // } if (tagArray == null) { buildTagArray(); } if (position < pre) { return backgroundTags; } return tagArray; } public double scoreOf(int[] sequence) { throw new UnsupportedOperationException(); } private double[] scoreCache = null; private int[] lastWindow = null; //private int lastPos = -1; public double scoreOf(int[] tags, int pos) { if (false) { return scoresOf(tags, pos)[tags[pos]]; } if (lastWindow == null) { lastWindow = new int[leftWindow() + rightWindow() + 1]; Arrays.fill(lastWindow, -1); } boolean match = (pos == lastPos); for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) { if (i == pos || i < 0) { continue; } /*System.err.println("p:"+pos); System.err.println("lw:"+leftWindow()); System.err.println("i:"+i);*/ match &= tags[i] == lastWindow[i - pos + leftWindow()]; } if (!match) { scoreCache = scoresOf(tags, pos); for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) { if (i < 0) { continue; } lastWindow[i - pos + leftWindow()] = tags[i]; } lastPos = pos; } return scoreCache[tags[pos]]; } private int percent = -1; private int num = 0; private long secs = System.currentTimeMillis(); private long hit = 0; private long tot = 0; public double[] scoresOf(int[] tags, int pos) { if (VERBOSE) { int p = (100 * pos) / length(); if (p > percent) { long secs2 = System.currentTimeMillis(); System.err.println(StringUtils.padLeft(p, 3) + "% " + ((secs2 - secs == 0) ? 0 : (num * 1000 / (secs2 - secs))) + " hits per sec, position=" + pos + ", legal=" + ((tot == 0) ? 100 : ((100 * hit) / tot))); // + "% [hit=" + hit + ", tot=" + tot + "]"); percent = p; num = 0; secs = secs2; } tot++; } String[] answers = new String[1 + leftWindow() + rightWindow()]; String[] pre = new String[leftWindow()]; for (int i = 0; i < 1 + leftWindow() + rightWindow(); i++) { int absPos = pos - leftWindow() + i; if (absPos < 0) { continue; } answers[i] = tagIndex.get(tags[absPos]); CoreLabel li = lineInfos.get(absPos); li.set(AnswerAnnotation.class, answers[i]); if (i < leftWindow()) { pre[i] = answers[i]; } } double[] scores = new double[tagIndex.size()]; //System.out.println("Considering: "+Arrays.asList(pre)); if (!legalTags.contains(Arrays.asList(pre)) && classifier.flags.useObservedSequencesOnly) { // System.out.println("Rejecting: " + Arrays.asList(pre)); // System.out.println(legalTags); Arrays.fill(scores, -1000);// Double.NEGATIVE_INFINITY; return scores; } num++; hit++; Counter c = classifier.scoresOf(lineInfos, pos); //System.out.println("Pos "+pos+" hist "+Arrays.asList(pre)+" result "+c); //System.out.println(c); //if (false && flags.justify) { // System.out.println("Considering position " + pos + ", word is " + ((CoreLabel) lineInfos.get(pos)).word()); // //System.out.println("Datum is "+d.asFeatures()); // System.out.println("History: " + Arrays.asList(pre)); //} for (String s : c.keySet()) { int t = tagIndex.indexOf(s); if (t > -1) { int[] tA = getPossibleValues(pos); for (int j = 0; j < tA.length; j++) { if (tA[j] == t) { scores[j] = c.getCount(s); //if (false && flags.justify) { // System.out.println("Label " + s + " got score " + scores[j]); //} } } } } // normalize? if (classifier.normalize()) { ArrayMath.logNormalize(scores); } return scores; } static double[] recenter(double[] x) { double[] r = new double[x.length]; // double logTotal = Double.NEGATIVE_INFINITY; // for (int i = 0; i < x.length; i++) // logTotal = SloppyMath.logAdd(logTotal, x[i]); double logTotal = ArrayMath.logSum(x); for (int i = 0; i < x.length; i++) { r[i] = x[i] - logTotal; } return r; } /** * Build a Scorer. * * @param lineInfos List of WordInfo data items to classify * @param classifier The trained Classifier * @param pre Number of previous tags that condition current tag * @param post Number of following tags that condition previous tag * (if pre and post are both nonzero, then you have a * dependency network tagger) */ Scorer(List lineInfos, Index tagIndex, CMMClassifier classifier, int pre, int post, Set> legalTags) { if (VERBOSE) { System.err.println("Built Scorer for " + lineInfos.size() + " words, clique pre=" + pre + " post=" + post); } this.pre = pre; this.post = post; this.lineInfos = lineInfos; this.tagIndex = tagIndex; this.classifier = classifier; this.legalTags = legalTags; backgroundTags = new int[]{tagIndex.indexOf(classifier.flags.backgroundSymbol)}; } } // end class Scorer private boolean normalize() { return flags.normalize; } static int lastPos = -1; public Counter scoresOf(List lineInfos, int pos) { // if (pos != lastPos) { // System.err.print(pos+"."); // lastPos = pos; // } // System.err.print("!"); Datum d = makeDatum(lineInfos, pos, featureFactory); return classifier.logProbabilityOf(d); } /** * Takes a {@link List} of {@link CoreLabel}s and prints the likelihood * of each possible label at each point. * TODO: Finish or delete this method! * * @param document A {@link List} of {@link CoreLabel}s. */ @Override public void printProbsDocument(List document) { //ClassicCounter c = scoresOf(document, 0); } /** Command-line version of the classifier. See the class * comments for examples of use, and SeqClassifierFlags * for more information on supported flags. */ public static void main(String[] args) throws Exception { StringUtils.printErrInvocationString("CMMClassifier", args); Properties props = StringUtils.argsToProperties(args); CMMClassifier cmm = new CMMClassifier(props); String testFile = cmm.flags.testFile; String textFile = cmm.flags.textFile; String loadPath = cmm.flags.loadClassifier; String serializeTo = cmm.flags.serializeTo; // cmm.crossValidateTrainAndTest(trainFile); if (loadPath != null) { cmm.loadClassifierNoExceptions(loadPath, props); } else if (cmm.flags.loadJarClassifier != null) { cmm.loadJarClassifier(cmm.flags.loadJarClassifier, props); } else if (cmm.flags.trainFile != null) { if (cmm.flags.biasedTrainFile != null) { cmm.trainSemiSup(); } else { cmm.train(); } } else { cmm.loadDefaultClassifier(); } if (serializeTo != null) { cmm.serializeClassifier(serializeTo); } if (testFile != null) { cmm.classifyAndWriteAnswers(testFile, cmm.makeReaderAndWriter()); } else if (cmm.flags.testFiles != null) { cmm.classifyAndWriteAnswers(cmm.flags.baseTestDir, cmm.flags.testFiles, cmm.makeReaderAndWriter()); } if (textFile != null) { DocumentReaderAndWriter readerAndWriter = new PlainTextDocumentReaderAndWriter(); cmm.classifyAndWriteAnswers(textFile, readerAndWriter); } } // end main /** * Assigns NER labels to the words in the given Document. * Implements the {@link DocumentProcessor} interface. Outputs a new document * with the same meta-data as the old one, but whose contents are a * List of {@link WordTag}s, where the tags are the NER labels assigned * to the word. */ public Document processDocument(Document in) { Document d = in.blankDocument(); d.addAll(process(in)); return d; } /** * Assigns NER labels to the words in the given List. * Implements the {@link ListProcessor} interface. Checks the input * for instances of {@link HasWord} and {@link HasTag}, or uses the * toString() method, the HasWord check fails. Outputs a list of * {@link WordTag}s, where the tag is the NER label assigned to the word. */ //TODO: require this list to have things that are instanceof HasWord? public List process(List list) { List featureLabels = new ArrayList(); for (Object o : list) { CoreLabel wi = new CoreLabel(); if (o instanceof HasWord) { wi.setWord(((HasWord) o).word()); if (o instanceof HasTag) { wi.setTag(((HasTag) o).tag()); } } else { wi.setWord(o.toString()); } featureLabels.add((IN) wi); } List tagged = classify(featureLabels); List out = new ArrayList(); for (CoreLabel wi : tagged) { out.add(new WordTag(wi.word(), wi.get(AnswerAnnotation.class))); } return out; } public double weight(String feature, String label) { return ((LinearClassifier)classifier).weight(feature, label); } public double[][] weights() { return ((LinearClassifier)classifier).weights(); } @Override public List classifyWithGlobalInformation(List tokenSeq, final CoreMap doc, final CoreMap sent) { return classify(tokenSeq); } } // end class CMMClassifier





© 2015 - 2024 Weber Informatics LLC | Privacy Policy