
cc.mallet.fst.PerClassAccuracyEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst;
import java.util.logging.Logger;
import java.text.DecimalFormat;
import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
/**
* Determines the precision, recall and F1 on a per-class basis.
*
* @author Charles Sutton
* @version $Id: PerClassAccuracyEvaluator.java,v 1.1 2007/10/22 21:37:48 mccallum Exp $
*/
public class PerClassAccuracyEvaluator extends TransducerEvaluator {
private static Logger logger = MalletLogger.getLogger(TokenAccuracyEvaluator.class.getName());
public PerClassAccuracyEvaluator (InstanceList[] instanceLists, String[] descriptions) {
super (instanceLists, descriptions);
}
public PerClassAccuracyEvaluator (InstanceList i1, String d1) {
this (new InstanceList[] {i1}, new String[] {d1});
}
public PerClassAccuracyEvaluator (InstanceList i1, String d1, InstanceList i2, String d2) {
this (new InstanceList[] {i1, i2}, new String[] {d1, d2});
}
public void evaluateInstanceList (TransducerTrainer tt, InstanceList data, String description)
{
Transducer model = tt.getTransducer();
Alphabet dict = model.getInputPipe().getTargetAlphabet();
int numLabels = dict.size();
int[] numCorrectTokens = new int [numLabels];
int[] numPredTokens = new int [numLabels];
int[] numTrueTokens = new int [numLabels];
logger.info("Per-token results for " + description);
for (int i = 0; i < data.size(); i++) {
Instance instance = data.get(i);
Sequence input = (Sequence) instance.getData();
Sequence trueOutput = (Sequence) instance.getTarget();
assert (input.size() == trueOutput.size());
Sequence predOutput = model.transduce (input);
assert (predOutput.size() == trueOutput.size());
for (int j = 0; j < trueOutput.size(); j++) {
int idx = dict.lookupIndex(trueOutput.get(j));
numTrueTokens[idx]++;
numPredTokens[dict.lookupIndex(predOutput.get(j))]++;
if (trueOutput.get(j).equals(predOutput.get(j)))
numCorrectTokens[idx]++;
}
}
DecimalFormat f = new DecimalFormat ("0.####");
double[] allf = new double [numLabels];
for (int i = 0; i < numLabels; i++) {
Object label = dict.lookupObject(i);
double precision = ((double) numCorrectTokens[i]) / numPredTokens[i];
double recall = ((double) numCorrectTokens[i]) / numTrueTokens[i];
double f1 = (2 * precision * recall) / (precision + recall);
if (!Double.isNaN (f1)) allf [i] = f1;
logger.info(description +" label " + label + " P " + f.format (precision)
+ " R " + f.format(recall) + " F1 "+ f.format (f1));
}
logger.info ("Macro-average F1 "+f.format (MatrixOps.mean (allf)));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy