
eu.fbk.twm.utils.Evaluator Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2005 ITC-irst (http://www.itc.it/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eu.fbk.twm.utils;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import java.io.*;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
/**
* TO DO
*
* @author Claudio Giuliano
* @version %I%, %G%
* @since 1.0
*/
public class Evaluator {
/**
* Define a static logger variable so that it references the
* Logger instance named Evaluator
.
*/
static Logger logger = Logger.getLogger(Evaluator.class.getName());
//
public static final int MAX_NUMBER_OF_CLASSES = 20;
//
private double[] tp = new double[MAX_NUMBER_OF_CLASSES];
//
private double[] fp = new double[MAX_NUMBER_OF_CLASSES];
//
private double[] tn = new double[MAX_NUMBER_OF_CLASSES];
//
private double[] fn = new double[MAX_NUMBER_OF_CLASSES];
//
private double microTP = 0;
//
private double microFP = 0;
//
private double microFN = 0;
//
private double microTN = 0;
//
private int total = 0;
//
private int correct = 0;
//
DecimalFormat decFormatter;
/**
* Creates a Evaluator
object.
*/
public Evaluator() {
//logger.debug("Evaluator.Evaluator");
decFormatter = new DecimalFormat("0.000");
} // end constructor
/**
* Creates a Evaluator
object.
*/
public Evaluator(int tp, int fp, int fn, int total) {
this();
logger.debug("Evaluator");
this.tp[1] = tp;
this.fp[1] = fp;
this.fn[1] = fn;
this.total = total;
this.tn[1] = total - tp - fp - fn;
microTP = tp;
microFP = fp;
microFN = fn;
microTN = tn[1];
this.correct = (int) (tp + tn[1]);
} // end constructor
/**
* Creates a Evaluator
object.
*/
public Evaluator(File f) throws IOException {
this();
read(new BufferedReader(new FileReader(f)));
} // end constructor
/**
* Creates a Evaluator
object.
*/
public Evaluator(String ref, String pred) throws IOException, IndexOutOfBoundsException {
this(new File(ref), new File(pred));
} // end constructor
/**
* Creates a Evaluator
object.
*/
public Evaluator(File refFile, File predFile) throws IOException, IndexOutOfBoundsException {
this();
//uncomment here to evaluate y={+1,1}
//List ref = readRef(refFile);
//uncomment here to evaluate y={0,1,2,...}
List ref = read(refFile);
List pred = read(predFile);
if (ref.size() != pred.size()) {
throw new IndexOutOfBoundsException(ref.size() + " != " + pred.size());
}
//eval2(ref, pred);
eval3(ref, pred);
} // end constructor
/**
* Creates a Evaluator
object.
*/
public Evaluator(List ref, List pred) throws IndexOutOfBoundsException {
//logger.debug("Evaluator.Evaluator");
decFormatter = new DecimalFormat("0.000");
if (ref.size() != pred.size()) {
throw new IndexOutOfBoundsException(ref.size() + " != " + pred.size());
}
eval3(ref, pred);
} // end constructor
//
public int getTN() {
return (int) microTN;
}
//
public int getTN(int c) {
return (int) tn[c];
}
//
public int getTP() {
return (int) microTP;
}
//
public int getTP(int c) {
return (int) tp[c];
}
//
public int getFP() {
return (int) microFP;
}
//
public int getFP(int c) {
return (int) fp[c];
}
//
public int getFN() {
return (int) microFN;
}
//
public int getFN(int c) {
return (int) fn[c];
}
//
public double getPrecision() {
if (microTP + microFP == 0) {
return 0;
}
return microTP / (microTP + microFP);
}
//
public double getPrecision(int i) {
if (tp[i] + fp[i] == 0) {
return 0;
}
return tp[i] / (tp[i] + fp[i]);
}
//
public double getRecall() {
if (microTP + microFN == 0) {
return 0;
}
return microTP / (microTP + microFN);
}
//
public double getRecall(int i) {
if (tp[i] + fn[i] == 0) {
return 0;
}
return tp[i] / (tp[i] + fn[i]);
}
//
public double getF1() {
double prec = getPrecision();
double recall = getRecall();
if (prec + recall == 0) {
return 0;
}
return (2 * prec * recall) / (prec + recall);
} // end getF1()
//
public double getAccuracy() {
//return (getTP() + getTN()) / (microTP + microTN + microFP + microFN);
logger.info("getAccuracy " + correct + "/" + total + "=" + (double) correct / total);
return (double) correct / total;
} // end getAccuracy
//
public double getAccuracy(int i) {
return (tp[i] + tn[i]) / (tp[i] + tn[i] + fp[i] + fn[i]);
} // end getAccuracy
//
public double getF1(int i) {
double prec = getPrecision(i);
double recall = getRecall(i);
if (prec + recall == 0) {
return 0;
}
return (2 * prec * recall) / (prec + recall);
} // end getF1
//
public int getTotal() {
return total;
}
//
public int getCorrect() {
return correct;
}
// case y = {-1,1}
protected List readRef(File f) throws IOException {
List list = new ArrayList();
LineNumberReader lr = new LineNumberReader(new FileReader(f));
String line = null;
while ((line = lr.readLine()) != null) {
String[] s = line.split("\t");
if (s[0].equals("0")) {
list.add(new Double("-1"));
}
else if (s[0].equals("1")) {
list.add(new Double("1"));
}
}
return list;
} // end readRef
//
protected List read(File f) throws IOException {
List list = new ArrayList();
LineNumberReader lr = new LineNumberReader(new FileReader(f));
String line = null;
while ((line = lr.readLine()) != null) {
String[] s = line.split("\t");
//logger.debug((i++) + " " + s[0]);
//list.add(new Double(s[0]));
list.add(s[0]);
}
return list;
} // end read
//
protected void eval3(List ref, List pred) {
//logger.info("Evaluator.eval3");
String target, v;
for (int i = 0; i < ref.size(); i++) {
//target = ((Double) ref.get(i)).doubleValue();
target = (String) ref.get(i);
//v = ((Double) pred.get(i)).doubleValue();
v = (String) pred.get(i);
if (v.equals(target)) {
++correct;
}
// case y = {0, 1, 2, ...}
if (v.equals("null")) {
if (target.equals(v)) {
microTN++;
}
else {
microFN++;
}
}
else {
if (target.equals("null")) {
microFP++;
}
else if (target.equals(v)) {
microTP++;
}
else {
microFP++;
microFN++;
}
}
++total;
} // end for
} // end eval3
//
protected void eval(List ref, List pred) {
//logger.info("Evaluator.eval");
double target, v;
for (int i = 0; i < ref.size(); i++) {
target = ((Double) ref.get(i)).doubleValue();
v = ((Double) pred.get(i)).doubleValue();
if (v == target) {
++correct;
}
// case y = {0, 1, 2, ...}
if (v == 0) {
if (target == v) {
microTN++;
}
else {
microFN++;
}
}
else {
if (target == 0) {
microFP++;
}
else if (target == v) {
microTP++;
}
else {
microFP++;
microFN++;
}
}
// case y = {-1, +1}
/*
if (v == 1)
{
if (target == 1)
microTP++;
else
microFP++;
}
else
{
if (target == -1)
microTN++;
else
microFN++;
}
*/
++total;
} // end for
} // end eval
//
protected void eval2(List ref, List pred) {
logger.info("Evaluator.eval2");
double y, v;
double maxY = 0;
for (int i = 0; i < ref.size(); i++) {
// y
y = ((Double) ref.get(i)).doubleValue();
// prediction
v = ((Double) pred.get(i)).doubleValue();
if (y > maxY) {
maxY = y;
}
if (v == y) {
++correct;
}
// case y = {0, 1, 2, ...}
// 0
if (v == 0) {
if (y == v) {
microTN++;
}
else {
microFN++;
fn[(int) y]++;
}
}
else // 1, 2, 3
{
if (y == 0) {
microFP++;
fp[(int) v]++;
}
else if (y == v) {
microTP++;
tp[(int) v]++;
}
else {
microFP++;
microFN++;
fp[(int) v]++;
fn[(int) y]++;
}
}
++total;
} // end for
logger.info(correct + "/" + total + "=" + (double) correct / total);
} // end eval2
//
public void add(Evaluator eval) {
microTP += eval.getTP();
microTN += eval.getTN();
microFP += eval.getFP();
microFN += eval.getFN();
total += eval.getTotal();
correct += eval.getCorrect();
for (int i = 0; i < MAX_NUMBER_OF_CLASSES; i++) {
tp[i] += eval.getTP(i);
tn[i] += eval.getTN(i);
fp[i] += eval.getFP(i);
fn[i] += eval.getFN(i);
} // end for i
} // end add
//
public Evaluator get(int i) {
return new Evaluator(getTP(i), getFP(i), getFN(i), total);
} // end get
//
public void read(Reader r) throws IOException {
LineNumberReader lr = new LineNumberReader(r);
String line = null;
// first line
if ((line = lr.readLine()) == null) {
return;
}
while ((line = lr.readLine()) != null) {
String[] s = line.split("\t");
if (s[0].equals("micro")) {
microTP = Integer.parseInt(s[1]);
microFP = Integer.parseInt(s[2]);
microFN = Integer.parseInt(s[3]);
total = Integer.parseInt(s[4]);
microTN = total - microTP - microFP - microFN;
}
else {
int i = Integer.parseInt(s[0]);
tp[i] = Integer.parseInt(s[1]);
fp[i] = Integer.parseInt(s[2]);
fn[i] = Integer.parseInt(s[3]);
total = Integer.parseInt(s[4]);
tn[i] = total - tp[i] - fp[i] - fn[i];
correct += tp[i] + tn[i];
}
}
} // end read
//
public void write(Writer w) throws IOException {
w.write("c\ttp\tfp\tfn\ttotal\tprec\trecall\tF1\tacc\n");
int count = 0;
for (int i = 1; i < MAX_NUMBER_OF_CLASSES; i++) {
if ((tp[i] != 0) || (fp[i] != 0) || (fn[i] != 0)) {
//w.write(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\t" + decFormatter.format(getAccuracy(i)) + "\n");
w.write(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\n");
count++;
}
} // end for i
if (count > 1) {
//sb.append("\n");
w.write("micro\t" + getTP() + "\t" + getFP() + "\t" + getFN() + "\t" + getTotal() + "\t" + decFormatter.format(getPrecision()) + "\t" + decFormatter.format(getRecall()) + "\t" + decFormatter.format(getF1()) + "\t" + decFormatter.format(getAccuracy()) + "\n");
}
w.flush();
} // end write
//
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append("c\ttp\tfp\tfn\ttotal\tprec\trecall\tF1\tacc\n");
int count = 0;
for (int i = 1; i < MAX_NUMBER_OF_CLASSES; i++) {
if ((tp[i] != 0) || (fp[i] != 0) || (fn[i] != 0)) {
//sb.append(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\t" + decFormatter.format(getAccuracy(i)) + "\n");
sb.append(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\n");
count++;
}
} // end for i
if (count > 1) {
//sb.append("\n");
sb.append("micro\t" + getTP() + "\t" + getFP() + "\t" + getFN() + "\t" + getTotal() + "\t" + decFormatter.format(getPrecision()) + "\t" + decFormatter.format(getRecall()) + "\t" + decFormatter.format(getF1()) + "\t" + decFormatter.format((double) correct / total) + "\n");
}
return sb.toString();
} // end toString
//
public static void main(String args[]) throws Exception {
String logConfig = System.getProperty("log-config");
if (logConfig == null) {
logConfig = "log-config.txt";
}
PropertyConfigurator.configure(logConfig);
if (args.length != 2) {
System.err.println("java -mx512M org.itc.irst.tcc.sre.util.Evaluator reference-file answer-file");
System.exit(0);
}
File ref = new File(args[0]);
File ans = new File(args[1]);
Evaluator eval = new Evaluator(ref, ans);
//
DecimalFormat formatter = new DecimalFormat("0.00");
//System.out.println("microTP\tfp\tfn\ttotal\tprec\trecall\tF1\tacc");
System.out.println(eval.getPrecision() + "\t" + eval.getRecall() + "\t" + eval.getF1());
System.out.println(formatter.format(eval.getPrecision()) + "\t" + formatter.format(eval.getRecall()) + "\t" + formatter.format(eval.getF1()));
} // end main
} // end class Evaluator
© 2015 - 2025 Weber Informatics LLC | Privacy Policy