
de.tudarmstadt.ukp.dkpro.tc.mallet.util.TaskUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of dkpro-tc-mallet Show documentation
Show all versions of dkpro-tc-mallet Show documentation
Interface to the Mallet Machine Learning Toolkit
The newest version!
/*******************************************************************************
* Copyright 2014
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package de.tudarmstadt.ukp.dkpro.tc.mallet.util;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.regex.Pattern;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.lang.StringUtils;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.NoopTransducerTrainer;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import de.tudarmstadt.ukp.dkpro.tc.api.exception.TextClassificationException;
import de.tudarmstadt.ukp.dkpro.tc.mallet.report.ReportConstants;
public class TaskUtils {
private static ArrayList precisionValues;
private static ArrayList recallValues;
private static ArrayList f1Values;
private static ArrayList labels;
public static CRF trainCRF(InstanceList training, CRF crf, double gaussianPriorVariance, int iterations, String defaultLabel,
boolean fullyConnected, int[] orders) {
if (crf == null) {
crf = new CRF(training.getPipe(), (Pipe)null);
String startName =
crf.addOrderNStates(training, orders, null,
defaultLabel, null, null,
fullyConnected);
for (int i = 0; i < crf.numStates(); i++) {
crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT);
}
crf.getState(startName).setInitialWeight(0.0);
}
// logger.info("Training on " + training.size() + " instances");
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
crft.setGaussianPriorVariance(gaussianPriorVariance);
boolean converged;
for (int i = 1; i <= iterations; i++) {
converged = crft.train (training, 1);
if (converged) {
break;
}
}
return crf;
}
public static void runTrainCRF(File trainingFile, File modelFile, double var, int iterations, String defaultLabel,
boolean fullyConnected, int[] orders, boolean denseFeatureValues) throws FileNotFoundException, IOException, ClassNotFoundException {
Reader trainingFileReader = null;
InstanceList trainingData = null;
//trainingFileReader = new FileReader(trainingFile);
trainingFileReader = new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFile)), "UTF-8");
Pipe p = null;
CRF crf = null;
p = new ConversionToFeatureVectorSequence(denseFeatureValues); //uses first line of file to identify DKProInstanceID feature and discard
p.getTargetAlphabet().lookupIndex(defaultLabel);
p.setTargetProcessing(true);
trainingData = new InstanceList(p);
trainingData.addThruPipe(new LineGroupIterator(trainingFileReader,
Pattern.compile("^\\s*$"), true)); //if you want to skip the line containing feature names, add "|^[A-Za-z]+.*$"
// logger.info
// ("Number of features in training data: "+p.getDataAlphabet().size());
// logger.info ("Number of predicates: "+p.getDataAlphabet().size());
if (p.isTargetProcessing())
{
Alphabet targets = p.getTargetAlphabet();
StringBuffer buf = new StringBuffer("Labels:");
for (int i = 0; i < targets.size(); i++)
{
buf.append(" ").append(targets.lookupObject(i).toString());
// logger.info(buf.toString());
}
}
crf = trainCRF(trainingData, crf, var, iterations, defaultLabel, fullyConnected, orders);
ObjectOutputStream s =
new ObjectOutputStream(new FileOutputStream(modelFile));
s.writeObject(crf);
s.close();
}
public static void test(TransducerTrainer tt, TransducerEvaluator eval,
InstanceList testing)
{
eval.evaluateInstanceList(tt, testing, "Testing");
}
public static TransducerEvaluator runTestCRF(File testFile, File modelFile) throws FileNotFoundException, IOException, ClassNotFoundException {
Reader testFileReader = null;
InstanceList testData = null;
//testFileReader = new FileReader(testFile);
testFileReader = new InputStreamReader(new GZIPInputStream(new FileInputStream(testFile)), "UTF-8");
Pipe p = null;
CRF crf = null;
TransducerEvaluator eval = null;
ObjectInputStream s =
new ObjectInputStream(new FileInputStream(modelFile));
crf = (CRF) s.readObject();
s.close();
p = crf.getInputPipe();
p.setTargetProcessing(true);
testData = new InstanceList(p);
testData.addThruPipe(
new LineGroupIterator(testFileReader,
Pattern.compile("^\\s*$"), true));
// logger.info ("Number of predicates: "+p.getDataAlphabet().size());
eval = new PerClassEvaluator(new InstanceList[] {testData}, new String[] {"Testing"});
if (p.isTargetProcessing())
{
Alphabet targets = p.getTargetAlphabet();
StringBuffer buf = new StringBuffer("Labels:");
for (int i = 0; i < targets.size(); i++)
{
buf.append(" ").append(targets.lookupObject(i).toString());
// logger.info(buf.toString());
}
}
test(new NoopTransducerTrainer(crf), eval, testData);
labels = ((PerClassEvaluator) eval).getLabelNames();
precisionValues = ((PerClassEvaluator) eval).getPrecisionValues();
recallValues = ((PerClassEvaluator) eval).getRecallValues();
f1Values = ((PerClassEvaluator) eval).getF1Values();
return eval;
}
public static TransducerEvaluator runTrainTest(File trainFile, File testFile, File modelFile,
double var, int iterations, String defaultLabel,
boolean fullyConnected, int[] orders, String tagger, boolean denseFeatureValues) throws FileNotFoundException, ClassNotFoundException, IOException, TextClassificationException {
TransducerEvaluator eval = null;
if (tagger.equals("CRF")) {
runTrainCRF(trainFile,modelFile, var, iterations, defaultLabel, fullyConnected, orders, denseFeatureValues);
eval = runTestCRF(testFile, modelFile);
printEvaluationMeasures();
}
else if (tagger.equals("HMM")){
throw new TextClassificationException("'HMM' is not currently supported.");
//runTrainHMM(trainFile,modelFile, defaultLabel, iterations, denseFeatureValues);
//eval = runTestHMM(testFile, modelFile);
}
else {
throw new TextClassificationException("Unsupported tagger name for sequence tagging. Supported taggers are 'CRF' and 'HMM'.");
}
return eval;
}
//FIXME HMM is not currently supported (uncomment and use a different vector sequence compatible to HMM utilities
//in Mallet) @author krishperumal11
// public static void runTrainHMM(File trainingFile, File modelFile, String defaultLabel, int iterations, boolean denseFeatureValues) throws FileNotFoundException, IOException {
// Reader trainingFileReader = null;
// InstanceList trainingData = null;
// //trainingFileReader = new FileReader(trainingFile);
// trainingFileReader = new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFile)));
// Pipe p = null;
// p = new ConversionToFeatureVectorSequence(denseFeatureValues); //uses first line of file to identify DKProInstanceID feature and discard
// p.getTargetAlphabet().lookupIndex(defaultLabel);
// p.setTargetProcessing(true);
// trainingData = new InstanceList(p);
// trainingData.addThruPipe(new LineGroupIterator(trainingFileReader,
// Pattern.compile("^\\s*$"), true)); //if you want to skip the line containing feature names, add "|^[A-Za-z]+.*$"
// // logger.info
// // ("Number of features in training data: "+p.getDataAlphabet().size());
//
// // logger.info ("Number of predicates: "+p.getDataAlphabet().size());
//
// if (p.isTargetProcessing())
// {
// Alphabet targets = p.getTargetAlphabet();
// StringBuffer buf = new StringBuffer("Labels:");
// for (int i = 0; i < targets.size(); i++)
// buf.append(" ").append(targets.lookupObject(i).toString());
// // logger.info(buf.toString());
// }
//
// HMM hmm = null;
// hmm = trainHMM(trainingData, hmm, iterations);
// ObjectOutputStream s =
// new ObjectOutputStream(new FileOutputStream(modelFile));
// s.writeObject(hmm);
// s.close();
// }
//
// public static HMM trainHMM(InstanceList training, HMM hmm, int numIterations) throws IOException {
// if (hmm == null) {
// hmm = new HMM(training.getPipe(), null);
// hmm.addStatesForLabelsConnectedAsIn(training);
// //hmm.addStatesForBiLabelsConnectedAsIn(trainingInstances);
//
// HMMTrainerByLikelihood trainer =
// new HMMTrainerByLikelihood(hmm);
//
// trainer.train(training, numIterations);
//
// //trainingEvaluator.evaluate(trainer);
// }
// return hmm;
// }
//
// public static TransducerEvaluator runTestHMM(File testFile, File modelFile) throws FileNotFoundException, IOException, ClassNotFoundException {
// ArrayList pipes = new ArrayList();
//
// pipes.add(new SimpleTaggerSentence2TokenSequence());
// pipes.add(new TokenSequence2FeatureSequence());
//
// Pipe pipe = new SerialPipes(pipes);
//
// InstanceList testData = new InstanceList(pipe);
//
// testData.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(testFile)))), Pattern.compile("^\\s*$"), true));
//
// TransducerEvaluator eval =
// new PerClassEvaluator(testData, "testing");
//
// ObjectInputStream s =
// new ObjectInputStream(new FileInputStream(modelFile));
// HMM hmm = (HMM) s.readObject();
//
// test(new NoopTransducerTrainer(hmm), eval, testData);
// labels = ((PerClassEvaluator) eval).getLabelNames();
// precisionValues = ((PerClassEvaluator) eval).getPrecisionValues();
// recallValues = ((PerClassEvaluator) eval).getRecallValues();
// f1Values = ((PerClassEvaluator) eval).getF1Values();
// return eval;
// }
public static void printEvaluationMeasures() {
double values[][] = new double[labels.size()][3];
Iterator itPrecision = precisionValues.iterator();
Iterator itRecall = recallValues.iterator();
Iterator itF1 = f1Values.iterator();
int i = 0;
while(itPrecision.hasNext()) {
values[i++][0] = itPrecision.next();
}
i = 0;
while(itRecall.hasNext()) {
values[i++][1] = itRecall.next();
}
i = 0;
while(itF1.hasNext()) {
values[i++][2] = itF1.next();
}
Iterator itLabels = labels.iterator();
for(i=0; i predictedLabels = ((PerClassEvaluator) eval).getPredictedLabels();
BufferedReader br = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(fileTest)), "UTF-8"));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(filePredictions)), "UTF-8"));
String line;
boolean header = false;
int i = 0;
while ((line = br.readLine()) != null) {
if (!header) {
bw.write(line + " " + predictionClassLabelName);
bw.flush();
header = true;
continue;
}
if (!line.isEmpty()) {
bw.write("\n" + line + " " + predictedLabels.get(i++));
bw.flush();
}
else {
bw.write("\n");
bw.flush();
}
}
br.close();
bw.close();
}
public static void outputEvaluation(TransducerEvaluator eval, File fileEvaluation) throws IOException {
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileEvaluation), "UTF-8"));
ArrayList labelNames = ((PerClassEvaluator) eval).getLabelNames();
ArrayList precisionValues = ((PerClassEvaluator) eval).getPrecisionValues();
ArrayList recallValues = ((PerClassEvaluator) eval).getRecallValues();
ArrayList f1Values = ((PerClassEvaluator) eval).getF1Values();
int numLabels = labelNames.size();
bw.write("Measure,Value");
bw.write("\n" + ReportConstants.CORRECT + "," + ((PerClassEvaluator) eval).getNumberOfCorrectPredictions());
bw.write("\n" + ReportConstants.INCORRECT + "," + ((PerClassEvaluator) eval).getNumberOfIncorrectPredictions());
bw.write("\n" + ReportConstants.NUMBER_EXAMPLES + "," + ((PerClassEvaluator) eval).getNumberOfExamples());
bw.write("\n" + ReportConstants.PCT_CORRECT + "," + ((PerClassEvaluator) eval).getPercentageOfCorrectPredictions());
bw.write("\n" + ReportConstants.PCT_INCORRECT + "," + ((PerClassEvaluator) eval).getPercentageOfIncorrectPredictions());
for (int i = 0; i < numLabels; i++) {
String label = labelNames.get(i);
bw.write("\n" + ReportConstants.PRECISION + "_" + label + "," + precisionValues.get(i));
bw.write("\n" + ReportConstants.RECALL + "_" + label + "," + recallValues.get(i));
bw.write("\n" + ReportConstants.FMEASURE + "_" + label + "," + f1Values.get(i));
bw.flush();
}
bw.write("\n" + ReportConstants.MACRO_AVERAGE_FMEASURE + "," + ((PerClassEvaluator) eval).getMacroAverage());
bw.flush();
bw.close();
}
public static void outputConfusionMatrix(TransducerEvaluator eval, File fileConfusionMatrix) throws IOException {
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileConfusionMatrix), "UTF-8"));
ArrayList labelNames = ((PerClassEvaluator) eval).getLabelNames();
int numLabels = labelNames.size();
HashMap labelNameToIndexMap = new HashMap();
for (int i = 0; i < numLabels; i++) {
labelNameToIndexMap.put(labelNames.get(i), i);
}
ArrayList goldLabels = ((PerClassEvaluator) eval).getGoldLabels();
ArrayList predictedLabels = ((PerClassEvaluator) eval).getPredictedLabels();
Integer[][] confusionMatrix = new Integer[numLabels][numLabels];
//initialize to 0
for (int i = 0; i < confusionMatrix.length; i++) {
for (int j = 0; j < confusionMatrix.length; j++) {
confusionMatrix[i][j] = 0;
}
}
for (int i = 0; i < goldLabels.size(); i++) {
confusionMatrix[labelNameToIndexMap.get(goldLabels.get(i))][labelNameToIndexMap.get(predictedLabels.get(i))]++;
}
String[][] confusionMatrixString = new String[numLabels + 1][numLabels + 1];
confusionMatrixString[0][0] = " ";
for (int i = 1; i < numLabels + 1; i++) {
confusionMatrixString[i][0] = labelNames.get(i-1) + "_actual";
confusionMatrixString[0][i] = labelNames.get(i-1) + "_predicted";
}
for (int i = 1; i < numLabels + 1; i++) {
for (int j = 1; j < numLabels + 1; j++) {
confusionMatrixString[i][j] = confusionMatrix[i-1][j-1].toString();
}
}
for (String[] element : confusionMatrixString) {
bw.write(StringUtils.join(element, ","));
bw.write("\n");
bw.flush();
}
bw.close();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy