All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ctakes.assertion.medfacts.cleartk.TrainAllAssertionModels.txt Maven / Gradle / Ivy

package org.apache.ctakes.assertion.medfacts.cleartk;

import java.util.Locale;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.log4j.Logger;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;

import org.apache.uima.collection.CollectionReader;
import org.apache.uima.collection.CollectionReaderDescription;
import org.cleartk.classifier.CleartkAnnotatorDescriptionFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
import org.cleartk.classifier.opennlp.MaxentDataWriter;
import org.cleartk.classifier.opennlp.MaxentStringOutcomeDataWriter;
import org.cleartk.util.cr.FilesCollectionReader;
import org.cleartk.util.cr.XReader;
import org.uimafit.component.xwriter.XWriter;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.CollectionReaderFactory;
import org.uimafit.pipeline.SimplePipeline;
import org.uimafit.testing.util.HideOutput;
//import org.junit.Test;
import org.apache.ctakes.assertion.medfacts.AssertionAnalysisEngine;
import org.apache.ctakes.typesystem.type.syntax.BaseToken;
//import edu.mayo.bmi.uima.core.type.textsem.EntityMention;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.typesystem.type.textspan.Sentence;
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.examples.pos.ExamplePOSPlainTextWriter;


public class TrainAllAssertionModels {
	
	public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory";

	  public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory";

	  public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory";

	  public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory";

	  protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName());


	/**
	 * @param args
	 */
	public static void main(String[] args) {
		
	    Options options = new Options();
	    
	    Option modelDirectoryOption =
	        OptionBuilder
	          .withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY)
	          .withArgName("DIR")
	          .hasArg()
	          .isRequired()
	          .withDescription("the directory where the model is written to for training, or read from for decoding")
	          .create();
	    options.addOption(modelDirectoryOption);
	    
	    Option trainingInputDirectoryOption =
	        OptionBuilder
	          .withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY)
	          .withArgName("DIR")
	          .hasArg()
	          .isRequired()
	          .withDescription("directory where input training xmi files are located")
	          .create();
	    options.addOption(trainingInputDirectoryOption);
	    
	    Option decodingInputDirectoryOption =
	        OptionBuilder
	          .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY)
	          .withArgName("DIR")
	          .hasArg()
	          .isRequired()
	          .withDescription("directory where input xmi files are located for decoding")
	          .create();
	    options.addOption(decodingInputDirectoryOption);
	    
	    Option decodingOutputDirectoryOption =
	        OptionBuilder
	          .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY)
	          .withArgName("DIR")
	          .hasArg()
	          .isRequired()
	          .withDescription("directory where output xmi files that are generated in decoding are placed")
	          .create();
	    options.addOption(decodingOutputDirectoryOption);
	    
	    CommandLineParser parser = new GnuParser();
	    
	    boolean invalidInput = false;
	    
	    CommandLine commandLine = null;
	    String modelDirectory = null;
	    String trainingInputDirectory = null;
	    String decodingInputDirectory = null;
	    String decodingOutputDirectory = null;
	    try
	    {
	      commandLine = parser.parse(options, args);
	      
	      modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY);
	      trainingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY);
	      decodingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY);
	      decodingOutputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY);
	    } catch (ParseException e)
	    {
	      invalidInput = true;
	      logger.error("unable to parse command-line arguments", e);
	    }
	    
	    if (modelDirectory == null || modelDirectory.isEmpty() ||
	        trainingInputDirectory == null || trainingInputDirectory.isEmpty() ||
	        decodingInputDirectory == null || decodingInputDirectory.isEmpty() ||
	        decodingOutputDirectory == null || decodingOutputDirectory.isEmpty()
	        )
	    {
	      logger.error("required parameters not supplied");
	      invalidInput = true;
	    }
	    
	    if (invalidInput)
	    {
	      HelpFormatter formatter = new HelpFormatter();
	      formatter.printHelp(TrainAssertionModel.class.getName(), options, true);
	      return;
	    }
	    
	    logger.info(String.format(
	        "%n" +
	        "model dir:           \"%s\"%n" +
	        "training input dir:  \"%s\"%n" +
	        "decoding input dir:  \"%s\"%n" +
	        "decoding output dir: \"%s\"%n",
	        modelDirectory,
	        trainingInputDirectory,
	        decodingInputDirectory,
	        decodingOutputDirectory));
	    
	    String polarityModelOutputDirectory = modelDirectory + "/maxent-polarity";
	    String uncertaintyModelOutputDirectory = modelDirectory + "/maxent-uncertainty";
	    String conditionalModelOutputDirectory = modelDirectory + "/maxent-conditional";
	    String subjectModelOutputDirectory = modelDirectory + "/maxent-subject";
	    
	    try
	    {
	    	AnalysisEngineDescription uncertaintyDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
	  	          UncertaintyCleartkAnalysisEngine.class,
	  	          AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
	  	          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
	  	          MaxentStringOutcomeDataWriter.class.getName(),
	  	          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
	  	          uncertaintyModelOutputDirectory);
	    	AnalysisEngineDescription polarityDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
		  	          PolarityCleartkAnalysisEngine.class,
		  	          AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
		  	          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
		  	          MaxentStringOutcomeDataWriter.class.getName(),
		  	          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
		  	          polarityModelOutputDirectory);
	    	
	    	AnalysisEngineDescription conditionalDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
		  	          ConditionalCleartkAnalysisEngine.class,
		  	          AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
		  	          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
		  	          MaxentStringOutcomeDataWriter.class.getName(),
		  	          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
		  	          conditionalModelOutputDirectory);
	    	
	    	AnalysisEngineDescription subjectDataWriter = AnalysisEngineFactory.createPrimitiveDescription(
		  	          SubjectCleartkAnalysisEngine.class,
		  	          AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
		  	          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
		  	          MaxentStringOutcomeDataWriter.class.getName(),
		  	          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
		  	          subjectModelOutputDirectory);
	    	
	      testClassifierPipeline(
	          polarityDataWriter,
	          polarityModelOutputDirectory,
	          uncertaintyDataWriter,
	          uncertaintyModelOutputDirectory,
	          conditionalDataWriter,
	          conditionalModelOutputDirectory,
	          subjectDataWriter,
	          subjectModelOutputDirectory,
	          trainingInputDirectory,
	          decodingInputDirectory,
	          decodingOutputDirectory
	      );
	    } catch (Exception e)
	    {
	      logger.error("Some exception happened while training or decoding...", e);
	      return;
	    }
	 }
	
	public static void testClassifierPipeline(
			  AnalysisEngineDescription polarityDataWriter,
		      String polarityModelOutputDirectory,
		      AnalysisEngineDescription uncertaintyDataWriter,
		      String uncertaintyModelOutputDirectory,
		      AnalysisEngineDescription conditionalDataWriter,
		      String conditionalModelOutputDirectory,
		      AnalysisEngineDescription subjectDataWriter,
		      String subjectModelOutputDirectory,
		      String trainingDataInputDirectory,
		      String decodingInputDirectory,
		      String decodingOutputDirectory) throws Exception {
		
		    CollectionReader trainingCollectionReader = CollectionReaderFactory.createCollectionReader(
		        XReader.class,
		        XReader.PARAM_ROOT_FILE,
		        trainingDataInputDirectory,
		        XReader.PARAM_XML_SCHEME,
		        XReader.XMI);
		
		    CollectionReader evaluationCollectionReader = CollectionReaderFactory.createCollectionReader(
		        XReader.class,
		        XReader.PARAM_ROOT_FILE,
		        decodingInputDirectory,
		        XReader.PARAM_XML_SCHEME,
		        XReader.XMI);
		    
		    logger.info("starting feature generation... POLARITY");
		    SimplePipeline.runPipeline(
		        trainingCollectionReader,
		        polarityDataWriter);
		    logger.info("finished feature generation... POLARITY");

		    trainingCollectionReader.reconfigure();
		    logger.info("starting feature generation... UNCERTAINTY");
		    SimplePipeline.runPipeline(
		        trainingCollectionReader,
		        uncertaintyDataWriter);
		    logger.info("finished feature generation... UNCERTAINTY.");

		    trainingCollectionReader.reconfigure();
		    logger.info("starting feature generation... UNCERTAINTY");
		    SimplePipeline.runPipeline(
		        trainingCollectionReader,
		        conditionalDataWriter);
		    logger.info("finished feature generation... UNCERTAINTY.");

		    trainingCollectionReader.reconfigure();
		    logger.info("starting feature generation... UNCERTAINTY");
		    SimplePipeline.runPipeline(
		        trainingCollectionReader,
		        subjectDataWriter);
		    logger.info("finished feature generation... UNCERTAINTY.");
		    
		    String[] args = new String[] {polarityModelOutputDirectory};
		    HideOutput hider = new HideOutput();
		    logger.info("starting training POLARITY...");
		    org.cleartk.classifier.jar.Train.main(args);
		    logger.info("finished training POLARITY .");
		    
		    args = new String[] {uncertaintyModelOutputDirectory};
		    logger.info("starting training UNCERTAINTY...");
		    org.cleartk.classifier.jar.Train.main(args);
		    logger.info("finished training UNCERTAINTY .");
		    
		    args = new String[] {conditionalModelOutputDirectory};
		    logger.info("starting training CONDITIONAL...");
		    org.cleartk.classifier.jar.Train.main(args);
		    logger.info("finished training CONDITIONAL .");
		    
		    args = new String[] {subjectModelOutputDirectory};
		    logger.info("starting training SUBJECT...");
		    org.cleartk.classifier.jar.Train.main(args);
		    logger.info("finished training SUBJECT .");		    
		    hider.restoreOutput();		    

		    AnalysisEngineDescription polarityTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
			        PolarityCleartkAnalysisEngine.class,
			        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
			        polarityModelOutputDirectory + "/model.jar");
		    
		    AnalysisEngineDescription uncertaintyTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
			        UncertaintyCleartkAnalysisEngine.class,
			        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
			        uncertaintyModelOutputDirectory + "/model.jar");
		    
		    AnalysisEngineDescription conditionalTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
			        ConditionalCleartkAnalysisEngine.class,
			        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
			        conditionalModelOutputDirectory + "/model.jar");
		    
		    AnalysisEngineDescription subjectTaggerDescription = AnalysisEngineFactory.createPrimitiveDescription(
			        SubjectCleartkAnalysisEngine.class,
			        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
			        subjectModelOutputDirectory + "/model.jar");

		    logger.info("starting decoding...");
		    SimplePipeline.runPipeline(
		        evaluationCollectionReader,
//		        BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US),
//		        TokenAnnotator.getDescription(),
//		        DefaultSnowballStemmer.getDescription("English"),
		        polarityTaggerDescription,
		        uncertaintyTaggerDescription,
		        conditionalTaggerDescription,
		        subjectTaggerDescription,
		        AnalysisEngineFactory.createPrimitiveDescription(
		            XWriter.class,
		            AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
		            XWriter.PARAM_OUTPUT_DIRECTORY_NAME,
		            decodingOutputDirectory,
		            XWriter.PARAM_XML_SCHEME_NAME,
		            XWriter.XMI));
		    logger.info("finished decoding.");

				  
			  }

		
		// TODO Auto-generated method stub

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy