
cc.mallet.classify.tui.Vectors2Classify Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.tui;
import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;
import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
/**
* Classify documents, run trials, print statistics from a vector file.
@author Andrew McCallum [email protected]
*/
public abstract class Vectors2Classify
{
static BshInterpreter interpreter = new BshInterpreter();
private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
private static ArrayList classifierTrainerStrings = new ArrayList();
private static int dataOptionsSize = 3;
private static int reportOptionsSize = 6;
private static boolean[][] ReportOptions = new boolean[dataOptionsSize][reportOptionsSize];
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption
{
static final String[] dataOptions = {"train", "test", "validation"};
static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw", "precision", "recall"};
static final int train=0;
static final int test =1;
static final int validation=2;
static final int accuracy=0;
static final int f1=1;
static final int confusion=2;
static final int raw=3;
static final int precision=4;
static final int recall=5;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
(Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1:label|precision:label|recall:label|confusion|raw]",
true, new String[] {"train:accuracy", "test:accuracy", "test:confusion", "test:precision", "test:recall", "test:f1"},
"", null)
{
public void postParsing (CommandOption.List list)
{
java.lang.String defaultRawFormatting = "siw";
for (int argi=0; argi=3){
reportOptionArg = fields[2];
}
//System.out.println("Report option arg " + reportOptionArg);
//find the datasource (test,train,validation)
boolean foundDataSource = false;
int i=0;
for (; i 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(), unlabeledProportionOption.value);
//InfoGain ig = new InfoGain (ilists[0]);
//int igl = Math.min (10, ig.numLocations());
//for (int i = 0; i < igl; i++)
//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
//ig.print();
//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
//ilists[0].setFeatureSelection (selectedFeatures);
//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
//orfi.induceFeatures (ilists[0], false, true);
//System.out.println ("Training with "+ilists[0].size()+" instances");
long time[] = new long[numTrainers];
for (int c = 0; c < numTrainers; c++){
time[c] = System.currentTimeMillis();
ClassifierTrainer trainer = getTrainer(classifierTrainerStrings.get(c));
trainer.setValidationInstances(ilists[2]);
System.out.println ("Trial " + trialIndex + " Training " + trainer + " with "+ilists[0].size()+" instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier = trainer.train (ilists[0]);
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
System.out.println ("Trial " + trialIndex + " Training " + trainer.toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial (classifier, ilists[0]);
//assert (ilists[1].size() > 0);
Trial testTrial = new Trial (classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.confusion] && ilists[0].size()>0)
trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
if (ReportOptions[ReportOption.test][ReportOption.confusion] && ilists[1].size()>0)
testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
if (ReportOptions[ReportOption.validation][ReportOption.confusion] && ilists[2].size()>0)
validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
trainPrecision[c][k][trialIndex] = trainTrial.getPrecision(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
testPrecision[c][k][trialIndex] = testTrial.getPrecision(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
validationPrecision[c][k][trialIndex] = validationTrial.getPrecision(labels[k]);
}
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
trainRecall[c][k][trialIndex] = trainTrial.getRecall(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
testRecall[c][k][trialIndex] = testTrial.getRecall(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
validationRecall[c][k][trialIndex] = validationTrial.getRecall(labels[k]);
}
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
trainF1[c][k][trialIndex] = trainTrial.getF1(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
testF1[c][k][trialIndex] = testTrial.getF1(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
validationF1[c][k][trialIndex] = validationTrial.getF1(labels[k]);
}
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (numTrainers > 1) filename = filename+trainer.toString();
if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream
(new FileOutputStream (filename));
oos.writeObject (classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write classifier to filename "+
filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
//train
if (ReportOptions[ReportOption.train][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Training Data Confusion Matrix");
if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Precision(" + labels[k] + ") = "+ trainTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Recall(" + labels[k] + ") = "+ trainTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data F1(" + labels[k] + ") = "+ trainTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data accuracy = "+ trainAccuracy[c][trialIndex]);
}
//validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Validation Data Confusion Matrix");
if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data precision(" + labels[k] + ") = "+ validationTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data recall(" + labels[k] + ") = "+ validationTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data F1(" + labels[k] + ") = "+ validationTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data accuracy = "+ validationAccuracy[c][trialIndex]);
}
//test
if (ReportOptions[ReportOption.test][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Test Data Confusion Matrix");
if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data precision(" + labels[k] + ") = "+ testTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data recall(" + labels[k] + ") = "+ testTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data F1(" + labels[k] + ") = "+ testTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data accuracy = "+ testAccuracy[c][trialIndex]);
}
if (trialIndex == 0) trainerNames[c] = trainer.toString();
} // end for each trainer
} // end for each trial
// New reporting
//"[train|test|validation]:[accuracy|f1|confusion|raw|precision|recall]"
for (int c=0; c < numTrainers; c++) {
System.out.println ("\n"+trainerNames[c].toString());
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
if (ReportOptions[ReportOption.train][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train precision("+labels[k]+") mean = "+ MatrixOps.mean (trainPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (trainPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (trainPrecision[c][k]));
}
if (ReportOptions[ReportOption.train][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train recall("+labels[k]+") mean = "+ MatrixOps.mean (trainRecall[c][k])+
" stddev = "+ MatrixOps.stddev (trainRecall[c][k])+
" stderr = "+ MatrixOps.stderr (trainRecall[c][k]));
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train f1("+labels[k]+") mean = "+ MatrixOps.mean (trainF1[c][k])+
" stddev = "+ MatrixOps.stddev (trainF1[c][k])+
" stderr = "+ MatrixOps.stderr (trainF1[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
if (ReportOptions[ReportOption.validation][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation precision("+labels[k]+") mean = "+ MatrixOps.mean (validationPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (validationPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (validationPrecision[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation recall("+labels[k]+") mean = "+ MatrixOps.mean (validationRecall[c][k])+
" stddev = "+ MatrixOps.stddev (validationRecall[c][k])+
" stderr = "+ MatrixOps.stderr (validationRecall[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation f1("+labels[k]+") mean = "+ MatrixOps.mean (validationF1[c][k])+
" stddev = "+ MatrixOps.stddev (validationF1[c][k])+
" stderr = "+ MatrixOps.stderr (validationF1[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
if (ReportOptions[ReportOption.test][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test precision("+labels[k]+") mean = "+ MatrixOps.mean (testPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (testPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (testPrecision[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test recall("+labels[k]+") mean = "+ MatrixOps.mean (testRecall[c][k])+
" stddev = "+ MatrixOps.stddev (testRecall[c][k])+
" stderr = "+ MatrixOps.stderr (testRecall[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test f1("+labels[k]+") mean = "+ MatrixOps.mean (testF1[c][k])+
" stddev = "+ MatrixOps.stddev (testF1[c][k])+
" stderr = "+ MatrixOps.stderr (testF1[c][k]));
}
} // end for each trainer
}
private static void printTrialClassification(Trial trial)
{
for (Classification c : trial) {
Instance instance = c.getInstance();
System.out.print(instance.getName() + " " + instance.getTarget() + " ");
Labeling labeling = c.getLabeling();
for (int j = 0; j < labeling.numLocations(); j++){
System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
}
System.out.println();
}
}
private static Object createTrainer(String arg) {
try {
return interpreter.eval (arg);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException ("Java interpreter eval error\n"+e);
}
}
private static ClassifierTrainer getTrainer(String arg) {
// parse something like Maxent,gaussianPriorVariance=10,numIterations=20
// first, split the argument at commas.
java.lang.String fields[] = arg.split(",");
//Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
// all call new MaxEntTrainer()
java.lang.String constructorName = fields[0];
Object trainer;
if (constructorName.indexOf('(') != -1) // if contains (), pass it though
trainer = createTrainer(arg);
else {
if (constructorName.endsWith("Trainer")){
trainer = createTrainer("new " + constructorName + "()"); // add parens if they forgot
}else{
trainer = createTrainer("new "+constructorName+"Trainer()"); // make trainer name from classifier name
}
}
// find methods associated with the class we just built
Method methods[] = trainer.getClass().getMethods();
// find setters corresponding to parameter names.
for (int i=1; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy