All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.tui.Vectors2Classify Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify.tui;



import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;

import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
/**
 * Classify documents, run trials, print statistics from a vector file.
 @author Andrew McCallum [email protected]
 */

public abstract class Vectors2Classify
{
	static BshInterpreter interpreter = new BshInterpreter();

	private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
	private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
	private static ArrayList classifierTrainerStrings = new ArrayList();
	private static int dataOptionsSize = 3;
	private static int reportOptionsSize = 6;

	private static boolean[][] ReportOptions = new boolean[dataOptionsSize][reportOptionsSize];

	// Essentially an enum mapping string names to enums to ints.
	private static class ReportOption
	{
		static final String[] dataOptions = {"train", "test", "validation"};
		static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw", "precision", "recall"};
		static final int train=0;
		static final int test =1;
		static final int validation=2;
		static final int accuracy=0;
		static final int f1=1;
		static final int confusion=2;
		static final int raw=3;
		static final int precision=4;
		static final int recall=5;
	}

	static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
			(Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1:label|precision:label|recall:label|confusion|raw]",
					true, new String[] {"train:accuracy",  "test:accuracy", "test:confusion",  "test:precision",  "test:recall", "test:f1"},
					"", null) 
	{
		public void postParsing (CommandOption.List list)
		{
			java.lang.String defaultRawFormatting = "siw";

			for (int argi=0; argi=3){
					reportOptionArg = fields[2];
				}
				//System.out.println("Report option arg " + reportOptionArg);

				//find the datasource (test,train,validation)
				boolean foundDataSource = false;
				int i=0;
				for (; i 0)
					unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
				.nextBitSet(ilists[0].size(), unlabeledProportionOption.value);

				//InfoGain ig = new InfoGain (ilists[0]);
				//int igl = Math.min (10, ig.numLocations());
				//for (int i = 0; i < igl; i++)
				//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
				//ig.print();

				//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
				//ilists[0].setFeatureSelection (selectedFeatures);
				//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
				//orfi.induceFeatures (ilists[0], false, true);

				//System.out.println ("Training with "+ilists[0].size()+" instances");
				long time[] = new long[numTrainers];
				for (int c = 0; c < numTrainers; c++){
					time[c] = System.currentTimeMillis();
					ClassifierTrainer trainer = getTrainer(classifierTrainerStrings.get(c));
					trainer.setValidationInstances(ilists[2]);
					System.out.println ("Trial " + trialIndex + " Training " + trainer + " with "+ilists[0].size()+" instances");
					if (unlabeledProportionOption.value > 0)
						ilists[0].hideSomeLabels(unlabeledIndices);
					Classifier classifier = trainer.train (ilists[0]);
					if (unlabeledProportionOption.value > 0)
						ilists[0].unhideAllLabels();

					System.out.println ("Trial " + trialIndex + " Training " + trainer.toString() + " finished");
					time[c] = System.currentTimeMillis() - time[c];
					Trial trainTrial = new Trial (classifier, ilists[0]);
					//assert (ilists[1].size() > 0);
					Trial testTrial = new Trial (classifier, ilists[1]);
					Trial validationTrial = new Trial(classifier, ilists[2]);

					// gdruck - only perform evaluation if requested in report options
					if (ReportOptions[ReportOption.train][ReportOption.confusion] && ilists[0].size()>0) 
						trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
					if (ReportOptions[ReportOption.test][ReportOption.confusion] && ilists[1].size()>0) 
						testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
					if (ReportOptions[ReportOption.validation][ReportOption.confusion] && ilists[2].size()>0) 
						validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();

					// gdruck - only perform evaluation if requested in report options
					if (ReportOptions[ReportOption.train][ReportOption.accuracy]) 
						trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
					if (ReportOptions[ReportOption.test][ReportOption.accuracy]) 
						testAccuracy[c][trialIndex] = testTrial.getAccuracy();
					if (ReportOptions[ReportOption.validation][ReportOption.accuracy]) 
						validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();

					// gdruck - only perform evaluation if requested in report options
					if (ReportOptions[ReportOption.train][ReportOption.precision]) {
						for (int k =0 ; k < labels.length; k++) 
							trainPrecision[c][k][trialIndex] = trainTrial.getPrecision(labels[k]);
					}
					if (ReportOptions[ReportOption.test][ReportOption.precision]) {
						for (int k =0 ; k < labels.length; k++) 
							testPrecision[c][k][trialIndex] = testTrial.getPrecision(labels[k]);
					}
					if (ReportOptions[ReportOption.validation][ReportOption.precision]) {
						for (int k =0 ; k < labels.length; k++) 
							validationPrecision[c][k][trialIndex] = validationTrial.getPrecision(labels[k]);
					}

					// gdruck - only perform evaluation if requested in report options
					if (ReportOptions[ReportOption.train][ReportOption.recall]) {
						for (int k =0 ; k < labels.length; k++) 
							trainRecall[c][k][trialIndex] = trainTrial.getRecall(labels[k]);
					}	
					if (ReportOptions[ReportOption.test][ReportOption.recall]) {
						for (int k =0 ; k < labels.length; k++) 
							testRecall[c][k][trialIndex] = testTrial.getRecall(labels[k]);
					}					
					if (ReportOptions[ReportOption.validation][ReportOption.recall]) {
						for (int k =0 ; k < labels.length; k++) 
							validationRecall[c][k][trialIndex] = validationTrial.getRecall(labels[k]);
					}	

					// gdruck - only perform evaluation if requested in report options
					if (ReportOptions[ReportOption.train][ReportOption.f1]) {
						for (int k =0 ; k < labels.length; k++) 
							trainF1[c][k][trialIndex] = trainTrial.getF1(labels[k]);
					}					
					if (ReportOptions[ReportOption.test][ReportOption.f1]) {
						for (int k =0 ; k < labels.length; k++) 
							testF1[c][k][trialIndex] = testTrial.getF1(labels[k]);
					}						
					if (ReportOptions[ReportOption.validation][ReportOption.f1]) {
						for (int k =0 ; k < labels.length; k++) 
							validationF1[c][k][trialIndex] = validationTrial.getF1(labels[k]);
					}	

					if (outputFile.wasInvoked()) {
						String filename = outputFile.value;
						if (numTrainers > 1) filename = filename+trainer.toString();
						if (numTrials > 1) filename = filename+".trial"+trialIndex;
						try {
							ObjectOutputStream oos = new ObjectOutputStream
									(new FileOutputStream (filename));
							oos.writeObject (classifier);
							oos.close();
						} catch (Exception e) {
							e.printStackTrace();
							throw new IllegalArgumentException ("Couldn't write classifier to filename "+
									filename);
						}
					}

					// New Reporting

					// raw output
					if (ReportOptions[ReportOption.train][ReportOption.raw]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
						System.out.println(" Raw Training Data");
						printTrialClassification(trainTrial);
					}

					if (ReportOptions[ReportOption.test][ReportOption.raw]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
						System.out.println(" Raw Testing Data");
						printTrialClassification(testTrial);
					}

					if (ReportOptions[ReportOption.validation][ReportOption.raw]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString());
						System.out.println(" Raw Validation Data");
						printTrialClassification(validationTrial);
					}


					//train
					if (ReportOptions[ReportOption.train][ReportOption.confusion]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() +  " Training Data Confusion Matrix");
						if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
					}

					if (ReportOptions[ReportOption.train][ReportOption.precision]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Precision(" + labels[k] + ") = "+ trainTrial.getPrecision(labels[k]));
					}

					if (ReportOptions[ReportOption.train][ReportOption.recall]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data Recall(" + labels[k] + ") = "+ trainTrial.getRecall(labels[k]));
					}

					if (ReportOptions[ReportOption.train][ReportOption.f1]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data F1(" + labels[k] + ") = "+ trainTrial.getF1(labels[k]));
					}

					if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
						System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " training data accuracy = "+ trainAccuracy[c][trialIndex]);
					}

					//validation
					if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() +  " Validation Data Confusion Matrix");
						if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
					}

					if (ReportOptions[ReportOption.validation][ReportOption.precision]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data precision(" + labels[k] + ") = "+ validationTrial.getPrecision(labels[k]));
					}

					if (ReportOptions[ReportOption.validation][ReportOption.recall]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data recall(" + labels[k] + ") = "+ validationTrial.getRecall(labels[k]));
					}

					if (ReportOptions[ReportOption.validation][ReportOption.f1]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data F1(" + labels[k] + ") = "+ validationTrial.getF1(labels[k]));
					}
					if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
						System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " validation data accuracy = "+ validationAccuracy[c][trialIndex]);
					}

					//test
					if (ReportOptions[ReportOption.test][ReportOption.confusion]){
						System.out.println("Trial " + trialIndex + " Trainer " + trainer.toString() + " Test Data Confusion Matrix");
						if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
					}

					if (ReportOptions[ReportOption.test][ReportOption.precision]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data precision(" + labels[k] + ") = "+ testTrial.getPrecision(labels[k]));
					}

					if (ReportOptions[ReportOption.test][ReportOption.recall]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data recall(" + labels[k] + ") = "+ testTrial.getRecall(labels[k]));
					}

					if (ReportOptions[ReportOption.test][ReportOption.f1]){
						for (int k =0 ; k < labels.length; k++) 
							System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data F1(" + labels[k] + ") = "+ testTrial.getF1(labels[k]));
					}

					if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
						System.out.println ("Trial " + trialIndex + " Trainer " + trainer.toString() + " test data accuracy = "+ testAccuracy[c][trialIndex]);
					}

					if (trialIndex == 0) trainerNames[c] = trainer.toString();

				}  // end for each trainer
			}  // end for each trial

			// New reporting
			//"[train|test|validation]:[accuracy|f1|confusion|raw|precision|recall]"
			for (int c=0; c < numTrainers; c++) {
				System.out.println ("\n"+trainerNames[c].toString());

				if (ReportOptions[ReportOption.train][ReportOption.accuracy])
					System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
							" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
							" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));
				if (ReportOptions[ReportOption.train][ReportOption.precision]) {
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. train precision("+labels[k]+") mean = "+ MatrixOps.mean (trainPrecision[c][k])+
								" stddev = "+ MatrixOps.stddev (trainPrecision[c][k])+
								" stderr = "+ MatrixOps.stderr (trainPrecision[c][k]));
				}
				if (ReportOptions[ReportOption.train][ReportOption.recall]) {
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. train recall("+labels[k]+") mean = "+ MatrixOps.mean (trainRecall[c][k])+
								" stddev = "+ MatrixOps.stddev (trainRecall[c][k])+
								" stderr = "+ MatrixOps.stderr (trainRecall[c][k]));
				}
				if (ReportOptions[ReportOption.train][ReportOption.f1]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. train f1("+labels[k]+") mean = "+ MatrixOps.mean (trainF1[c][k])+
								" stddev = "+ MatrixOps.stddev (trainF1[c][k])+
								" stderr = "+ MatrixOps.stderr (trainF1[c][k]));

				}
				if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
					System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
							" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
							" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));
				if (ReportOptions[ReportOption.validation][ReportOption.precision]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. validation precision("+labels[k]+") mean = "+ MatrixOps.mean (validationPrecision[c][k])+
								" stddev = "+ MatrixOps.stddev (validationPrecision[c][k])+
								" stderr = "+ MatrixOps.stderr (validationPrecision[c][k]));
				}
				if (ReportOptions[ReportOption.validation][ReportOption.recall]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. validation recall("+labels[k]+") mean = "+ MatrixOps.mean (validationRecall[c][k])+
								" stddev = "+ MatrixOps.stddev (validationRecall[c][k])+
								" stderr = "+ MatrixOps.stderr (validationRecall[c][k]));
				}
				if (ReportOptions[ReportOption.validation][ReportOption.f1]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. validation f1("+labels[k]+") mean = "+ MatrixOps.mean (validationF1[c][k])+
								" stddev = "+ MatrixOps.stddev (validationF1[c][k])+
								" stderr = "+ MatrixOps.stderr (validationF1[c][k]));
				}

				if (ReportOptions[ReportOption.test][ReportOption.accuracy])
					System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
							" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
							" stderr = "+ MatrixOps.stderr (testAccuracy[c]));
				if (ReportOptions[ReportOption.test][ReportOption.precision]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. test precision("+labels[k]+") mean = "+ MatrixOps.mean (testPrecision[c][k])+
								" stddev = "+ MatrixOps.stddev (testPrecision[c][k])+
								" stderr = "+ MatrixOps.stderr (testPrecision[c][k]));
				}
				if (ReportOptions[ReportOption.test][ReportOption.recall]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. test recall("+labels[k]+") mean = "+ MatrixOps.mean (testRecall[c][k])+
								" stddev = "+ MatrixOps.stddev (testRecall[c][k])+
								" stderr = "+ MatrixOps.stderr (testRecall[c][k]));
				}
				if (ReportOptions[ReportOption.test][ReportOption.f1]){
					for (int k =0 ; k < labels.length; k++) 
						System.out.println ("Summary. test f1("+labels[k]+") mean = "+ MatrixOps.mean (testF1[c][k])+
								" stddev = "+ MatrixOps.stddev (testF1[c][k])+
								" stderr = "+ MatrixOps.stderr (testF1[c][k]));
				}
			}   // end for each trainer
		}

		private static void printTrialClassification(Trial trial)
		{
			for (Classification c : trial) {
				Instance instance = c.getInstance();
				System.out.print(instance.getName() + " " + instance.getTarget() + " ");
				Labeling labeling = c.getLabeling();
				for (int j = 0; j < labeling.numLocations(); j++){
					System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
				}
				System.out.println();
			}
		}

		private static Object createTrainer(String arg) {
			try {
				return interpreter.eval (arg);
			} catch (bsh.EvalError e) {
				throw new IllegalArgumentException ("Java interpreter eval error\n"+e);
			}
		}

		private static ClassifierTrainer getTrainer(String arg) {
			// parse something like Maxent,gaussianPriorVariance=10,numIterations=20

			// first, split the argument at commas.
			java.lang.String fields[] = arg.split(",");

			//Massage constructor name, so that MaxEnt, MaxEntTrainer, new MaxEntTrainer()
			// all call new MaxEntTrainer()
			java.lang.String constructorName = fields[0];
			Object trainer;
			if (constructorName.indexOf('(') != -1) // if contains (), pass it though
				trainer = createTrainer(arg);
			else {
				if (constructorName.endsWith("Trainer")){
					trainer = createTrainer("new " + constructorName + "()"); // add parens if they forgot
				}else{
					trainer = createTrainer("new "+constructorName+"Trainer()"); // make trainer name from classifier name
				}
			}

			// find methods associated with the class we just built
			Method methods[] =  trainer.getClass().getMethods();

			// find setters corresponding to parameter names.
			for (int i=1; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy