All Downloads are FREE. Search and download functionalities are using the official Maven repository.

de.citec.scie.Training Maven / Gradle / Ivy

Go to download

Contains the SCIE main application and the CLI interface. This project integrates the named entity recognition (NER), the PDF import and the classification and interfaces with the UIMA framework. The command line interface can be used to produce a set of UIMA XCAS files.

The newest version!
/*
 * SCIE -- Spinal Cord Injury Information Extraction
 * Copyright (C) 2013, 2014
 * Raphael Dickfelder, Jan Göpfert, Benjamin Paaßen, Andreas Stöckel
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 * 
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */

package de.citec.scie;

import de.bwaldvogel.liblinear.SolverType;
import de.citec.scie.classifiers.Classifier;
import de.citec.scie.classifiers.ClassifierEvaluation;
import de.citec.scie.classifiers.LibLinearClassifier;
import de.citec.scie.classifiers.TrainingUtils;
import de.citec.scie.classifiers.data.LabeledDataPoint;
import de.citec.scie.classifiers.data.RawRelation;
import de.citec.scie.classifiers.data.RelationDataPoint;
import de.citec.scie.classifiers.data.TrainingDataReader;
import de.citec.scie.classifiers.data.impl.AnimalRelation;
import de.citec.scie.classifiers.data.impl.DrugCore;
import de.citec.scie.classifiers.data.impl.DrugDeliveryCombination;
import de.citec.scie.classifiers.data.impl.DrugDoseCombination;
import de.citec.scie.classifiers.data.impl.InjuryRelation;
import de.citec.scie.classifiers.data.impl.InjuryTypeCore;
import de.citec.scie.classifiers.data.impl.InjuryTypeDurationCombination;
import de.citec.scie.classifiers.data.impl.InjuryTypeInjuryDeviceCombination;
import de.citec.scie.classifiers.data.impl.InjuryTypeInjuryHeightCombination;
import de.citec.scie.classifiers.data.impl.InvestigationMethodCore;
import de.citec.scie.classifiers.data.impl.InvestigationMethodPValueCombination;
import de.citec.scie.classifiers.data.impl.InvestigationMethodSignificanceCombination;
import de.citec.scie.classifiers.data.impl.InvestigationMethodTrendCombination;
import de.citec.scie.classifiers.data.impl.OrganismAgeCombination;
import de.citec.scie.classifiers.data.impl.OrganismCore;
import de.citec.scie.classifiers.data.impl.OrganismGenderCombination;
import de.citec.scie.classifiers.data.impl.OrganismNumberCombination;
import de.citec.scie.classifiers.data.impl.OrganismWeightCombination;
import de.citec.scie.classifiers.data.impl.ResultRelation;
import de.citec.scie.classifiers.data.impl.TreatmentRelation;
import de.citec.scie.typesystem.Typesystem;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.uima.fit.util.CasIOUtil;
import org.apache.uima.jcas.JCas;

/**
 * This classes purpose is to manage the training of SCIE classifiers.
 *
 * @author Benjamin Paassen - [email protected]
 */
public class Training {

	public static void train(final File[] relFiles, final String classifierSpec,
			final File outputFolder, final boolean force) {
		final ArrayList input = new ArrayList<>();
		for (final File relFile : relFiles) {
			System.out.println("Processing " + relFile.getAbsolutePath());
			try {
				final File XCASFile = new File(relFile.getAbsolutePath().replace(".rel", ".xml"));
				if (!XCASFile.exists() || !XCASFile.isFile()) {
					throw new FileNotFoundException("No fitting XCAS file was found!");
				}
				final ArrayList readRawRelations = RawRelation.readRawRelations(relFile);
				final JCas jcas = Typesystem.getJCas(Constants.TYPESYSTEM);
				CasIOUtil.readXCas(jcas, XCASFile);
				input.add(new XCASRelationTuple(jcas, readRawRelations,
						relFile.getName().replace(".rel", "")));
			} catch (IOException ex) {
				System.err.println("The file could not be processed because of "
						+ "an exception during parsing:");
				ex.printStackTrace(System.err);
			}
		}

		final Pattern classifierPattern = Pattern.compile(classifierSpec);

		//find out which classifiers should be trained.
		boolean found = false;
		for (final ClassifierTrainingConfiguration config : ClassifierTrainingConfiguration.values()) {
			final Matcher matcher = classifierPattern.matcher(config.classifierName);
			if (matcher.matches()) {
				found = true;
				/*
				 * Load dependencies for relation classifiers and construct the
				 * respective TrainingDataReader manually.
				 */
				TrainingDataReader reader = null;
				switch (config) {
					case ANIMAL_RELATION:
						final Classifier animalCore = Annotator.loadClassifier("AnimalCore", SolverType.L1R_LR);
						final Classifier animalAge = Annotator.loadClassifier("AnimalAge", SolverType.L1R_LR);
						final Classifier animalGender = Annotator.loadClassifier("AnimalGender", SolverType.L1R_LR);
						final Classifier animalNumber = Annotator.loadClassifier("AnimalNumber", SolverType.L1R_LR);
						final Classifier animalWeight = Annotator.loadClassifier("AnimalWeight", SolverType.L1R_LR);
						final Classifier[] animalSlotClassifiers
								= {animalAge, animalGender, animalNumber, animalWeight};
						reader = new AnimalRelation.TrainingDataReader(
								animalCore, animalSlotClassifiers);
						break;
					case INJURY_RELATION:
						final Classifier injuryCore = Annotator.loadClassifier(
								"InjuryCore", SolverType.L1R_LR);
						final Classifier injuryDuration = Annotator.loadClassifier(
								"InjuryDuration", SolverType.L1R_LR);
						final Classifier injuryDevice = Annotator.loadClassifier(
								"InjuryDevice", SolverType.L1R_LR);
						final Classifier injuryHeight = Annotator.loadClassifier(
								"InjuryHeight", SolverType.L1R_LR);
						final Classifier[] injurySlotClassifiers
								= {injuryDuration, injuryDevice, injuryHeight};
						reader = new InjuryRelation.TrainingDataReader(
								injuryCore, injurySlotClassifiers);
						break;
					case TREATMENT_RELATION:
						final Classifier treatmentCore = Annotator.loadClassifier(
								"TreatmentCore", SolverType.L1R_LR);
						final Classifier treatmentDelivery = Annotator.loadClassifier(
								"TreatmentDelivery", SolverType.L1R_LR);
						final Classifier treatmentDose = Annotator.loadClassifier(
								"TreatmentDose", SolverType.L1R_LR);
						final Classifier[] treatmentSlotClassifiers
								= {treatmentDelivery, treatmentDose};
						reader = new TreatmentRelation.TrainingDataReader(
								treatmentCore, treatmentSlotClassifiers);
						break;
					case RESULT_RELATION:
						final Classifier resultCore = Annotator.loadClassifier(
								"ResultCore", SolverType.L1R_LR);
						final Classifier resultPValue = Annotator.loadClassifier(
								"ResultPValue", SolverType.L1R_LR);
						final Classifier resultSignificance = Annotator.loadClassifier(
								"ResultSignificance", SolverType.L1R_LR);
						final Classifier resultTrend = Annotator.loadClassifier(
								"ResultTrend", SolverType.L1R_LR);
						final Classifier[] resultSlotClassifiers
								= {resultPValue, resultSignificance, resultTrend};
						reader = new ResultRelation.TrainingDataReader(
								resultCore, resultSlotClassifiers);
						break;
				}
				System.out.println("Training " + config.classifierName);
				/*
				 * If all pre-requisites are constructed, start the training.
				 */
				try {
					setUpClassifier(reader, input, config, outputFolder);
				} catch (IOException ex) {
					System.err.println("Classifier could not be trained because of exception:");
					ex.printStackTrace(System.err);
				}
			}
		}
		if (!found) {
			System.err.println("The classifier(s) " + classifierSpec + " was not in the list! Please select one from these:");
			for (final ClassifierTrainingConfiguration config : ClassifierTrainingConfiguration.values()) {
				System.out.println(config.classifierName);
			}
		}
	}
	/**
	 * The optimality function for Core and Slot classifiers: We try to get a
	 * high recall. Area under curve and accuracy are less important. Precision
	 * not at all.
	 */
	private static final Comparator slotOptimality = new Comparator() {

		@Override
		public int compare(ClassifierEvaluation o1, ClassifierEvaluation o2) {
			return Double.compare(getScore(o1), getScore(o2));
		}

		private double getScore(ClassifierEvaluation eval) {
			return eval.getTestEvaluation().getRecall() * 10
					+ eval.getTestEvaluation().getArea_under_roc() * 3
					+ eval.getTestEvaluation().getAccuracy();
		}

	};

	/**
	 * The optimality function for Relation classifiers: We try to get a high
	 * area under curve. F1 is also important, accuracy less.
	 */
	private static final Comparator relationOptimality = new Comparator() {

		@Override
		public int compare(ClassifierEvaluation o1, ClassifierEvaluation o2) {
			/*
			 * We calculate a score taking into account F1 measure, area
			 * under curve und accuracy (least important).
			 */
			return Double.compare(getScore(o1), getScore(o2));
		}

		private double getScore(ClassifierEvaluation eval) {
			return eval.getTestEvaluation().getF1() * 3
					+ eval.getTestEvaluation().getArea_under_roc() * 5
					+ eval.getTrainEvaluation().getAccuracy();
		}
	};

	/**
	 * This enum models the standard training configuration for each classifier.
	 */
	private static enum ClassifierTrainingConfiguration {

		/*
		 * Animal Relation
		 */
		ANIMAL_CORE(new OrganismCore.TrainingDataReader(), SolverType.L1R_LR,
				false, 5, 3, slotOptimality, 0.003, 0.01, false, "AnimalCore"),
		ANIMAL_AGE(new OrganismAgeCombination.TrainingDataReader(),
				SolverType.L1R_LR,
				false, 5, 3, slotOptimality, 0.05, 0.01, false, "AnimalAge"),
		ANIMAL_GENDER(new OrganismGenderCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.05, 0.01,
				false, "AnimalGender"),
		ANIMAL_NUMBER(new OrganismNumberCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.05, 0.01,
				false, "AnimalNumber"),
		ANIMAL_WEIGHT(new OrganismWeightCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.05, 0.01,
				false, "AnimalWeight"),
		ANIMAL_RELATION(null, SolverType.L1R_LR,
				false, 5, 3, relationOptimality, 0.6, 0.01,
				false, "AnimalRelation"),
		/*
		 * Injury Relation
		 */
		INJURY_CORE(
				new InjuryTypeCore.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.1, 0.01,
				false, "InjuryCore"),
		INJURY_DURATION(
				new InjuryTypeDurationCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.5, 0.01,
				false, "InjuryDuration"),
		INJURY_DEVICE(
				new InjuryTypeInjuryDeviceCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.6, 0.01,
				false, "InjuryDevice"),
		INJURY_HEIGHT(
				new InjuryTypeInjuryHeightCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.1, 0.01,
				false, "InjuryHeight"),
		INJURY_RELATION(null,
				SolverType.L1R_LR, false, 5, 3, relationOptimality, 0.6, 0.01,
				false, "InjuryRelation"),
		/*
		 * Treatment Relation
		 */
		TREATMENT_CORE(
				new DrugCore.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.2, 0.01,
				false, "TreatmentCore"),
		TREATMENT_DELIVERY(
				new DrugDeliveryCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.4, 0.01,
				false, "TreatmentDelivery"),
		TREATMENT_DOSE(
				new DrugDoseCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.4, 0.01,
				false, "TreatmentDose"),
		TREATMENT_RELATION(null,
				SolverType.L1R_LR, false, 5, 3, relationOptimality, 0.8, 0.01,
				false, "TreatmentRelation"),
		/*
		 * Result Relation
		 */
		RESULT_CORE(
				new InvestigationMethodCore.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.9, 0.01,
				false, "ResultCore"),
		RESULT_P_VALUE(
				new InvestigationMethodPValueCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.8, 0.01,
				false, "ResultPValue"),
		RESULT_SIGNFICANCE(
				new InvestigationMethodSignificanceCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.6, 0.01,
				false, "ResultSignificance"),
		RESULT_TREND(
				new InvestigationMethodTrendCombination.TrainingDataReader(),
				SolverType.L1R_LR, false, 5, 3, slotOptimality, 0.6, 0.01,
				false, "ResultTrend"),
		RESULT_RELATION(null,
				SolverType.L1R_LR, false, 5, 3, relationOptimality, 0.6, 0.01,
				false, "ResultRelation");

		public final TrainingDataReader reader;
		public final SolverType type;
		public final boolean cSweep;
		public final int folds;
		public final int repeats;
		public final Comparator optimality;
		public final double C;
		public final double eps;
		public final boolean verbose;
		public final String classifierName;

		private ClassifierTrainingConfiguration(TrainingDataReader reader,
				SolverType type, boolean cSweep, int folds, int repeats,
				Comparator optimality, double C,
				double eps, boolean verbose, String classifierName) {
			this.reader = reader;
			this.type = type;
			this.cSweep = cSweep;
			this.folds = folds;
			this.repeats = repeats;
			this.optimality = optimality;
			this.C = C;
			this.eps = eps;
			this.verbose = verbose;
			this.classifierName = classifierName;
		}
	}

	private static class XCASRelationTuple {

		public final JCas jcas;
		public final ArrayList rawRelations;
		public final String filename;

		public XCASRelationTuple(JCas jcas, ArrayList rawRelations, String filename) {
			this.jcas = jcas;
			this.rawRelations = rawRelations;
			this.filename = filename;
		}

	}

	/**
	 * An hand-coded sweep of the C meta-parameter of LibLinear.
	 */
	private static class CSweep implements TrainingUtils.ParameterSweep {

		private static final double[] values = {0.001, 0.005, 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1};
		private int step;
		private double optimalC;

		public CSweep() {
			this.step = 0;
			this.optimalC = values[0];
		}

		@Override
		public void initializeParameter(LibLinearClassifier classifier) {
			//nothing necessary
		}

		@Override
		public void setNextParameter(LibLinearClassifier classifier) {
			classifier.setC(values[step]);
			step++;
		}

		@Override
		public boolean hasNextParameter(LibLinearClassifier classifier) {
			return step < values.length - 1;
		}

		@Override
		public void noteOptimal(LibLinearClassifier classifier) {
			optimalC = classifier.getC();
		}

		public double getOptimalC() {
			return optimalC;
		}

	}

	private static Classifier setUpClassifier(TrainingDataReader reader,
			final ArrayList input,
			ClassifierTrainingConfiguration config,
			final File outputFolder) throws IOException {
		if (config.reader != null) {
			reader = config.reader;
		}
		if (reader == null) {
			throw new UnsupportedOperationException("No TrainingDataReader was given!");
		}
		//Read training data
		System.out.println("Start creating training data");
		final ArrayList trainingData = new ArrayList<>();
		for (final XCASRelationTuple tuple : input) {
			if (reader instanceof RelationDataPoint.RelationTrainingDataReader) {
				System.out.println("Processing " + tuple.filename);
			}
			reader.readTrainingData(tuple.jcas, tuple.rawRelations, trainingData);
		}
		int N = trainingData.size();
		System.out.println("Training " + config.classifierName + " using " + N
				+ " data points (" + config.folds + "-fold cross validation).");
		//balance data.
		TrainingUtils.balanceDataSet(trainingData);
		if (N < trainingData.size()) {
			System.out.println("Balancing was necessary: Now using "
					+ trainingData.size() + " data points.");
		}
		//Training
		final LibLinearClassifier out = new LibLinearClassifier(config.type);
		final ClassifierEvaluation eval;
		try {
			if (config.cSweep) {
				final CSweep cSweep = new CSweep();
				eval = TrainingUtils.crossValidationSweep(trainingData, out,
						config.folds, config.repeats, cSweep,
						config.optimality, config.verbose);
				out.setC(cSweep.getOptimalC());
				System.out.println("Optimum C value found at " + out.getC());
			} else {
				out.setC(config.C);
				eval = TrainingUtils.crossValidation(trainingData,
						out, config.folds, config.repeats,
						config.optimality, config.verbose);
			}
		} catch (UnsupportedOperationException ex) {
			System.err.println("Warning: " + ex.getMessage());
			return out;
		}

		final File modelFile = new File(outputFolder, config.classifierName + ".model");
		System.out.println("Storing model in " + modelFile.getAbsolutePath());
		final BufferedWriter modelOut = new BufferedWriter(new FileWriter(modelFile));
		final File featureFile = new File(outputFolder, config.classifierName + ".features");
		System.out.println("Storing features in " + featureFile.getAbsolutePath());
		final BufferedWriter featuresOut = new BufferedWriter(new FileWriter(featureFile));
		out.writeParamters(modelOut, featuresOut);
		modelOut.close();
		featuresOut.close();

		final File evalFile = new File(outputFolder, config.classifierName + ".eval");
		System.out.println("Storing evaluation in " + evalFile.getAbsolutePath());
		final BufferedWriter evalOut = new BufferedWriter(new FileWriter(evalFile));
		eval.store(evalOut);
		evalOut.close();

		return out;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy