org.apache.ctakes.assertion.medfacts.cleartk.TrainAssertionModel Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.assertion.medfacts.cleartk;
import java.io.File;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.ctakes.assertion.eval.AssertionEvaluation;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceAnnotationsSystemAssertionClearer;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceIdentifiedAnnotationsSystemToGoldCopier;
import org.apache.ctakes.core.cc.XmiWriterCasConsumerCtakes;
import org.apache.log4j.Logger;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.ConfigurationParameterFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.testing.util.HideOutput;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.DataWriterFactory;
import org.cleartk.ml.jar.DefaultDataWriterFactory;
import org.cleartk.ml.jar.DirectoryDataWriterFactory;
import org.cleartk.ml.jar.GenericJarClassifierFactory;
import org.cleartk.ml.opennlp.maxent.MaxentStringOutcomeDataWriter;
import org.cleartk.util.cr.XReader;
//import org.junit.Test;
//import edu.mayo.bmi.uima.core.type.textsem.EntityMention;
public class TrainAssertionModel {
public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory";
public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory";
public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory";
public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory";
protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName());
/**
* @param args
*/
/*
public static void main(String[] args) {
// TODO Auto-generated method stub
String trainDir = args[0];
String outputDir = args[1];
try {
CollectionReader reader = FilesCollectionReader.getCollectionReader(trainDir);
AggregateBuilder builder = new AggregateBuilder();
//builder.add(AnalysisEngineFactory.createEngineDescription("desc/AssertionMiniPipelineAnalysisEngine.xml", null));
//builder.add(AnalysisEngineFactory.createEngineDescription(IdentifiedAnnotation.class));
//builder.add(AnalysisEngineFactory.createEngineDescription("edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation"));
builder.add(AssertionCleartkAnalysisEngine.getWriterDescription(outputDir));
SimplePipeline.runPipeline(reader, builder.createEngineDescription());
org.cleartk.classifier.jar.Train.main(outputDir);
} catch (Exception e) {
System.err.println("Exception: " + e);
e.printStackTrace();
throw new RuntimeException(e);
}
}
*/
protected String modelOutputDirectory = "/work/medfacts/cleartk/data/train.model";
//@Test
public void testMaxent() throws Exception {
String trainingDataDirectory = "/work/medfacts/cleartk/data/train";
String evaluationDataDirectory = "/work/medfacts/cleartk/data/eval2.input";
String evaluationOutputDataDirectory = "/work/medfacts/cleartk/data/eval2.output";
String maxentModelOutputDirectory = modelOutputDirectory + "/maxent";
AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
AssertionCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
maxentModelOutputDirectory);
testClassifier(
dataWriter,
maxentModelOutputDirectory,
trainingDataDirectory,
evaluationDataDirectory,
evaluationOutputDataDirectory
);
// // Not sure why the _SPLIT is here, but we will throw it out for good measure
// String firstLine = FileUtil.loadListOfStrings(new File(maxentDirectoryName
// + "/2008_Sichuan_earthquake.txt.pos"))[0].trim().replace("_SPLIT", "");
// checkPOS(firstLine);
}
public static void main(String args[])
{
Options options = new Options();
Option modelDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("the directory where the model is written to for training, or read from for decoding")
.create();
options.addOption(modelDirectoryOption);
Option trainingInputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where input training xmi files are located")
.create();
options.addOption(trainingInputDirectoryOption);
Option decodingInputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where input xmi files are located for decoding")
.create();
options.addOption(decodingInputDirectoryOption);
Option decodingOutputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where output xmi files that are generated in decoding are placed")
.create();
options.addOption(decodingOutputDirectoryOption);
CommandLineParser parser = new GnuParser();
boolean invalidInput = false;
CommandLine commandLine = null;
String modelDirectory = null;
String trainingInputDirectory = null;
String decodingInputDirectory = null;
String decodingOutputDirectory = null;
try
{
commandLine = parser.parse(options, args);
modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY);
trainingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY);
decodingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY);
decodingOutputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY);
} catch (ParseException e)
{
invalidInput = true;
logger.error("unable to parse command-line arguments", e);
}
if (modelDirectory == null || modelDirectory.isEmpty() ||
trainingInputDirectory == null || trainingInputDirectory.isEmpty() ||
decodingInputDirectory == null || decodingInputDirectory.isEmpty() ||
decodingOutputDirectory == null || decodingOutputDirectory.isEmpty()
)
{
logger.error("required parameters not supplied");
invalidInput = true;
}
if (invalidInput)
{
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(TrainAssertionModel.class.getName(), options, true);
return;
}
logger.info(String.format(
"%n" +
"model dir: \"%s\"%n" +
"training input dir: \"%s\"%n" +
"decoding input dir: \"%s\"%n" +
"decoding output dir: \"%s\"%n",
modelDirectory,
trainingInputDirectory,
decodingInputDirectory,
decodingOutputDirectory));
String maxentModelOutputDirectory = modelDirectory + "/maxent";
try
{
AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
AssertionCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
maxentModelOutputDirectory);
testClassifier(
dataWriter,
maxentModelOutputDirectory,
trainingInputDirectory,
decodingInputDirectory,
decodingOutputDirectory
);
} catch (Exception e)
{
logger.error("Some exception happened while training or decoding...", e);
return;
}
}
public static void testClassifier(
AnalysisEngineDescription dataWriter,
String modelOutputDirectory,
String trainingDataInputDirectory,
String decodingInputDirectory,
String decodingOutputDirectory,
String... trainingArgs) throws Exception
{
CollectionReader trainingCollectionReader = CollectionReaderFactory.createReader(
XReader.class,
XReader.PARAM_ROOT_FILE,
trainingDataInputDirectory,
XReader.PARAM_XML_SCHEME,
XReader.XMI);
CollectionReader evaluationCollectionReader = CollectionReaderFactory.createReader(
XReader.class,
XReader.PARAM_ROOT_FILE,
decodingInputDirectory,
XReader.PARAM_XML_SCHEME,
XReader.XMI);
AggregateBuilder trainingBuilder = new AggregateBuilder();
AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
trainingBuilder.add(goldCopierAnnotator);
AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
trainingBuilder.add(assertionAttributeClearerAnnotator);
// Class extends DataWriterFactory> dataWriterFactoryClass = DefaultMaxentDataWriterFactory.class;
AnalysisEngineDescription trainingAssertionAnnotator =
AnalysisEngineFactory.createEngineDescription(
AssertionCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION
);
ConfigurationParameterFactory.addConfigurationParameters(
trainingAssertionAnnotator,
AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME,
AssertionEvaluation.GOLD_VIEW_NAME,
// CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME,
// dataWriterFactoryClass.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
modelOutputDirectory
);
trainingBuilder.add(trainingAssertionAnnotator);
// CollectionReader collectionReader = XReader.getCollectionReader(
// trainingDataDirectory);
// collectionReader.setConfigParameterValue(XReader.PARAM_XML_SCHEME, XReader.XMI);
// collectionReader.reconfigure();
logger.info("starting feature generation...");
SimplePipeline.runPipeline(
trainingCollectionReader,
// FilesCollectionReader.getCollectionReaderWithView(
// "src/test/resources/data/treebank/11597317.tree",
// TreebankConstants.TREEBANK_VIEW)
// ,
// TreebankGoldAnnotator.getDescriptionPOSTagsOnly(),
// DefaultSnowballStemmer.getDescription("English"),
// dataWriter);
trainingBuilder.createAggregateDescription());
logger.info("finished feature generation.");
String[] args;
if (trainingArgs != null && trainingArgs.length > 0) {
args = new String[trainingArgs.length + 1];
args[0] = modelOutputDirectory;
System.arraycopy(trainingArgs, 0, args, 1, trainingArgs.length);
} else {
args = new String[] { modelOutputDirectory };
}
HideOutput hider = new HideOutput();
logger.info("starting training...");
org.cleartk.ml.jar.Train.main(args);
logger.info("finished training.");
hider.restoreOutput();
AggregateBuilder decodingBuilder = new AggregateBuilder();
//AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
decodingBuilder.add(goldCopierAnnotator);
//AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
decodingBuilder.add(assertionAttributeClearerAnnotator);
AnalysisEngineDescription decodingAssertionAnnotator =
AnalysisEngineFactory.createEngineDescription(
AssertionCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION
);
ConfigurationParameterFactory.addConfigurationParameters(
decodingAssertionAnnotator,
AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME,
AssertionEvaluation.GOLD_VIEW_NAME,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
new File(modelOutputDirectory, "model.jar").getPath()
);
decodingBuilder.add(decodingAssertionAnnotator);
//SimplePipeline.runPipeline(collectionReader, builder.createEngineDescription());
AnalysisEngineDescription decodingAggregateDescription = decodingBuilder.createAggregateDescription();
// AnalysisEngineDescription taggerDescription = AnalysisEngineFactory.createEngineDescription(
// AssertionCleartkAnalysisEngine.class,
// GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
// //AssertionComponents.TYPE_SYSTEM_DESCRIPTION,
// modelOutputDirectory + "/model.jar");
logger.info("starting decoding...");
SimplePipeline.runPipeline(
evaluationCollectionReader,
// BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US),
// TokenAnnotator.getDescription(),
// DefaultSnowballStemmer.getDescription("English"),
//taggerDescription,
decodingAggregateDescription,
AnalysisEngineFactory.createEngineDescription(
XmiWriterCasConsumerCtakes.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
XmiWriterCasConsumerCtakes.PARAM_OUTPUTDIR,
decodingOutputDirectory));
logger.info("finished decoding.");
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy