
cc.mallet.classify.Trial Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.Collection;
import java.util.logging.Logger;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
/**
* Stores the results of classifying a collection of Instances,
* and provides many methods for evaluating the results.
*
* If you just need one evaluation result, you may find it easier to one
* of the corresponding methods in Classifier, which simply call the methods here.
*
* @see InstanceList
* @see Classifier
* @see Classification
*
* @author Andrew McCallum [email protected]
*/
public class Trial extends ArrayList
{
private static Logger logger = Logger.getLogger(Trial.class.getName());
Classifier classifier;
public Trial (Classifier c, InstanceList ilist)
{
super (ilist.size());
this.classifier = c;
for (Instance instance : ilist)
this.add (c.classify (instance));
}
public boolean add (Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
return super.add (c);
}
public void add (int index, Classification c)
{
if (c.getClassifier() != this.classifier)
throw new IllegalArgumentException ("Trying to add Classification from a different Classifier.");
super.add (index, c);
}
public boolean addAll(Collection extends Classification> collection) {
boolean ret = true;
for (Classification c : collection)
if (!this.add(c))
ret = false;
return ret;
}
public boolean addAll (int index, Collection extends Classification> collection) {
throw new IllegalStateException ("Not implemented.");
}
public Classifier getClassifier ()
{
return classifier;
}
/** Return the fraction of instances that have the correct label as their best predicted label. */
public double getAccuracy ()
{
int numCorrect = 0;
for (int i = 0; i < this.size(); i++)
if (this.get(i).bestLabelIsCorrect())
numCorrect++;
return (double)numCorrect/this.size();
}
/** Calculate the precision of the classifier on an instance list for a
particular target entry */
public double getPrecision (Object labelEntry)
{
int index;
if (labelEntry instanceof Labeling)
index = ((Labeling)labelEntry).getBestIndex();
else
index = classifier.getLabelAlphabet().lookupIndex(labelEntry, false);
if (index == -1) throw new IllegalArgumentException ("Label "+labelEntry.toString()+" is not a valid label.");
return getPrecision (index);
}
public double getPrecision (Labeling label)
{
return getPrecision (label.getBestIndex());
}
/** Calculate the precision for a particular target index from an
array list of classifications */
public double getPrecision (int index)
{
int numCorrect = 0;
int numInstances = 0;
int trueLabel, classLabel;
for (int i = 0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy