
cc.mallet.classify.tui.Vectors2Classify Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.tui;
import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;
import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
/**
* Classify documents, run trials, print statistics from a vector file.
@author Andrew McCallum [email protected]
*/
public abstract class Vectors2Classify
{
static BshInterpreter interpreter = new BshInterpreter();
private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
private static ArrayList classifierTrainerStrings = new ArrayList();
private static int dataOptionsSize = 3;
private static int reportOptionsSize = 6;
private static boolean[][] ReportOptions = new boolean[dataOptionsSize][reportOptionsSize];
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption
{
static final String[] dataOptions = {"train", "test", "validation"};
static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw", "precision", "recall"};
static final int train=0;
static final int test =1;
static final int validation=2;
static final int accuracy=0;
static final int f1=1;
static final int confusion=2;
static final int raw=3;
static final int precision=4;
static final int recall=5;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
(Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1:label|precision:label|recall:label|confusion|raw]",
true, new String[] {"train:accuracy", "test:accuracy", "test:confusion", "test:precision", "test:recall", "test:f1"},
"", null)
{
public void postParsing (CommandOption.List list)
{
java.lang.String defaultRawFormatting = "siw";
for (int argi=0; argi=3){
reportOptionArg = fields[2];
}
//System.out.println("Report option arg " + reportOptionArg);
//find the datasource (test,train,validation)
boolean foundDataSource = false;
int i=0;
for (; i 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(), unlabeledProportionOption.value);
//InfoGain ig = new InfoGain (ilists[0]);
//int igl = Math.min (10, ig.numLocations());
//for (int i = 0; i < igl; i++)
//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
//ig.print();
//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
//ilists[0].setFeatureSelection (selectedFeatures);
//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
//orfi.induceFeatures (ilists[0], false, true);
//System.out.println ("Training with "+ilists[0].size()+" instances");
long time[] = new long[numTrainers];
for (int c = 0; c < numTrainers; c++){
time[c] = System.currentTimeMillis();
ClassifierTrainer trainer = getTrainer(classifierTrainerStrings.get(c));
trainer.setValidationInstances(ilists[2]);
System.out.println ("Trial " + trialIndex + " Training " + trainer + " with "+ilists[0].size()+" instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier = trainer.train (ilists[0]);
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
System.out.println ("Trial " + trialIndex + " Training " + trainer.toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial (classifier, ilists[0]);
//assert (ilists[1].size() > 0);
Trial testTrial = new Trial (classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.confusion] && ilists[0].size()>0)
trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
if (ReportOptions[ReportOption.test][ReportOption.confusion] && ilists[1].size()>0)
testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
if (ReportOptions[ReportOption.validation][ReportOption.confusion] && ilists[2].size()>0)
validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
trainPrecision[c][k][trialIndex] = trainTrial.getPrecision(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
testPrecision[c][k][trialIndex] = testTrial.getPrecision(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
validationPrecision[c][k][trialIndex] = validationTrial.getPrecision(labels[k]);
}
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
trainRecall[c][k][trialIndex] = trainTrial.getRecall(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
testRecall[c][k][trialIndex] = testTrial.getRecall(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
validationRecall[c][k][trialIndex] = validationTrial.getRecall(labels[k]);
}
// gdruck - only perform evaluation if requested in report options
if (ReportOptions[ReportOption.train][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
trainF1[c][k][trialIndex] = trainTrial.getF1(labels[k]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
testF1[c][k][trialIndex] = testTrial.getF1(labels[k]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]) {
for (int k =0 ; k < labels.length; k++)
validationF1[c][k][trialIndex] = validationTrial.getF1(labels[k]);
}
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (numTrainers > 1) filename = filename+trainer.toString();
if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream
(new FileOutputStream (filename));
oos.writeObject (classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write classifier to filename "+
filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
//train
if (ReportOptions[ReportOption.train][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Training Data Confusion Matrix");
if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Precision(" + labels[k] + ") = "+ trainTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Recall(" + labels[k] + ") = "+ trainTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data F1(" + labels[k] + ") = "+ trainTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data accuracy = "+ trainAccuracy[c][trialIndex]);
}
//validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Validation Data Confusion Matrix");
if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data precision(" + labels[k] + ") = "+ validationTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data recall(" + labels[k] + ") = "+ validationTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data F1(" + labels[k] + ") = "+ validationTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data accuracy = "+ validationAccuracy[c][trialIndex]);
}
//test
if (ReportOptions[ReportOption.test][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Test Data Confusion Matrix");
if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data precision(" + labels[k] + ") = "+ testTrial.getPrecision(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data recall(" + labels[k] + ") = "+ testTrial.getRecall(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data F1(" + labels[k] + ") = "+ testTrial.getF1(labels[k]));
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data accuracy = "+ testAccuracy[c][trialIndex]);
}
if (trialIndex == 0) trainerNames[c] = trainer.toString();
} // end for each trainer
} // end for each trial
// New reporting
//"[train|test|validation]:[accuracy|f1|confusion|raw|precision|recall]"
for (int c=0; c < numTrainers; c++) {
System.out.println ("\n"+trainerNames[c].toString());
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
if (ReportOptions[ReportOption.train][ReportOption.precision]) {
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train precision("+labels[k]+") mean = "+ MatrixOps.mean (trainPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (trainPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (trainPrecision[c][k]));
}
if (ReportOptions[ReportOption.train][ReportOption.recall]) {
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train recall("+labels[k]+") mean = "+ MatrixOps.mean (trainRecall[c][k])+
" stddev = "+ MatrixOps.stddev (trainRecall[c][k])+
" stderr = "+ MatrixOps.stderr (trainRecall[c][k]));
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. train f1("+labels[k]+") mean = "+ MatrixOps.mean (trainF1[c][k])+
" stddev = "+ MatrixOps.stddev (trainF1[c][k])+
" stderr = "+ MatrixOps.stderr (trainF1[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
if (ReportOptions[ReportOption.validation][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation precision("+labels[k]+") mean = "+ MatrixOps.mean (validationPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (validationPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (validationPrecision[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation recall("+labels[k]+") mean = "+ MatrixOps.mean (validationRecall[c][k])+
" stddev = "+ MatrixOps.stddev (validationRecall[c][k])+
" stderr = "+ MatrixOps.stderr (validationRecall[c][k]));
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. validation f1("+labels[k]+") mean = "+ MatrixOps.mean (validationF1[c][k])+
" stddev = "+ MatrixOps.stddev (validationF1[c][k])+
" stderr = "+ MatrixOps.stderr (validationF1[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
if (ReportOptions[ReportOption.test][ReportOption.precision]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test precision("+labels[k]+") mean = "+ MatrixOps.mean (testPrecision[c][k])+
" stddev = "+ MatrixOps.stddev (testPrecision[c][k])+
" stderr = "+ MatrixOps.stderr (testPrecision[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.recall]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test recall("+labels[k]+") mean = "+ MatrixOps.mean (testRecall[c][k])+
" stddev = "+ MatrixOps.stddev (testRecall[c][k])+
" stderr = "+ MatrixOps.stderr (testRecall[c][k]));
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
for (int k =0 ; k < labels.length; k++)
System.out.println ("Summary. test f1("+labels[k]+") mean = "+ MatrixOps.mean (testF1[c][k])+
" stddev = "+ MatrixOps.stddev (testF1[c][k])+
" stderr = "+ MatrixOps.stderr (testF1[c][k]));
}
} // end for each trainer
}
private static void printTrialClassification(Trial trial)
{
for (Classification c : trial) {
Instance instance = c.getInstance();
System.out.print(instance.getName() + " " + instance.getTarget() + " ");
Labeling labeling = c.getLabeling();
for (int j = 0; j < labeling.numLocations(); j++){
System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
}
System.out.println();
}
}
private static Object createTrainer(String arg) {
try {
return interpreter.eval (arg);
} catch (bsh.EvalError e) {
throw new IllegalArgumentException ("Java interpreter eval error\n"+e);
}
}
private static ClassifierTrainer getTrainer(String arg) {
// parse something like Maxent,gaussianPriorVariance=10,numIterations=20
// first, split the argument at commas.
java.lang.String fields[] = arg.split(",");
//Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
// all call new MaxEntTrainer()
java.lang.String constructorName = fields[0];
Object trainer;
if (constructorName.indexOf('(') != -1) // if contains (), pass it though
trainer = createTrainer(arg);
else {
if (constructorName.endsWith("Trainer")){
trainer = createTrainer("new " + constructorName + "()"); // add parens if they forgot
}else{
trainer = createTrainer("new "+constructorName+"Trainer()"); // make trainer name from classifier name
}
}
// find methods associated with the class we just built
Method methods[] = trainer.getClass().getMethods();
// find setters corresponding to parameter names.
for (int i=1; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy