
cc.mallet.classify.tui.Calo2Classify Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify.tui;
import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;
import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import java.util.Random;
/**
* Classify documents, run trials, print statistics from a vector file.
@author Andrew McCallum [email protected]
*/
public abstract class Calo2Classify
{
private static Classifier classifierL; // CPAL - added
private static Logger logger = MalletLogger.getLogger(Calo2Classify.class.getName());
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Calo2Classify.class.getName() + "-pl");
private static ArrayList classifierTrainers = new ArrayList();
private static boolean[][] ReportOptions = new boolean[3][4];
private static String[][] ReportOptionArgs = new String[3][4]; //arg in dataset:reportOption=arg
// Essentially an enum mapping string names to enums to ints.
private static class ReportOption
{
static final String[] dataOptions = {"train", "test", "validation"};
static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"};
static final int train=0;
static final int test =1;
static final int validation=2;
static final int accuracy=0;
static final int f1=1;
static final int confusion=2;
static final int raw=3;
}
static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
(Calo2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]",
true, new String[] {"test:accuracy", "test:confusion", "train:accuracy"},
"", null)
{
public void postParsing (CommandOption.List list)
{
java.lang.String defaultRawFormatting = "siw";
for (int argi=0; argi=3){
reportOptionArg = fields[2];
}
//System.out.println("Report option arg " + reportOptionArg);
//find the datasource (test,train,validation)
boolean foundDataSource = false;
int i=0;
for (; i 1) filename = filename+trainers[c].toString();
//if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
//ObjectOutputStream oos = new ObjectOutputStream
// (new FileOutputStream (filename));
//oos.writeObject (classifier);
ObjectInputStream iis = new ObjectInputStream
(new FileInputStream (filename));
classifierL = (Classifier) iis.readObject();
iis.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't read classifier from filename "+
filename);
}
}
// CPAL
for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
System.out.println("\n-------------------- Trial " + trialIndex + " --------------------\n");
InstanceList[] ilists;
BitSet unlabeledIndices = null;
if (!separateIlists){
ilists = ilist.split (r, new double[] {t, 1-t-v, v});
} else {
ilists = new InstanceList[3];
ilists[0] = trainingFileIlist;
ilists[1] = testFileIlist;
ilists[2] = testFileIlist;
}
if (unlabeledProportionOption.value > 0)
unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
.nextBitSet(ilists[0].size(),
unlabeledProportionOption.value);
//InfoGain ig = new InfoGain (ilists[0]);
//int igl = Math.min (10, ig.numLocations());
//for (int i = 0; i < igl; i++)
//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
//ig.print();
//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
//ilists[0].setFeatureSelection (selectedFeatures);
//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
//orfi.induceFeatures (ilists[0], false, true);
//System.out.println ("Training with "+ilists[0].size()+" instances");
long time[] = new long[trainers.length];
for (int c = 0; c < trainers.length; c++){
time[c] = System.currentTimeMillis();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances");
if (unlabeledProportionOption.value > 0)
ilists[0].hideSomeLabels(unlabeledIndices);
Classifier classifier;
if(loadmodelFile.wasInvoked()) {
classifier = classifierL;
} else {
classifier = trainers[c].train (ilists[0]);
}
if (unlabeledProportionOption.value > 0)
ilists[0].unhideAllLabels();
System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
time[c] = System.currentTimeMillis() - time[c];
Trial trainTrial = new Trial (classifier, ilists[0]);
assert (ilists[1].size() > 0);
Trial testTrial = new Trial (classifier, ilists[1]);
Trial validationTrial = new Trial(classifier, ilists[2]);
if (ilists[0].size()>0) trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
if (ilists[1].size()>0) testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
if (ilists[2].size()>0) validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();
trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
testAccuracy[c][trialIndex] = testTrial.getAccuracy();
validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
if (outputFile.wasInvoked()) {
String filename = outputFile.value;
if (trainers.length > 1) filename = filename+trainers[c].toString();
if (numTrials > 1) filename = filename+".trial"+trialIndex;
try {
ObjectOutputStream oos = new ObjectOutputStream
(new FileOutputStream (filename));
oos.writeObject (classifier);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write classifier to filename "+
filename);
}
}
// New Reporting
// raw output
if (ReportOptions[ReportOption.train][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Training Data");
printTrialClassification(trainTrial);
}
if (ReportOptions[ReportOption.test][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Testing Data");
printTrialClassification(testTrial);
}
if (ReportOptions[ReportOption.validation][ReportOption.raw]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
System.out.println(" Raw Validation Data");
printTrialClassification(validationTrial);
}
//train
if (ReportOptions[ReportOption.train][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Training Data Confusion Matrix");
if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.train][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
}
//validation
if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Validation Data Confusion Matrix");
if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.validation][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
}
//test
if (ReportOptions[ReportOption.test][ReportOption.confusion]){
System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
}
if (ReportOptions[ReportOption.test][ReportOption.f1]){
String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
}
} // end for each trainer
} // end for each trial
// New reporting
//"[train|test|validation]:[accuracy|f1|confusion|raw]"
for (int c=0; c < trainers.length; c++) {
System.out.println ("\n"+trainers[c].toString());
if (ReportOptions[ReportOption.train][ReportOption.accuracy])
System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
if (ReportOptions[ReportOption.test][ReportOption.accuracy])
System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
} // end for each trainer
}
private static void printTrialClassification(Trial trial)
{
for (int i = 0; i < trial.size(); i++) {
Instance instance = trial.get(i).getInstance();
System.out.print(instance.getName() + " " + instance.getTarget() + " ");
Labeling labeling = trial.get(i).getLabeling();
for (int j = 0; j < labeling.numLocations(); j++){
System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
}
System.out.println();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy