gate.plugin.learningframework.engines.EngineMBMalletClass Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of learningframework Show documentation
Show all versions of learningframework Show documentation
A GATE plugin that provides many different machine learning
algorithms for a wide range of NLP-related machine learning tasks like
text classification, tagging, or chunking.
/*
* Copyright (c) 2015-2016 The University Of Sheffield.
*
* This file is part of gateplugin-LearningFramework
* (see https://github.com/GateNLP/gateplugin-LearningFramework).
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 2.1 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this software. If not, see .
*/
package gate.plugin.learningframework.engines;
import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.C45Trainer;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.DecisionTreeTrainer;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.WinnowTrainer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.InstanceList.CrossValidationIterator;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import gate.Annotation;
import gate.AnnotationSet;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import static gate.plugin.learningframework.engines.Engine.FILENAME_MODEL;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.log4j.Logger;
import static gate.plugin.learningframework.LFUtils.newURL;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.text.SimpleDateFormat;
import java.util.Date;
/**
*
* @author Johann Petrak
*/
public class EngineMBMalletClass extends EngineMBMallet {
private static final Logger LOGGER = Logger.getLogger(EngineMBMalletClass.class);
public EngineMBMalletClass() { }
@Override
public void trainModel(File dataDirectory, String instanceType, String parms) {
//System.err.println("EngineMalletClass.trainModel: trainer="+trainer);
//System.err.println("EngineMalletClass.trainModel: CR="+corpusRepresentation);
model=((ClassifierTrainer) trainer).train(corpusRepresentation.getRepresentationMallet());
updateInfo();
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
info.modelWhenTrained = sdf.format(new Date());
info.algorithmParameters = parms;
info.save(dataDirectory);
featureInfo.save(dataDirectory);
}
@Override
public List applyModel(
AnnotationSet instanceAS, AnnotationSet inputAS, AnnotationSet sequenceAS, String parms) {
// NOTE: the crm should be of type CorpusRepresentationMalletClass for this to work!
if(!(corpusRepresentation instanceof CorpusRepresentationMalletTarget)) {
throw new GateRuntimeException("Cannot perform classification with data from "+corpusRepresentation.getClass());
}
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentation;
data.stopGrowth();
List gcs = new ArrayList<>();
LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe();
Classifier classifier = (Classifier)model;
// iterate over the instance annotations and create mallet instances
for(Annotation instAnn : instanceAS.inDocumentOrder()) {
Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
inst = pipe.instanceFrom(inst);
Classification classification = classifier.classify(inst);
Labeling labeling = classification.getLabeling();
LabelVector labelvec = labeling.toLabelVector();
List classes = new ArrayList<>(labelvec.numLocations());
List confidences = new ArrayList<>(labelvec.numLocations());
for(int i=0; i trainerClass = algorithm.getTrainerClass();
try {
@SuppressWarnings("unchecked")
Constructor tmpc = trainerClass.getDeclaredConstructor();
trainer = tmpc.newInstance();
} catch (IllegalAccessException | IllegalArgumentException |
InstantiationException | NoSuchMethodException |
SecurityException | InvocationTargetException ex) {
throw new GateRuntimeException("Could not create trainer instance for " + trainerClass, ex);
}
} else {
// there are parameters, so if it is one of the algorithms were we support setting
// a parameter do this
if (algorithm.equals(AlgorithmClassification.MalletC45_CL_MR)) {
Parms ps = new Parms(parms, "m:maxDepth:i", "p:prune:B","n:minNumInsts:i");
int maxDepth = (int)ps.getValueOrElse("maxDepth", 0);
int minNumInsts = (int)ps.getValueOrElse("minNumInsts", 2);
boolean prune = (boolean)ps.getValueOrElse("prune",true);
C45Trainer c45trainer;
if(maxDepth > 0) {
if(!prune) {
c45trainer = new C45Trainer(maxDepth,false);
} else {
c45trainer = new C45Trainer(maxDepth,true);
}
} else {
c45trainer = new C45Trainer(prune);
}
c45trainer.setMinNumInsts(minNumInsts);
trainer = c45trainer;
} else if(algorithm.equals(AlgorithmClassification.MalletDecisionTree_CL_MR)) {
DecisionTreeTrainer dtTrainer = new DecisionTreeTrainer();
Parms ps = new Parms(parms, "m:maxDepth:i", "i:minInfoGainSplit:d");
int maxDepth = (int)ps.getValueOrElse("maxDepth", DecisionTreeTrainer.DEFAULT_MAX_DEPTH);
double minIGS = (double)ps.getValueOrElse("minInfoGainSplit",DecisionTreeTrainer.DEFAULT_MIN_INFO_GAIN_SPLIT);
dtTrainer.setMaxDepth(maxDepth);
dtTrainer.setMinInfoGainSplit(minIGS);
trainer = dtTrainer;
} else if(algorithm.equals(AlgorithmClassification.MalletMaxEnt_CL_MR)) {
MaxEntTrainer tr = new MaxEntTrainer();
Parms ps = new Parms(parms, "v:gaussianPriorVariance:d",
"l:l1Weight:d", "i:numIterations:i");
// TODO: the default values cannot be taken from MaxEntTrainer because
// they are not public there
double gaussianPriorVariance = (double)ps.getValueOrElse("gaussianPriorVariance", 1.0);
tr.setGaussianPriorVariance(gaussianPriorVariance);
double l1Weight = (double)ps.getValueOrElse("l1Weight", 0.0);
tr.setL1Weight(l1Weight);
int iters = (int)ps.getValueOrElse("numIterations", Integer.MAX_VALUE);
tr.setNumIterations(iters);
trainer = tr;
// NOTE: for AdaBoost, use this method recursively to first initialize
// the trainer with the base trainer. The parameters should contain
// something like -A ALGNAME -N numRounds -a -b ...
// where ALGNAME is an AlgorithmClassification constant and N is the
// numRounds parameter for AdaBoost[M2] and all the other parameters
// are for the base algorithm initialization
} else if(algorithm.equals(AlgorithmClassification.MalletBalancedWinnow_CL_MR)) {
Parms ps = new Parms(parms, "e:epsilon:d",
"d:delta:d", "i:maxIterations:i", "c:coolingRate:d");
double epsilon = (double)ps.getValueOrElse("epsilon", BalancedWinnowTrainer.DEFAULT_EPSILON);
double delta = (double)ps.getValueOrElse("delta", BalancedWinnowTrainer.DEFAULT_DELTA);
int iters = (int)ps.getValueOrElse("int", BalancedWinnowTrainer.DEFAULT_MAX_ITERATIONS);
double cr = (double)ps.getValueOrElse("coolingRate", BalancedWinnowTrainer.DEFAULT_COOLING_RATE);
trainer = new BalancedWinnowTrainer(epsilon,delta,iters,cr);
} else if(algorithm.equals(AlgorithmClassification.MalletWinnow_CL_MR)) {
Parms ps = new Parms(parms, "a:alpha:d",
"b:beta:d", "n:nfact:d");
double alpha = (double)ps.getValueOrElse("alpha", 2.0);
double beta = (double)ps.getValueOrElse("beta", 2.0);
double nfact = (double)ps.getValueOrElse("nfact", 0.5);
trainer = new WinnowTrainer(alpha, beta, nfact);
} else {
// all other algorithms are still just instantiated from the class name, we ignore
// the parameters
LOGGER.warn("IMPORTANT: parameters ignored when creating Mallet trainer " + algorithm.getTrainerClass());
Class trainerClass = algorithm.getTrainerClass();
try {
@SuppressWarnings("unchecked")
Constructor tmpc = trainerClass.getDeclaredConstructor();
trainer = tmpc.newInstance();
} catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException ex) {
throw new GateRuntimeException("Could not create trainer instance for " + trainerClass, ex);
}
}
}
}
@Override
protected void loadModel(URL directory, String parms) {
URL modelFile = newURL(directory, FILENAME_MODEL);
Classifier classifier;
try (InputStream is = modelFile.openStream();
ObjectInputStream ois = new ObjectInputStream(is)) {
classifier = (Classifier) ois.readObject();
model=classifier;
} catch (Exception ex) {
throw new GateRuntimeException("Could not load Mallet model", ex);
}
}
@Override
public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) {
EvaluationResult ret = null;
Parms parms = new Parms(algorithmParameters,"s:seed:i");
int seed = (Integer)parms.getValueOrElse("seed", 1);
if(evaluationMethod == EvaluationMethod.CROSSVALIDATION) {
CrossValidationIterator cvi = corpusRepresentation.getRepresentationMallet().crossValidationIterator(numberOfFolds, seed);
if(algorithm instanceof AlgorithmClassification) {
double sumOfAccs = 0.0;
while(cvi.hasNext()) {
InstanceList[] il = cvi.nextSplit();
InstanceList trainSet = il[0];
InstanceList testSet = il[1];
Classifier cl = ((ClassifierTrainer) trainer).train(trainSet);
sumOfAccs += cl.getAccuracy(testSet);
}
EvaluationResultClXval e = new EvaluationResultClXval();
//e.internalEvaluationResult = null;
e.accuracyEstimate = sumOfAccs/numberOfFolds;
e.nrFolds = numberOfFolds;
ret = e;
} else {
throw new GateRuntimeException("Mallet evaluation: not available for regression!");
}
} else {
if(algorithm instanceof AlgorithmClassification) {
Random rnd = new Random(seed);
double sumOfAccs = 0.0;
for(int i = 0; i