All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.classify.Trial Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

There is a newer version: 2.0.12
Show newest version
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




package cc.mallet.classify;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;

/**
 * Stores the results of classifying a collection of Instances,
 * and provides many methods for evaluating the results.
 *
 * If you just need one evaluation result, you may find it easier to one
 * of the corresponding methods in Classifier, which simply call the methods here.
 * 
 * @see InstanceList
 * @see Classifier
 * @see Classification
 *
 * @author Andrew McCallum [email protected]
 */
public class Trial extends ArrayList
{
	private static Logger logger = Logger.getLogger(Trial.class.getName());

	Classifier classifier;

	public Trial (Classifier c, InstanceList ilist)
	{
		super (ilist.size());
		this.classifier = c;
		for (Instance instance : ilist)
			this.add (c.classify (instance));
	}
	
	public boolean add (Classification c)
	{
		if (c.getClassifier() != this.classifier)
			throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
		return super.add (c);
	}
	
	public void add (int index, Classification c) 
	{
		if (c.getClassifier() != this.classifier)
			throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
		super.add (index, c);
	}
	
	public boolean addAll(Collection collection) {
		boolean ret = true;
		for (Classification c : collection)
			if (!this.add(c))
				ret = false;
		return ret;
	}
	
	public boolean addAll (int index, Collection collection) {
		throw new IllegalStateException ("Not implemented.");
	}

	
	public Classifier getClassifier () 
	{
		return classifier;
	}

	/** Return the fraction of instances that have the correct label as their best predicted label. */
	public double getAccuracy ()
	{
		int numCorrect = 0;
		for (int i = 0; i < this.size(); i++)
			if (this.get(i).bestLabelIsCorrect())
				numCorrect++;
		return (double)numCorrect/this.size();
	}

	
	/** Calculate the precision of the classifier on an instance list for a
	    particular target entry */
	public double getPrecision (Object labelEntry)
	{
		int index;
		if (labelEntry instanceof Labeling)
			index = ((Labeling)labelEntry).getBestIndex();
		else
			index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
		if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
		return getPrecision (index);
	}
	
	public double getPrecision (Labeling label)
	{
		return getPrecision (label.getBestIndex());
	}

	/** Calculate the precision for a particular target index from an 
	    array list of classifications */
	public double getPrecision (int index)
	{
		int numCorrect = 0;
		int numInstances = 0;
		int trueLabel, classLabel;
		for (int i = 0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy