
cc.mallet.classify.tui.Calo2Classify Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.tui;
import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;
import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import java.util.Random;
/**
* Classify documents, run trials, print statistics from a vector file.
@author Andrew McCallum [email protected]
*/
public abstract class Calo2Classify
{
private static Classifier classifierL; // CPAL - added
private static Logger logger = MalletLogger.getLogger(Calo2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Calo2Classify.class.getName() + "-pl");
private static ArrayList classifierTrainers = new ArrayList();
private static boolean[][] ReportOptions = new boolean[3][4];
private static String[][] ReportOptionArgs = new String[3][4]; //arg in dataset:reportOption=arg
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption
{
static final String[] dataOptions = {"train", "test", "validation"};
static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"};
static final int train=0;
static final int test =1;
static final int validation=2;
static final int accuracy=0;
static final int f1=1;
static final int confusion=2;
static final int raw=3;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
(Calo2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]",
true, new String[] {"test:accuracy", "test:confusion", "train:accuracy"},
"", null)
{
public void postParsing (CommandOption.List list)
{
java.lang.String defaultRawFormatting = "siw";
for (int argi=0; argi=3){
reportOptionArg = fields[2];
}
//System.out.println("Report option arg " + reportOptionArg);
//find the datasource (test,train,validation)
boolean foundDataSource = false;
int i=0;
for (; i 1) filename = filename+trainers[c].toString();
//if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
//ObjectOutputStream oos = new ObjectOutputStream
// (new FileOutputStream (filename));
//oos.writeObject (classifier);
ObjectInputStream iis = new ObjectInputStream
(new FileInputStream (filename));
classifierL = (Classifier) iis.readObject();
iis.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't read classifier from filename "+
filename);
}
}
// CPAL
for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
System.out.println("\n-------------------- Trial " + trialIndex + " --------------------\n");
InstanceList[] ilists;
BitSet unlabeledIndices = null;
if (!separateIlists){
ilists = ilist.split (r, new double[] {t, 1-t-v, v});
} else {
ilists = new InstanceList[3];
ilists[0] = trainingFileIlist;
ilists[1] = testFileIlist;
ilists[2] = testFileIlist;
}
if (unlabeledProportionOption.value > 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(),
unlabeledProportionOption.value);
//InfoGain ig = new InfoGain (ilists[0]);
//int igl = Math.min (10, ig.numLocations());
//for (int i = 0; i < igl; i++)
//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
//ig.print();
//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
//ilists[0].setFeatureSelection (selectedFeatures);
//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
//orfi.induceFeatures (ilists[0], false, true);
//System.out.println ("Training with "+ilists[0].size()+" instances");
long time[] = new long[trainers.length];
for (int c = 0; c < trainers.length; c++){
time[c] = System.currentTimeMillis();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier;
if(loadmodelFile.wasInvoked()) {
classifier = classifierL;
} else {
classifier = trainers[c].train (ilists[0]);
}
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial (classifier, ilists[0]);
assert (ilists[1].size() > 0);
Trial testTrial = new Trial (classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
if (ilists[0].size()>0) trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
if (ilists[1].size()>0) testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
if (ilists[2].size()>0) validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (trainers.length > 1) filename = filename+trainers[c].toString();
if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream
(new FileOutputStream (filename));
oos.writeObject (classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write classifier to filename "+
filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
//train
if (ReportOptions[ReportOption.train][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Training Data Confusion Matrix");
if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
}
//validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Validation Data Confusion Matrix");
if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
}
//test
if (ReportOptions[ReportOption.test][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
}
} // end for each trainer
} // end for each trial
// New reporting
//"[train|test|validation]:[accuracy|f1|confusion|raw]"
for (int c=0; c < trainers.length; c++) {
System.out.println ("\n"+trainers[c].toString());
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
} // end for each trainer
}
private static void printTrialClassification(Trial trial)
{
for (int i = 0; i < trial.size(); i++) {
Instance instance = trial.get(i).getInstance();
System.out.print(instance.getName() + " " + instance.getTarget() + " ");
Labeling labeling = trial.get(i).getLabeling();
for (int j = 0; j < labeling.numLocations(); j++){
System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
}
System.out.println();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy