All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ctakes.assertion.medfacts.cleartk.TrainAssertionModel Maven / Gradle / Ivy

There is a newer version: 6.0.0
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.assertion.medfacts.cleartk;

import java.io.File;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.ctakes.assertion.eval.AssertionEvaluation;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceAnnotationsSystemAssertionClearer;
import org.apache.ctakes.assertion.eval.AssertionEvaluation.ReferenceIdentifiedAnnotationsSystemToGoldCopier;
import org.apache.ctakes.core.cc.XmiWriterCasConsumerCtakes;
import org.apache.log4j.Logger;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.ConfigurationParameterFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.testing.util.HideOutput;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.DataWriterFactory;
import org.cleartk.ml.jar.DefaultDataWriterFactory;
import org.cleartk.ml.jar.DirectoryDataWriterFactory;
import org.cleartk.ml.jar.GenericJarClassifierFactory;
import org.cleartk.ml.opennlp.maxent.MaxentStringOutcomeDataWriter;
import org.cleartk.util.cr.XReader;
//import org.junit.Test;
//import edu.mayo.bmi.uima.core.type.textsem.EntityMention;


public class TrainAssertionModel {
  
  public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory";

  public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory";

  public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory";

  public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory";

  protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName());

	/**
	 * @param args
	 */
  /*
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		String trainDir = args[0];
		String outputDir = args[1];
		
		try {
		CollectionReader reader = FilesCollectionReader.getCollectionReader(trainDir);
		AggregateBuilder builder = new AggregateBuilder();
		//builder.add(AnalysisEngineFactory.createEngineDescription("desc/AssertionMiniPipelineAnalysisEngine.xml", null));
		//builder.add(AnalysisEngineFactory.createEngineDescription(IdentifiedAnnotation.class));
		//builder.add(AnalysisEngineFactory.createEngineDescription("edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation"));
	    builder.add(AssertionCleartkAnalysisEngine.getWriterDescription(outputDir));
	    SimplePipeline.runPipeline(reader, builder.createEngineDescription());
	    org.cleartk.classifier.jar.Train.main(outputDir);
		} catch (Exception e) {
			System.err.println("Exception: " + e);
			e.printStackTrace();
			throw new RuntimeException(e);
		}
		
	    
	}
	*/
	
	protected String modelOutputDirectory = "/work/medfacts/cleartk/data/train.model";

  //@Test
  public void testMaxent() throws Exception {
    
    String trainingDataDirectory = "/work/medfacts/cleartk/data/train";
    String evaluationDataDirectory = "/work/medfacts/cleartk/data/eval2.input";
    String evaluationOutputDataDirectory = "/work/medfacts/cleartk/data/eval2.output";

    String maxentModelOutputDirectory = modelOutputDirectory + "/maxent";
    AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
        AssertionCleartkAnalysisEngine.class,
        AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
        DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
        MaxentStringOutcomeDataWriter.class.getName(),
        DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
        maxentModelOutputDirectory);
    testClassifier(
        dataWriter,
        maxentModelOutputDirectory,
        trainingDataDirectory,
        evaluationDataDirectory,
        evaluationOutputDataDirectory
    );

//    // Not sure why the _SPLIT is here, but we will throw it out for good measure
//    String firstLine = FileUtil.loadListOfStrings(new File(maxentDirectoryName
//        + "/2008_Sichuan_earthquake.txt.pos"))[0].trim().replace("_SPLIT", "");
//    checkPOS(firstLine);
  }
  
  public static void main(String args[])
  {
    
    Options options = new Options();
    
    Option modelDirectoryOption =
        OptionBuilder
          .withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY)
          .withArgName("DIR")
          .hasArg()
          .isRequired()
          .withDescription("the directory where the model is written to for training, or read from for decoding")
          .create();
    options.addOption(modelDirectoryOption);
    
    Option trainingInputDirectoryOption =
        OptionBuilder
          .withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY)
          .withArgName("DIR")
          .hasArg()
          .isRequired()
          .withDescription("directory where input training xmi files are located")
          .create();
    options.addOption(trainingInputDirectoryOption);
    
    Option decodingInputDirectoryOption =
        OptionBuilder
          .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY)
          .withArgName("DIR")
          .hasArg()
          .isRequired()
          .withDescription("directory where input xmi files are located for decoding")
          .create();
    options.addOption(decodingInputDirectoryOption);
    
    Option decodingOutputDirectoryOption =
        OptionBuilder
          .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY)
          .withArgName("DIR")
          .hasArg()
          .isRequired()
          .withDescription("directory where output xmi files that are generated in decoding are placed")
          .create();
    options.addOption(decodingOutputDirectoryOption);
    
    CommandLineParser parser = new GnuParser();
    
    boolean invalidInput = false;
    
    CommandLine commandLine = null;
    String modelDirectory = null;
    String trainingInputDirectory = null;
    String decodingInputDirectory = null;
    String decodingOutputDirectory = null;
    try
    {
      commandLine = parser.parse(options, args);
      
      modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY);
      trainingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY);
      decodingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY);
      decodingOutputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY);
    } catch (ParseException e)
    {
      invalidInput = true;
      logger.error("unable to parse command-line arguments", e);
    }
    
    if (modelDirectory == null || modelDirectory.isEmpty() ||
        trainingInputDirectory == null || trainingInputDirectory.isEmpty() ||
        decodingInputDirectory == null || decodingInputDirectory.isEmpty() ||
        decodingOutputDirectory == null || decodingOutputDirectory.isEmpty()
        )
    {
      logger.error("required parameters not supplied");
      invalidInput = true;
    }
    
    if (invalidInput)
    {
      HelpFormatter formatter = new HelpFormatter();
      formatter.printHelp(TrainAssertionModel.class.getName(), options, true);
      return;
    }
    
    logger.info(String.format(
        "%n" +
        "model dir:           \"%s\"%n" +
        "training input dir:  \"%s\"%n" +
        "decoding input dir:  \"%s\"%n" +
        "decoding output dir: \"%s\"%n",
        modelDirectory,
        trainingInputDirectory,
        decodingInputDirectory,
        decodingOutputDirectory));
    
    String maxentModelOutputDirectory = modelDirectory + "/maxent";
    try
    {
      AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createEngineDescription(
          AssertionCleartkAnalysisEngine.class,
          AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
          MaxentStringOutcomeDataWriter.class.getName(),
          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
          maxentModelOutputDirectory);
      testClassifier(
          dataWriter,
          maxentModelOutputDirectory,
          trainingInputDirectory,
          decodingInputDirectory,
          decodingOutputDirectory
      );
    } catch (Exception e)
    {
      logger.error("Some exception happened while training or decoding...", e);
      return;
    }
    
 }

	
	
  public static void testClassifier(
      AnalysisEngineDescription dataWriter,
      String modelOutputDirectory,
      String trainingDataInputDirectory,
      String decodingInputDirectory,
      String decodingOutputDirectory,
      String... trainingArgs) throws Exception
  {
    
    CollectionReader trainingCollectionReader = CollectionReaderFactory.createReader(
        XReader.class,
        XReader.PARAM_ROOT_FILE,
        trainingDataInputDirectory,
        XReader.PARAM_XML_SCHEME,
        XReader.XMI);
    CollectionReader evaluationCollectionReader = CollectionReaderFactory.createReader(
        XReader.class,
        XReader.PARAM_ROOT_FILE,
        decodingInputDirectory,
        XReader.PARAM_XML_SCHEME,
        XReader.XMI);
    
    
    AggregateBuilder trainingBuilder = new AggregateBuilder();
    
    AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
    trainingBuilder.add(goldCopierAnnotator);
    
    AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
    trainingBuilder.add(assertionAttributeClearerAnnotator);
    
//    Class> dataWriterFactoryClass = DefaultMaxentDataWriterFactory.class;
    AnalysisEngineDescription trainingAssertionAnnotator = 
        AnalysisEngineFactory.createEngineDescription(
            AssertionCleartkAnalysisEngine.class,
            AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION
            );
    ConfigurationParameterFactory.addConfigurationParameters(
        trainingAssertionAnnotator,
        AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME,
        AssertionEvaluation.GOLD_VIEW_NAME,
//        CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME,
//        dataWriterFactoryClass.getName(),
        DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
        modelOutputDirectory
        );
    trainingBuilder.add(trainingAssertionAnnotator);
    
    
    
//    CollectionReader collectionReader = XReader.getCollectionReader(
//        trainingDataDirectory);
//    collectionReader.setConfigParameterValue(XReader.PARAM_XML_SCHEME, XReader.XMI);
//    collectionReader.reconfigure();

    logger.info("starting feature generation...");
    SimplePipeline.runPipeline(
        trainingCollectionReader,
//        FilesCollectionReader.getCollectionReaderWithView(
//            "src/test/resources/data/treebank/11597317.tree",
//            TreebankConstants.TREEBANK_VIEW)
//        ,
//        TreebankGoldAnnotator.getDescriptionPOSTagsOnly(),
//        DefaultSnowballStemmer.getDescription("English"),
//        dataWriter);
        trainingBuilder.createAggregateDescription());
    logger.info("finished feature generation.");

    String[] args;
    if (trainingArgs != null && trainingArgs.length > 0) {
      args = new String[trainingArgs.length + 1];
      args[0] = modelOutputDirectory;
      System.arraycopy(trainingArgs, 0, args, 1, trainingArgs.length);
    } else {
      args = new String[] { modelOutputDirectory };
    }

    HideOutput hider = new HideOutput();
    logger.info("starting training...");
    org.cleartk.ml.jar.Train.main(args);
    logger.info("finished training.");
    hider.restoreOutput();

    AggregateBuilder decodingBuilder = new AggregateBuilder();
    
    //AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
    decodingBuilder.add(goldCopierAnnotator);
    
    //AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
    decodingBuilder.add(assertionAttributeClearerAnnotator);
    
    AnalysisEngineDescription decodingAssertionAnnotator =
      AnalysisEngineFactory.createEngineDescription(
        AssertionCleartkAnalysisEngine.class,
        AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION
        );
    ConfigurationParameterFactory.addConfigurationParameters(
        decodingAssertionAnnotator,
        AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME,
        AssertionEvaluation.GOLD_VIEW_NAME,
        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
        new File(modelOutputDirectory, "model.jar").getPath()
        );
    decodingBuilder.add(decodingAssertionAnnotator);
    
    //SimplePipeline.runPipeline(collectionReader,  builder.createEngineDescription());
    AnalysisEngineDescription decodingAggregateDescription = decodingBuilder.createAggregateDescription();
    
    
    
//    AnalysisEngineDescription taggerDescription = AnalysisEngineFactory.createEngineDescription(
//        AssertionCleartkAnalysisEngine.class,
//        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
//        //AssertionComponents.TYPE_SYSTEM_DESCRIPTION,
//        modelOutputDirectory + "/model.jar");

    logger.info("starting decoding...");
    SimplePipeline.runPipeline(
        evaluationCollectionReader,
//        BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US),
//        TokenAnnotator.getDescription(),
//        DefaultSnowballStemmer.getDescription("English"),
        //taggerDescription,
        decodingAggregateDescription,
        AnalysisEngineFactory.createEngineDescription(
            XmiWriterCasConsumerCtakes.class,
            AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION,
            XmiWriterCasConsumerCtakes.PARAM_OUTPUTDIR,
            decodingOutputDirectory));
    logger.info("finished decoding.");

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy