All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.openminted.uc.socialsciences.ner.train.StanfordNERTrainer Maven / Gradle / Ivy

The newest version!
package eu.openminted.uc.socialsciences.ner.train;

import java.util.Properties;

import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.kohsuke.args4j.Option;

import edu.stanford.nlp.ie.crf.CRFClassifier;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.util.StringUtils;
import eu.openminted.uc.socialsciences.common.CommandLineArgumentHandler;

public class StanfordNERTrainer {
	private static final Logger logger = LogManager.getLogger(StanfordNERTrainer.class);
	private static final String DEFAULT_OUTPUT_PATH = "omtd-ner-model.ser.gz";

	@Option(name="-i", usage="path to labeled input data (.tsv file)", required = true)
	private String input = null;

	@Option(name="-t", usage="path to training properties file", required = true)
	private String trainingPropertiesFile = null;

	@Option(name="-o", usage="[optional] path to save the model")
	private String output = DEFAULT_OUTPUT_PATH;

	public static void main(String[] args)
	{
		new StanfordNERTrainer().run(args);
	}

	private void run(String[] args)
	{
		new CommandLineArgumentHandler().parseInput(args, this);

		logger.info("Reading properties file from [" + trainingPropertiesFile + "]");
		logger.info("Reading training file from [" + input + "]");

		if (output.equals(DEFAULT_OUTPUT_PATH))
		{
			logger.info("No path for saving the trained model was specified. Default value will be used.");
		}
		logger.info("Will write the trained model to [" + output + "]");

		/*
		 * options: IOB1, IOB2, IOE1, IOE2, SBIEO, IO, BIO, BILOU, noprefix
		 */
		String classification = "noprefix";
		/*
		 * if false, representation will be mapped back to IOB1 on output
		 */
		boolean retainClassification = true;

		StanfordNERTrainer trainModel = new StanfordNERTrainer();
		trainModel.trainCrf(output, trainingPropertiesFile, input, classification, retainClassification);
		logger.info("Model training finished.");
	}

	private void trainCrf(String serializeFile, String prop, String trainFile, String classification,
			boolean retainClassification) {
		Properties props = StringUtils.propFileToProperties(prop);
		if (serializeFile != null) {
			props.setProperty("serializeTo", serializeFile);
		}
		if (trainFile != null) {
			props.setProperty("trainFile", trainFile);
		}
		props.setProperty("entitySubclassification", classification);
		SeqClassifierFlags flags = new SeqClassifierFlags(props);
		flags.retainEntitySubclassification = retainClassification;
		flags.readerAndWriter="edu.stanford.nlp.sequences.CoNLLDocumentReaderAndWriter";
		CRFClassifier crf = new CRFClassifier<>(flags);
		crf.train();
		crf.serializeClassifier(serializeFile);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy