
cc.mallet.classify.Trial Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
/**
* Stores the results of classifying a collection of Instances,
* and provides many methods for evaluating the results.
*
* If you just need one evaluation result, you may find it easier to one
* of the corresponding methods in Classifier, which simply call the methods here.
*
* @see InstanceList
* @see Classifier
* @see Classification
*
* @author Andrew McCallum [email protected]
*/
public class Trial extends ArrayList
{
private static Logger logger = Logger.getLogger(Trial.class.getName());
Classifier classifier;
public Trial (Classifier c, InstanceList ilist)
{
super (ilist.size());
this.classifier = c;
for (Instance instance : ilist)
this.add (c.classify (instance));
}
public boolean add (Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
return super.add (c);
}
public void add (int index, Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
super.add (index, c);
}
public boolean addAll(Collection extends Classification> collection) {
boolean ret = true;
for (Classification c : collection)
if (!this.add(c))
ret = false;
return ret;
}
public boolean addAll (int index, Collection extends Classification> collection) {
throw new IllegalStateException ("Not implemented.");
}
public Classifier getClassifier ()
{
return classifier;
}
/** Return the fraction of instances that have the correct label as their best predicted label. */
public double getAccuracy ()
{
int numCorrect = 0;
for (int i = 0; i < this.size(); i++)
if (this.get(i).bestLabelIsCorrect())
numCorrect++;
return (double)numCorrect/this.size();
}
/** Calculate the precision of the classifier on an instance list for a
particular target entry */
public double getPrecision (Object labelEntry)
{
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
else
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getPrecision (index);
}
public double getPrecision (Labeling label)
{
return getPrecision (label.getBestIndex());
}
/** Calculate the precision for a particular target index from an
array list of classifications */
public double getPrecision (int index)
{
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy