
org.apache.ctakes.assertion.medfacts.cleartk.TrainAllAssertionModels.txt Maven / Gradle / Ivy
package org.apache.ctakes.assertion.medfacts.cleartk;
import java.util.Locale;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.log4j.Logger;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.collection.CollectionReaderDescription;
import org.cleartk.classifier.CleartkAnnotatorDescriptionFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
import org.cleartk.classifier.opennlp.MaxentDataWriter;
import org.cleartk.classifier.opennlp.MaxentStringOutcomeDataWriter;
import org.cleartk.util.cr.FilesCollectionReader;
import org.cleartk.util.cr.XReader;
import org.uimafit.component.xwriter.XWriter;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.CollectionReaderFactory;
import org.uimafit.pipeline.SimplePipeline;
import org.uimafit.testing.util.HideOutput;
//import org.junit.Test;
import org.apache.ctakes.assertion.medfacts.AssertionAnalysisEngine;
import org.apache.ctakes.typesystem.type.syntax.BaseToken;
//import edu.mayo.bmi.uima.core.type.textsem.EntityMention;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.typesystem.type.textspan.Sentence;
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.examples.pos.ExamplePOSPlainTextWriter;
public class TrainAllAssertionModels {
public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory";
public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory";
public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory";
public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory";
protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName());
/**
* @param args
*/
public static void main(String[] args) {
Options options = new Options();
Option modelDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("the directory where the model is written to for training, or read from for decoding")
.create();
options.addOption(modelDirectoryOption);
Option trainingInputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where input training xmi files are located")
.create();
options.addOption(trainingInputDirectoryOption);
Option decodingInputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where input xmi files are located for decoding")
.create();
options.addOption(decodingInputDirectoryOption);
Option decodingOutputDirectoryOption =
OptionBuilder
.withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY)
.withArgName("DIR")
.hasArg()
.isRequired()
.withDescription("directory where output xmi files that are generated in decoding are placed")
.create();
options.addOption(decodingOutputDirectoryOption);
CommandLineParser parser = new GnuParser();
boolean invalidInput = false;
CommandLine commandLine = null;
String modelDirectory = null;
String trainingInputDirectory = null;
String decodingInputDirectory = null;
String decodingOutputDirectory = null;
try
{
commandLine = parser.parse(options, args);
modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY);
trainingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY);
decodingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY);
decodingOutputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY);
} catch (ParseException e)
{
invalidInput = true;
logger.error("unable to parse command-line arguments", e);
}
if (modelDirectory == null || modelDirectory.isEmpty() ||
trainingInputDirectory == null || trainingInputDirectory.isEmpty() ||
decodingInputDirectory == null || decodingInputDirectory.isEmpty() ||
decodingOutputDirectory == null || decodingOutputDirectory.isEmpty()
)
{
logger.error("required parameters not supplied");
invalidInput = true;
}
if (invalidInput)
{
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(TrainAssertionModel.class.getName(), options, true);
return;
}
logger.info(String.format(
"%n" +
"model dir: \"%s\"%n" +
"training input dir: \"%s\"%n" +
"decoding input dir: \"%s\"%n" +
"decoding output dir: \"%s\"%n",
modelDirectory,
trainingInputDirectory,
decodingInputDirectory,
decodingOutputDirectory));
String polarityModelOutputDirectory = modelDirectory + "/maxent-polarity";
String uncertaintyModelOutputDirectory = modelDirectory + "/maxent-uncertainty";
String conditionalModelOutputDirectory = modelDirectory + "/maxent-conditional";
String subjectModelOutputDirectory = modelDirectory + "/maxent-subject";
try
{
AnalysisEngineDescription uncertaintyDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
UncertaintyCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
uncertaintyModelOutputDirectory);
AnalysisEngineDescription polarityDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
PolarityCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
polarityModelOutputDirectory);
AnalysisEngineDescription conditionalDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
ConditionalCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
conditionalModelOutputDirectory);
AnalysisEngineDescription subjectDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
SubjectCleartkAnalysisEngine.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
MaxentStringOutcomeDataWriter.class.getName(),
DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
subjectModelOutputDirectory);
testClassifierPipeline(
polarityDataWriter,
polarityModelOutputDirectory,
uncertaintyDataWriter,
uncertaintyModelOutputDirectory,
conditionalDataWriter,
conditionalModelOutputDirectory,
subjectDataWriter,
subjectModelOutputDirectory,
trainingInputDirectory,
decodingInputDirectory,
decodingOutputDirectory
);
} catch (Exception e)
{
logger.error("Some exception happened while training or decoding...", e);
return;
}
}
public static void testClassifierPipeline(
AnalysisEngineDescription polarityDataWriter,
String polarityModelOutputDirectory,
AnalysisEngineDescription uncertaintyDataWriter,
String uncertaintyModelOutputDirectory,
AnalysisEngineDescription conditionalDataWriter,
String conditionalModelOutputDirectory,
AnalysisEngineDescription subjectDataWriter,
String subjectModelOutputDirectory,
String trainingDataInputDirectory,
String decodingInputDirectory,
String decodingOutputDirectory) throws Exception {
CollectionReader trainingCollectionReader = CollectionReaderFactory.createCollectionReader(
XReader.class,
XReader.PARAM_ROOT_FILE,
trainingDataInputDirectory,
XReader.PARAM_XML_SCHEME,
XReader.XMI);
CollectionReader evaluationCollectionReader = CollectionReaderFactory.createCollectionReader(
XReader.class,
XReader.PARAM_ROOT_FILE,
decodingInputDirectory,
XReader.PARAM_XML_SCHEME,
XReader.XMI);
logger.info("starting feature generation... POLARITY");
SimplePipeline.runPipeline(
trainingCollectionReader,
polarityDataWriter);
logger.info("finished feature generation... POLARITY");
trainingCollectionReader.reconfigure();
logger.info("starting feature generation... UNCERTAINTY");
SimplePipeline.runPipeline(
trainingCollectionReader,
uncertaintyDataWriter);
logger.info("finished feature generation... UNCERTAINTY.");
trainingCollectionReader.reconfigure();
logger.info("starting feature generation... UNCERTAINTY");
SimplePipeline.runPipeline(
trainingCollectionReader,
conditionalDataWriter);
logger.info("finished feature generation... UNCERTAINTY.");
trainingCollectionReader.reconfigure();
logger.info("starting feature generation... UNCERTAINTY");
SimplePipeline.runPipeline(
trainingCollectionReader,
subjectDataWriter);
logger.info("finished feature generation... UNCERTAINTY.");
String[] args = new String[] {polarityModelOutputDirectory};
HideOutput hider = new HideOutput();
logger.info("starting training POLARITY...");
org.cleartk.classifier.jar.Train.main(args);
logger.info("finished training POLARITY .");
args = new String[] {uncertaintyModelOutputDirectory};
logger.info("starting training UNCERTAINTY...");
org.cleartk.classifier.jar.Train.main(args);
logger.info("finished training UNCERTAINTY .");
args = new String[] {conditionalModelOutputDirectory};
logger.info("starting training CONDITIONAL...");
org.cleartk.classifier.jar.Train.main(args);
logger.info("finished training CONDITIONAL .");
args = new String[] {subjectModelOutputDirectory};
logger.info("starting training SUBJECT...");
org.cleartk.classifier.jar.Train.main(args);
logger.info("finished training SUBJECT .");
hider.restoreOutput();
AnalysisEngineDescription polarityTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
PolarityCleartkAnalysisEngine.class,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
polarityModelOutputDirectory + "/model.jar");
AnalysisEngineDescription uncertaintyTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
UncertaintyCleartkAnalysisEngine.class,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
uncertaintyModelOutputDirectory + "/model.jar");
AnalysisEngineDescription conditionalTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
ConditionalCleartkAnalysisEngine.class,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
conditionalModelOutputDirectory + "/model.jar");
AnalysisEngineDescription subjectTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
SubjectCleartkAnalysisEngine.class,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
subjectModelOutputDirectory + "/model.jar");
logger.info("starting decoding...");
SimplePipeline.runPipeline(
evaluationCollectionReader,
// BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US),
// TokenAnnotator.getDescription(),
// DefaultSnowballStemmer.getDescription("English"),
polarityTaggerDescription,
uncertaintyTaggerDescription,
conditionalTaggerDescription,
subjectTaggerDescription,
AnalysisEngineFactory.createPrimitiveDescription(
XWriter.class,
AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
XWriter.PARAM_OUTPUT_DIRECTORY_NAME,
decodingOutputDirectory,
XWriter.PARAM_XML_SCHEME_NAME,
XWriter.XMI));
logger.info("finished decoding.");
}
// TODO Auto-generated method stub
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy