org.cleartk.examples.documentclassification.advanced.DocumentClassificationEvaluation Maven / Gradle / Ivy
/**
* Copyright (c) 2007-2012, Regents of the University of Colorado
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package org.cleartk.examples.documentclassification.advanced;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.apache.commons.io.filefilter.HiddenFileFilter;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.cas.CAS;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.classifier.CleartkAnnotator;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.transform.InstanceDataWriter;
import org.cleartk.classifier.feature.transform.InstanceStream;
import org.cleartk.classifier.feature.transform.extractor.CentroidTfidfSimilarityExtractor;
import org.cleartk.classifier.feature.transform.extractor.MinMaxNormalizationExtractor;
import org.cleartk.classifier.feature.transform.extractor.TfidfExtractor;
import org.cleartk.classifier.feature.transform.extractor.ZeroMeanUnitStddevExtractor;
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
import org.cleartk.classifier.jar.JarClassifierBuilder;
import org.cleartk.classifier.libsvm.LIBSVMStringOutcomeDataWriter;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.eval.Evaluation_ImplBase;
import org.cleartk.examples.type.UsenetDocument;
import org.cleartk.syntax.opennlp.SentenceAnnotator;
import org.cleartk.token.stem.snowball.DefaultSnowballStemmer;
import org.cleartk.token.tokenizer.TokenAnnotator;
import org.cleartk.util.Options_ImplBase;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.kohsuke.args4j.Option;
import org.uimafit.component.ViewTextCopierAnnotator;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.ConfigurationParameterFactory;
import org.uimafit.pipeline.JCasIterable;
import org.uimafit.pipeline.SimplePipeline;
import org.uimafit.testing.util.HideOutput;
import org.uimafit.util.JCasUtil;
import com.google.common.base.Function;
/**
*
* Copyright (c) 2012, Regents of the University of Colorado
* All rights reserved.
*
* This evaluation class provides a concrete example of how to train and evaluate classifiers.
* Specifically this class will train a document categorizer using a subset of the 20 newsgroups
* dataset. It evaluates performance using 2-fold cross validation as well as a holdout set.
*
*
* Key points:
*
* - Creating training and evaluation pipelines
*
- Example of feature transformation / normalization
*
*
*
* @author Lee Becker
*/
public class DocumentClassificationEvaluation extends
Evaluation_ImplBase> {
public static class Options extends Options_ImplBase {
@Option(
name = "--train-dir",
usage = "Specify the directory containing the training documents. This is used for cross-validation, and for training in a holdout set evaluation. "
+ "When we run this example we point to a directory containing training data from a subset of the 20 newsgroup corpus - i.e. a directory called '3news-bydate/train'")
public File trainDirectory = new File("src/main/resources/data/3news-bydate/train");
@Option(
name = "--test-dir",
usage = "Specify the directory containing the test (aka holdout/validation) documents. This is for holdout set evaluation. "
+ "When we run this example we point to a directory containing training data from a subset of the 20 newsgroup corpus - i.e. a directory called '3news-bydate/test'")
public File testDirectory = new File("src/main/resources/data/3news-bydate/test");
@Option(
name = "--models-dir",
usage = "specify the directory in which to write out the trained model files")
public File modelsDirectory = new File("target/document_classification/models");
@Option(
name = "--training-args",
usage = "specify training arguments to be passed to the learner. For multiple values specify -ta for each - e.g. '-ta -t -ta 0'")
public List trainingArguments = Arrays.asList("-t", "0");
}
public static enum AnnotatorMode {
TRAIN, TEST, CLASSIFY
}
public static List getFilesFromDirectory(File directory) {
IOFileFilter fileFilter = FileFilterUtils.makeSVNAware(HiddenFileFilter.VISIBLE);
IOFileFilter dirFilter = FileFilterUtils.makeSVNAware(FileFilterUtils.and(
FileFilterUtils.directoryFileFilter(),
HiddenFileFilter.VISIBLE));
return new ArrayList(FileUtils.listFiles(directory, fileFilter, dirFilter));
}
public static void main(String[] args) throws Exception {
Options options = new Options();
options.parseOptions(args);
List trainFiles = getFilesFromDirectory(options.trainDirectory);
List testFiles = getFilesFromDirectory(options.testDirectory);
DocumentClassificationEvaluation evaluation = new DocumentClassificationEvaluation(
options.modelsDirectory,
options.trainingArguments);
// Run Cross Validation
List> foldStats = evaluation.crossValidation(trainFiles, 2);
AnnotationStatistics crossValidationStats = AnnotationStatistics.addAll(foldStats);
System.err.println("Cross Validation Results:");
System.err.print(crossValidationStats);
System.err.println();
System.err.println(crossValidationStats.confusions());
System.err.println();
// Run Holdout Set
AnnotationStatistics holdoutStats = evaluation.trainAndTest(trainFiles, testFiles);
System.err.println("Holdout Set Results:");
System.err.print(holdoutStats);
System.err.println();
System.err.println(holdoutStats.confusions());
}
public static final String GOLD_VIEW_NAME = "DocumentClassificationGoldView";
public static final String SYSTEM_VIEW_NAME = CAS.NAME_DEFAULT_SOFA;
private List trainingArguments;
public DocumentClassificationEvaluation(File baseDirectory) {
super(baseDirectory);
this.trainingArguments = Arrays. asList();
}
public DocumentClassificationEvaluation(File baseDirectory, List trainingArguments) {
super(baseDirectory);
this.trainingArguments = trainingArguments;
}
@Override
protected CollectionReader getCollectionReader(List items) throws Exception {
return UriCollectionReader.getCollectionReaderFromFiles(items);
}
@Override
public void train(CollectionReader collectionReader, File outputDirectory) throws Exception {
// ////////////////////////////////////////////////////////////////////////////////
// Step 1: Extract features and serialize the raw instance objects
// Note: DocumentClassificationAnnotator sets the various extractor URI values to null by
// default. This signals to the feature extractors that they are being written out for training
// ////////////////////////////////////////////////////////////////////////////////
System.err.println("Step 1: Extracting features and writing raw instances data");
// Create and run the document classification training pipeline
AggregateBuilder builder = DocumentClassificationEvaluation.createDocumentClassificationAggregate(
outputDirectory,
AnnotatorMode.TRAIN);
SimplePipeline.runPipeline(collectionReader, builder.createAggregateDescription());
// Load the serialized instance data
Iterable> instances = InstanceStream.loadFromDirectory(outputDirectory);
// ////////////////////////////////////////////////////////////////////////////////
// Step 2: Transform features and write training data
// In this phase, the normalization statistics are computed and the raw
// features are transformed into normalized features.
// Then the adjusted values are written with a DataWriter (libsvm in this case)
// for training
// ////////////////////////////////////////////////////////////////////////////////
System.err.println("Collection feature normalization statistics");
// Collect TF*IDF stats for computing tf*idf values on extracted tokens
URI tfIdfDataURI = DocumentClassificationAnnotator.createTokenTfIdfDataURI(outputDirectory);
TfidfExtractor extractor = new TfidfExtractor(
DocumentClassificationAnnotator.TFIDF_EXTRACTOR_KEY);
extractor.train(instances);
extractor.save(tfIdfDataURI);
// Collect TF*IDF Centroid stats for computing similarity to corpus centroid
URI tfIdfCentroidSimDataURI = DocumentClassificationAnnotator.createIdfCentroidSimilarityDataURI(outputDirectory);
CentroidTfidfSimilarityExtractor simExtractor = new CentroidTfidfSimilarityExtractor(
DocumentClassificationAnnotator.CENTROID_TFIDF_SIM_EXTRACTOR_KEY);
simExtractor.train(instances);
simExtractor.save(tfIdfCentroidSimDataURI);
// Collect ZMUS stats for feature normalization
URI zmusDataURI = DocumentClassificationAnnotator.createZmusDataURI(outputDirectory);
ZeroMeanUnitStddevExtractor zmusExtractor = new ZeroMeanUnitStddevExtractor(
DocumentClassificationAnnotator.ZMUS_EXTRACTOR_KEY);
zmusExtractor.train(instances);
zmusExtractor.save(zmusDataURI);
// Collect MinMax stats for feature normalization
URI minmaxDataURI = DocumentClassificationAnnotator.createMinMaxDataURI(outputDirectory);
MinMaxNormalizationExtractor minmaxExtractor = new MinMaxNormalizationExtractor(
DocumentClassificationAnnotator.MINMAX_EXTRACTOR_KEY);
minmaxExtractor.train(instances);
minmaxExtractor.save(minmaxDataURI);
// Rerun training data writer pipeline, to transform the extracted instances -- an alternative,
// more costly approach would be to reinitialize the DocumentClassificationAnnotator above with
// the URIs for the feature
// extractor.
//
// In this example, we now write in the libsvm format
System.err.println("Write out model training data");
LIBSVMStringOutcomeDataWriter dataWriter = new LIBSVMStringOutcomeDataWriter(outputDirectory);
for (Instance instance : instances) {
instance = extractor.transform(instance);
instance = simExtractor.transform(instance);
instance = zmusExtractor.transform(instance);
instance = minmaxExtractor.transform(instance);
dataWriter.write(instance);
}
dataWriter.finish();
// //////////////////////////////////////////////////////////////////////////////
// Stage 3: Train and write model
// Now that the features have been extracted and normalized, we can proceed
// in running machine learning to train and package a model
// //////////////////////////////////////////////////////////////////////////////
System.err.println("Train model and write model.jar file.");
HideOutput hider = new HideOutput();
JarClassifierBuilder.trainAndPackage(
outputDirectory,
this.trainingArguments.toArray(new String[this.trainingArguments.size()]));
hider.restoreOutput();
}
/**
* Creates the preprocessing pipeline needed for document classification. Specifically this
* consists of:
*
* - Populating the default view with the document text (as specified in the URIView)
*
- Sentence segmentation
*
- Tokenization
*
- Stemming
*
- [optional] labeling the document with gold-standard document categories
*
*/
public static AggregateBuilder createPreprocessingAggregate(
File modelDirectory,
AnnotatorMode mode) throws ResourceInitializationException {
AggregateBuilder builder = new AggregateBuilder();
builder.add(UriToDocumentTextAnnotator.getDescription());
// NLP pre-processing components
builder.add(SentenceAnnotator.getDescription());
builder.add(TokenAnnotator.getDescription());
builder.add(DefaultSnowballStemmer.getDescription("English"));
// Now annotate documents with gold standard labels
switch (mode) {
case TRAIN:
// If this is training, put the label categories directly into the default view
builder.add(AnalysisEngineFactory.createPrimitiveDescription(GoldDocumentCategoryAnnotator.class));
break;
case TEST:
// Copies the text from the default view to a separate gold view
builder.add(AnalysisEngineFactory.createPrimitiveDescription(
ViewTextCopierAnnotator.class,
ViewTextCopierAnnotator.PARAM_SOURCE_VIEW_NAME,
CAS.NAME_DEFAULT_SOFA,
ViewTextCopierAnnotator.PARAM_DESTINATION_VIEW_NAME,
GOLD_VIEW_NAME));
// If this is testing, put the document categories in the gold view
// The extra parameters to add() map the default view to the gold view.
builder.add(
AnalysisEngineFactory.createPrimitiveDescription(GoldDocumentCategoryAnnotator.class),
CAS.NAME_DEFAULT_SOFA,
GOLD_VIEW_NAME);
break;
case CLASSIFY:
default:
// In normal mode don't deal with gold labels
break;
}
return builder;
}
/**
* Creates the aggregate builder for the document classification pipeline
*/
public static AggregateBuilder createDocumentClassificationAggregate(
File modelDirectory,
AnnotatorMode mode) throws ResourceInitializationException {
AggregateBuilder builder = DocumentClassificationEvaluation.createPreprocessingAggregate(
modelDirectory,
mode);
switch (mode) {
case TRAIN:
// For training we will create DocumentClassificationAnnotator that
// Extracts the features as is, and then writes out the data to
// a serialized instance file.
builder.add(AnalysisEngineFactory.createPrimitiveDescription(
DocumentClassificationAnnotator.class,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
InstanceDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
modelDirectory.getPath()));
break;
case TEST:
case CLASSIFY:
default:
// For testing and standalone classification, we want to create a
// DocumentClassificationAnnotator using
// all of the model data computed during training. This includes feature normalization data
// and thei model jar file for the classifying algorithm
AnalysisEngineDescription documentClassificationAnnotator = AnalysisEngineFactory.createPrimitiveDescription(
DocumentClassificationAnnotator.class,
CleartkAnnotator.PARAM_IS_TRAINING,
false,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
new File(modelDirectory, "model.jar").getPath());
ConfigurationParameterFactory.addConfigurationParameters(
documentClassificationAnnotator,
DocumentClassificationAnnotator.PARAM_TF_IDF_URI,
DocumentClassificationAnnotator.createTokenTfIdfDataURI(modelDirectory),
DocumentClassificationAnnotator.PARAM_TF_IDF_CENTROID_SIMILARITY_URI,
DocumentClassificationAnnotator.createIdfCentroidSimilarityDataURI(modelDirectory),
DocumentClassificationAnnotator.PARAM_MINMAX_URI,
DocumentClassificationAnnotator.createMinMaxDataURI(modelDirectory),
DocumentClassificationAnnotator.PARAM_ZMUS_URI,
DocumentClassificationAnnotator.createZmusDataURI(modelDirectory));
builder.add(documentClassificationAnnotator);
break;
}
return builder;
}
@Override
protected AnnotationStatistics test(CollectionReader collectionReader, File directory)
throws Exception {
AnnotationStatistics stats = new AnnotationStatistics();
// Create the document classification pipeline
AggregateBuilder builder = DocumentClassificationEvaluation.createDocumentClassificationAggregate(
directory,
AnnotatorMode.TEST);
AnalysisEngine engine = builder.createAggregate();
// Run and evaluate
Function getSpan = AnnotationStatistics.annotationToSpan();
Function getCategory = AnnotationStatistics.annotationToFeatureValue("category");
for (JCas jCas : new JCasIterable(collectionReader, engine)) {
JCas goldView = jCas.getView(GOLD_VIEW_NAME);
JCas systemView = jCas.getView(DocumentClassificationEvaluation.SYSTEM_VIEW_NAME);
// Get results from system and gold views, and update results accordingly
Collection goldCategories = JCasUtil.select(goldView, UsenetDocument.class);
Collection systemCategories = JCasUtil.select(
systemView,
UsenetDocument.class);
stats.add(goldCategories, systemCategories, getSpan, getCategory);
}
return stats;
}
}