All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.cleartk.examples.chunking.EvaluateNamedEntityChunker Maven / Gradle / Ivy

The newest version!
/** 
 * Copyright (c) 2007-2012, Regents of the University of Colorado 
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 
 * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 
 * Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE. 
 */
package org.cleartk.examples.chunking;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.apache.uima.cas.CAS;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.cleartk.classifier.CleartkSequenceAnnotator;
import org.cleartk.classifier.jar.DefaultSequenceDataWriterFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
import org.cleartk.classifier.jar.Train;
import org.cleartk.classifier.mallet.MalletCRFStringOutcomeDataWriter;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.eval.Evaluation_ImplBase;
import org.cleartk.examples.chunking.TrainNamedEntityChunker.MASCTextFileFilter;
import org.cleartk.examples.chunking.util.MASCGoldAnnotator;
import org.cleartk.ne.type.NamedEntityMention;
import org.cleartk.syntax.opennlp.PosTaggerAnnotator;
import org.cleartk.syntax.opennlp.SentenceAnnotator;
import org.cleartk.token.tokenizer.TokenAnnotator;
import org.cleartk.util.Options_ImplBase;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.kohsuke.args4j.Option;
import org.uimafit.component.ViewCreatorAnnotator;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.CollectionReaderFactory;
import org.uimafit.pipeline.JCasIterable;
import org.uimafit.pipeline.SimplePipeline;
import org.uimafit.util.JCasUtil;

import com.google.common.base.Function;

/**
 * 

* This class can be used to train and test the named entity chunker. It can be used to either * perform 2-fold cross-validation, or training and testing on a holdout test set, or just to * evaluate/test a pre-trained model. * *
* Copyright (c) 2012, Regents of the University of Colorado
* All rights reserved. * * @author Himanshu Gahlot * @author Steven Bethard */ public class EvaluateNamedEntityChunker extends Evaluation_ImplBase> { public static class Options extends Options_ImplBase { @Option( name = "--train-dir", usage = "Specify the directory containing the training documents. This is used for cross-validation and for training in a holdout set evaluator. " + "When we run this example we point to a directory containing training data from the MASC-1.0.3 corpus - i.e. a directory called 'MASC-1.0.3/data/written'") public File trainDirectory = new File("src/main/resources/data/MASC-1.0.3/data/written"); @Option( name = "--models-dir", usage = "specify the directory in which to write out the trained model files") public File modelsDirectory = new File("target/chunking/ne-model"); } public static void main(String[] args) throws Exception { Options options = new Options(); options.parseOptions(args); // find training files List trainFiles = new ArrayList(FileUtils.listFiles( options.trainDirectory, new MASCTextFileFilter(), FileFilterUtils.falseFileFilter())); // run cross validation EvaluateNamedEntityChunker evaluator = new EvaluateNamedEntityChunker(options.modelsDirectory); List> foldStats = evaluator.crossValidation(trainFiles, 2); AnnotationStatistics crossValidationStats = AnnotationStatistics.addAll(foldStats); System.err.println("Cross Validation Results:"); System.err.print(crossValidationStats); System.err.println(); System.err.println(crossValidationStats.confusions()); System.err.println(); // train and save a model using all the data evaluator.trainAndTest(trainFiles, Collections. emptyList()); } public EvaluateNamedEntityChunker(File baseDirectory) { super(baseDirectory); } @Override protected CollectionReader getCollectionReader(List files) throws Exception { return CollectionReaderFactory.createCollectionReader(UriCollectionReader.getDescriptionFromFiles(files)); } @Override public void train(CollectionReader collectionReader, File outputDirectory) throws Exception { // assemble the training pipeline AggregateBuilder aggregate = new AggregateBuilder(); // an annotator that loads the text from the training file URIs aggregate.add(UriToDocumentTextAnnotator.getDescription()); // an annotator that parses and loads MASC named entity annotations (and tokens) aggregate.add(MASCGoldAnnotator.getDescription()); // an annotator that adds part-of-speech tags aggregate.add(PosTaggerAnnotator.getDescription()); // our NamedEntityChunker annotator, configured to write Mallet CRF training data aggregate.add(AnalysisEngineFactory.createPrimitiveDescription( NamedEntityChunker.class, CleartkSequenceAnnotator.PARAM_IS_TRAINING, true, DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, outputDirectory, DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME, MalletCRFStringOutcomeDataWriter.class)); // run the pipeline over the training corpus SimplePipeline.runPipeline(collectionReader, aggregate.createAggregateDescription()); // quiet Mallet down a bit (but still leave likelihoods so you can see progress) Logger malletLogger = Logger.getLogger("cc.mallet"); malletLogger.setLevel(Level.WARNING); Logger likelihoodLogger = Logger.getLogger("cc.mallet.fst.CRFOptimizableByLabelLikelihood"); likelihoodLogger.setLevel(Level.INFO); // train a Mallet CRF model on the training data Train.main(outputDirectory); } @Override protected AnnotationStatistics test(CollectionReader collectionReader, File modelDirectory) throws Exception { final String defaultViewName = CAS.NAME_DEFAULT_SOFA; final String goldViewName = "GoldView"; // define the pipeline AggregateBuilder aggregate = new AggregateBuilder(); // Annotators processing the gold view: // * create the gold view // * load the text // * load the MASC annotations aggregate.add(AnalysisEngineFactory.createPrimitiveDescription( ViewCreatorAnnotator.class, ViewCreatorAnnotator.PARAM_VIEW_NAME, goldViewName)); aggregate.add(UriToDocumentTextAnnotator.getDescription(), defaultViewName, goldViewName); aggregate.add(MASCGoldAnnotator.getDescription(), defaultViewName, goldViewName); // Annotators processing the default (system) view: // * load the text // * parse sentences, tokens, part-of-speech tags // * run the named entity chunker aggregate.add(UriToDocumentTextAnnotator.getDescription()); aggregate.add(SentenceAnnotator.getDescription()); aggregate.add(TokenAnnotator.getDescription()); aggregate.add(PosTaggerAnnotator.getDescription()); aggregate.add(AnalysisEngineFactory.createPrimitiveDescription( NamedEntityChunker.class, CleartkSequenceAnnotator.PARAM_IS_TRAINING, false, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, new File(modelDirectory, "model.jar"))); // prepare the evaluation statistics AnnotationStatistics stats = new AnnotationStatistics(); Function getSpan = AnnotationStatistics.annotationToSpan(); Function getCategory = AnnotationStatistics.annotationToFeatureValue("mentionType"); // iterate over each JCas to be evaluated for (JCas jCas : new JCasIterable(collectionReader, aggregate.createAggregate())) { JCas goldView = jCas.getView(goldViewName); JCas systemView = jCas.getView(defaultViewName); // extract the named entity mentions from both gold and system views Collection goldMentions, systemMentions; goldMentions = JCasUtil.select(goldView, NamedEntityMention.class); systemMentions = JCasUtil.select(systemView, NamedEntityMention.class); // compare the system mentions to the gold mentions stats.add(goldMentions, systemMentions, getSpan, getCategory); } return stats; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy