All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.tui.Calo2Classify Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify.tui;



import java.io.*;
import java.util.*;
import java.util.logging.*;
import java.lang.reflect.*;

import cc.mallet.classify.*;
import cc.mallet.classify.evaluate.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import java.util.Random;
/**
 * Classify documents, run trials, print statistics from a vector file.
   @author Andrew McCallum [email protected]
 */

public abstract class Calo2Classify
{
	private static Classifier classifierL; // CPAL - added
    private static Logger logger = MalletLogger.getLogger(Calo2Classify.class.getName());
	private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Calo2Classify.class.getName() + "-pl");
	private static ArrayList classifierTrainers = new ArrayList();
    private static boolean[][] ReportOptions = new boolean[3][4];
    private static String[][] ReportOptionArgs = new String[3][4];  //arg in dataset:reportOption=arg
	// Essentially an enum mapping string names to enums to ints.
	private static class ReportOption
	{
		static final String[] dataOptions = {"train", "test", "validation"};
		static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"};
		static final int train=0;
		static final int test =1;
		static final int validation=2;
		static final int accuracy=0;
		static final int f1=1;
		static final int confusion=2;
		static final int raw=3;
	}
	static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings
	(Calo2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]",
	 true, new String[] {"test:accuracy", "test:confusion",  "train:accuracy"},
	 "", null)
        {
		public void postParsing (CommandOption.List list)
		{
			java.lang.String defaultRawFormatting = "siw";

			for (int argi=0; argi=3){
					reportOptionArg = fields[2];
				}
				//System.out.println("Report option arg " + reportOptionArg);

				//find the datasource (test,train,validation)
				boolean foundDataSource = false;
				int i=0;
				for (; i 1) filename = filename+trainers[c].toString();
                //if (numTrials > 1) filename = filename+".trial"+trialIndex;
                try {
                    //ObjectOutputStream oos = new ObjectOutputStream
                    //        (new FileOutputStream (filename));
                    //oos.writeObject (classifier);
                    ObjectInputStream iis = new ObjectInputStream
                        (new FileInputStream (filename));
                    classifierL = (Classifier) iis.readObject();
                    iis.close();
                } catch (Exception e) {
                    e.printStackTrace();
                    throw new IllegalArgumentException ("Couldn't read classifier from filename "+
                            filename);
                }
            }

        // CPAL


		for (int trialIndex = 0; trialIndex < numTrials; trialIndex++) {
			System.out.println("\n-------------------- Trial " + trialIndex + "  --------------------\n");
      InstanceList[] ilists;
      BitSet unlabeledIndices = null;
      if (!separateIlists){
        ilists = ilist.split (r, new double[] {t, 1-t-v, v});
      } else {
        ilists = new InstanceList[3];
        ilists[0] = trainingFileIlist;
        ilists[1] = testFileIlist;
        ilists[2] = testFileIlist;
      }
      if (unlabeledProportionOption.value > 0)
        unlabeledIndices = new cc.mallet.util.Randoms(r.nextInt())
            .nextBitSet(ilists[0].size(),
                        unlabeledProportionOption.value);


      //InfoGain ig = new InfoGain (ilists[0]);
			//int igl = Math.min (10, ig.numLocations());
			//for (int i = 0; i < igl; i++)
			//System.out.println ("InfoGain["+ig.getObjectAtRank(i)+"]="+ig.getValueAtRank(i));
			//ig.print();

			//FeatureSelection selectedFeatures = new FeatureSelection (ig, 8000);
			//ilists[0].setFeatureSelection (selectedFeatures);
			//OddsRatioFeatureInducer orfi = new OddsRatioFeatureInducer (ilists[0]);
			//orfi.induceFeatures (ilists[0], false, true);


			//System.out.println ("Training with "+ilists[0].size()+" instances");
			long time[] = new long[trainers.length];
			for (int c = 0; c < trainers.length; c++){
				time[c] = System.currentTimeMillis();
				System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " with "+ilists[0].size()+" instances");
        if (unlabeledProportionOption.value > 0)
          ilists[0].hideSomeLabels(unlabeledIndices);

        Classifier classifier;
        if(loadmodelFile.wasInvoked()) {
            classifier = classifierL;
        } else {
            classifier = trainers[c].train (ilists[0]);
        }
        if (unlabeledProportionOption.value > 0)
          ilists[0].unhideAllLabels();

        System.out.println ("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
				time[c] = System.currentTimeMillis() - time[c];
				Trial trainTrial = new Trial (classifier, ilists[0]);
				assert (ilists[1].size() > 0);
				Trial testTrial = new Trial (classifier, ilists[1]);
				Trial validationTrial = new Trial(classifier, ilists[2]);

				if (ilists[0].size()>0) trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix (trainTrial).toString();
				if (ilists[1].size()>0) testConfusionMatrix[c][trialIndex] = new ConfusionMatrix (testTrial).toString();
				if (ilists[2].size()>0) validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix (validationTrial).toString();

				trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
				testAccuracy[c][trialIndex] = testTrial.getAccuracy();
				validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();

				if (outputFile.wasInvoked()) {
					String filename = outputFile.value;
					if (trainers.length > 1) filename = filename+trainers[c].toString();
					if (numTrials > 1) filename = filename+".trial"+trialIndex;
					try {
						ObjectOutputStream oos = new ObjectOutputStream
																		 (new FileOutputStream (filename));
						oos.writeObject (classifier);
						oos.close();
					} catch (Exception e) {
						e.printStackTrace();
						throw new IllegalArgumentException ("Couldn't write classifier to filename "+
																								filename);
					}
				}

				// New Reporting

                // raw output
				if (ReportOptions[ReportOption.train][ReportOption.raw]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
					System.out.println(" Raw Training Data");
					printTrialClassification(trainTrial);
				}

				if (ReportOptions[ReportOption.test][ReportOption.raw]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
					System.out.println(" Raw Testing Data");
					printTrialClassification(testTrial);
				}

				if (ReportOptions[ReportOption.validation][ReportOption.raw]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
					System.out.println(" Raw Validation Data");
					printTrialClassification(validationTrial);
				}


				//train
				if (ReportOptions[ReportOption.train][ReportOption.confusion]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() +  " Training Data Confusion Matrix");
					if (ilists[0].size()>0) System.out.println (trainConfusionMatrix[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.train][ReportOption.accuracy]){
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= "+ trainAccuracy[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.train][ReportOption.f1]){
					String label = ReportOptionArgs[ReportOption.train][ReportOption.f1];
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = "+ trainTrial.getF1(label));
				}

				//validation
				if (ReportOptions[ReportOption.validation][ReportOption.confusion]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() +  " Validation Data Confusion Matrix");
					if (ilists[2].size()>0) System.out.println (validationConfusionMatrix[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.validation][ReportOption.accuracy]){
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= "+ validationAccuracy[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.validation][ReportOption.f1]){
					String label = ReportOptionArgs[ReportOption.validation][ReportOption.f1];
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = "+ validationTrial.getF1(label));
				}

				//test
				if (ReportOptions[ReportOption.test][ReportOption.confusion]){
					System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
					if (ilists[1].size()>0) System.out.println (testConfusionMatrix[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.test][ReportOption.accuracy]){
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= "+ testAccuracy[c][trialIndex]);
				}

				if (ReportOptions[ReportOption.test][ReportOption.f1]){
					String label = ReportOptionArgs[ReportOption.test][ReportOption.f1];
					System.out.println ("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = "+ testTrial.getF1(label));
				}


			}  // end for each trainer
		}  // end for each trial

        // New reporting
		//"[train|test|validation]:[accuracy|f1|confusion|raw]"
		for (int c=0; c < trainers.length; c++) {
			System.out.println ("\n"+trainers[c].toString());
			if (ReportOptions[ReportOption.train][ReportOption.accuracy])
				System.out.println ("Summary. train accuracy mean = "+ MatrixOps.mean (trainAccuracy[c])+
														" stddev = "+ MatrixOps.stddev (trainAccuracy[c])+
														" stderr = "+ MatrixOps.stderr (trainAccuracy[c]));

			if (ReportOptions[ReportOption.validation][ReportOption.accuracy])
				System.out.println ("Summary. validation accuracy mean = "+ MatrixOps.mean (validationAccuracy[c])+
														" stddev = "+ MatrixOps.stddev (validationAccuracy[c])+
														" stderr = "+ MatrixOps.stderr (validationAccuracy[c]));

			if (ReportOptions[ReportOption.test][ReportOption.accuracy])
				System.out.println ("Summary. test accuracy mean = "+ MatrixOps.mean (testAccuracy[c])+
														" stddev = "+ MatrixOps.stddev (testAccuracy[c])+
														" stderr = "+ MatrixOps.stderr (testAccuracy[c]));

		}   // end for each trainer
	}

	private static void printTrialClassification(Trial trial)
	{
		for (int i = 0; i < trial.size(); i++) {
			Instance instance = trial.get(i).getInstance();
			System.out.print(instance.getName() + " " + instance.getTarget() + " ");

			Labeling labeling = trial.get(i).getLabeling();

			for (int j = 0; j < labeling.numLocations(); j++){
				System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
			}

			System.out.println();
		}
	}


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy