All Downloads are FREE. Search and download functionalities are using the official Maven repository.

eu.fbk.twm.utils.Evaluator Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2005 ITC-irst (http://www.itc.it/)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package eu.fbk.twm.utils;

import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

import java.io.*;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;

/**
 * TO DO
 *
 * @author Claudio Giuliano
 * @version %I%, %G%
 * @since 1.0
 */
public class Evaluator {
	/**
	 * Define a static logger variable so that it references the
	 * Logger instance named Evaluator.
	 */
	static Logger logger = Logger.getLogger(Evaluator.class.getName());

	//
	public static final int MAX_NUMBER_OF_CLASSES = 20;

	//
	private double[] tp = new double[MAX_NUMBER_OF_CLASSES];

	//
	private double[] fp = new double[MAX_NUMBER_OF_CLASSES];

	//
	private double[] tn = new double[MAX_NUMBER_OF_CLASSES];

	//
	private double[] fn = new double[MAX_NUMBER_OF_CLASSES];

	//
	private double microTP = 0;

	//
	private double microFP = 0;

	//
	private double microFN = 0;

	//
	private double microTN = 0;

	//
	private int total = 0;

	//
	private int correct = 0;

	//
	DecimalFormat decFormatter;

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator() {
		//logger.debug("Evaluator.Evaluator");
		decFormatter = new DecimalFormat("0.000");
	} // end constructor

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator(int tp, int fp, int fn, int total) {
		this();
		logger.debug("Evaluator");

		this.tp[1] = tp;
		this.fp[1] = fp;
		this.fn[1] = fn;
		this.total = total;
		this.tn[1] = total - tp - fp - fn;
		microTP = tp;
		microFP = fp;
		microFN = fn;
		microTN = tn[1];
		this.correct = (int) (tp + tn[1]);
	} // end constructor

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator(File f) throws IOException {
		this();
		read(new BufferedReader(new FileReader(f)));
	} // end constructor

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator(String ref, String pred) throws IOException, IndexOutOfBoundsException {
		this(new File(ref), new File(pred));
	} // end constructor

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator(File refFile, File predFile) throws IOException, IndexOutOfBoundsException {
		this();

		//uncomment here to evaluate y={+1,1}
		//List ref = readRef(refFile);
		//uncomment here to evaluate y={0,1,2,...}
		List ref = read(refFile);
		List pred = read(predFile);

		if (ref.size() != pred.size()) {
			throw new IndexOutOfBoundsException(ref.size() + " != " + pred.size());
		}

		//eval2(ref, pred);
		eval3(ref, pred);
	} // end constructor

	/**
	 * Creates a Evaluator object.
	 */
	public Evaluator(List ref, List pred) throws IndexOutOfBoundsException {
		//logger.debug("Evaluator.Evaluator");
		decFormatter = new DecimalFormat("0.000");

		if (ref.size() != pred.size()) {
			throw new IndexOutOfBoundsException(ref.size() + " != " + pred.size());
		}

		eval3(ref, pred);
	} // end constructor

	//
	public int getTN() {
		return (int) microTN;
	}

	//
	public int getTN(int c) {
		return (int) tn[c];
	}

	//
	public int getTP() {
		return (int) microTP;
	}

	//
	public int getTP(int c) {
		return (int) tp[c];
	}

	//
	public int getFP() {
		return (int) microFP;
	}

	//
	public int getFP(int c) {
		return (int) fp[c];
	}

	//
	public int getFN() {
		return (int) microFN;
	}

	//
	public int getFN(int c) {
		return (int) fn[c];
	}

	//
	public double getPrecision() {
		if (microTP + microFP == 0) {
			return 0;
		}
		return microTP / (microTP + microFP);
	}

	//
	public double getPrecision(int i) {
		if (tp[i] + fp[i] == 0) {
			return 0;
		}
		return tp[i] / (tp[i] + fp[i]);
	}

	//
	public double getRecall() {
		if (microTP + microFN == 0) {
			return 0;
		}
		return microTP / (microTP + microFN);
	}

	//
	public double getRecall(int i) {
		if (tp[i] + fn[i] == 0) {
			return 0;
		}
		return tp[i] / (tp[i] + fn[i]);
	}

	//
	public double getF1() {
		double prec = getPrecision();
		double recall = getRecall();
		if (prec + recall == 0) {
			return 0;
		}
		return (2 * prec * recall) / (prec + recall);
	} // end getF1()

	//
	public double getAccuracy() {
		//return (getTP() + getTN()) / (microTP + microTN + microFP + microFN);
		logger.info("getAccuracy " + correct + "/" + total + "=" + (double) correct / total);
		return (double) correct / total;
	} // end getAccuracy

	//
	public double getAccuracy(int i) {
		return (tp[i] + tn[i]) / (tp[i] + tn[i] + fp[i] + fn[i]);
	} // end getAccuracy

	//
	public double getF1(int i) {
		double prec = getPrecision(i);
		double recall = getRecall(i);
		if (prec + recall == 0) {
			return 0;
		}
		return (2 * prec * recall) / (prec + recall);
	} // end getF1

	//
	public int getTotal() {
		return total;
	}

	//
	public int getCorrect() {
		return correct;
	}

	// case y = {-1,1}
	protected List readRef(File f) throws IOException {
		List list = new ArrayList();
		LineNumberReader lr = new LineNumberReader(new FileReader(f));
		String line = null;

		while ((line = lr.readLine()) != null) {
			String[] s = line.split("\t");
			if (s[0].equals("0")) {
				list.add(new Double("-1"));
			}
			else if (s[0].equals("1")) {
				list.add(new Double("1"));
			}

		}

		return list;
	} // end readRef

	//
	protected List read(File f) throws IOException {
		List list = new ArrayList();
		LineNumberReader lr = new LineNumberReader(new FileReader(f));
		String line = null;

		while ((line = lr.readLine()) != null) {
			String[] s = line.split("\t");
			//logger.debug((i++) + " " + s[0]);
			//list.add(new Double(s[0]));
			list.add(s[0]);
		}

		return list;
	} // end read

	//
	protected void eval3(List ref, List pred) {
		//logger.info("Evaluator.eval3");
		String target, v;

		for (int i = 0; i < ref.size(); i++) {
			//target = ((Double) ref.get(i)).doubleValue();
			target = (String) ref.get(i);
			//v = ((Double) pred.get(i)).doubleValue();
			v = (String) pred.get(i);

			if (v.equals(target)) {
				++correct;
			}

			// case y = {0, 1, 2, ...}

			if (v.equals("null")) {
				if (target.equals(v)) {
					microTN++;
				}
				else {
					microFN++;
				}
			}
			else {
				if (target.equals("null")) {
					microFP++;
				}
				else if (target.equals(v)) {
					microTP++;
				}
				else {
					microFP++;
					microFN++;
				}
			}


			++total;
		} // end for

	} // end eval3

	//
	protected void eval(List ref, List pred) {
		//logger.info("Evaluator.eval");
		double target, v;

		for (int i = 0; i < ref.size(); i++) {
			target = ((Double) ref.get(i)).doubleValue();
			v = ((Double) pred.get(i)).doubleValue();

			if (v == target) {
				++correct;
			}

			// case y = {0, 1, 2, ...}

			if (v == 0) {
				if (target == v) {
					microTN++;
				}
				else {
					microFN++;
				}
			}
			else {
				if (target == 0) {
					microFP++;
				}
				else if (target == v) {
					microTP++;
				}
				else {
					microFP++;
					microFN++;
				}
			}

			// case y = {-1, +1}
/*
		if (v == 1)
			{
				if (target == 1)
					microTP++;
				else
					microFP++;
			}		
			else
			{
				if (target == -1)
					microTN++;
				else
					microFN++;	
			}
*/
			++total;
		} // end for

	} // end eval

	//
	protected void eval2(List ref, List pred) {
		logger.info("Evaluator.eval2");
		double y, v;
		double maxY = 0;
		for (int i = 0; i < ref.size(); i++) {
			// y
			y = ((Double) ref.get(i)).doubleValue();
			// prediction
			v = ((Double) pred.get(i)).doubleValue();

			if (y > maxY) {
				maxY = y;
			}

			if (v == y) {
				++correct;
			}

			// case y = {0, 1, 2, ...}
			// 0
			if (v == 0) {
				if (y == v) {
					microTN++;
				}
				else {
					microFN++;
					fn[(int) y]++;
				}

			}
			else // 1, 2, 3
			{
				if (y == 0) {
					microFP++;
					fp[(int) v]++;
				}
				else if (y == v) {
					microTP++;
					tp[(int) v]++;
				}
				else {
					microFP++;
					microFN++;
					fp[(int) v]++;
					fn[(int) y]++;
				}
			}

			++total;
		} // end for

		logger.info(correct + "/" + total + "=" + (double) correct / total);

	} // end eval2

	//
	public void add(Evaluator eval) {
		microTP += eval.getTP();
		microTN += eval.getTN();
		microFP += eval.getFP();
		microFN += eval.getFN();

		total += eval.getTotal();
		correct += eval.getCorrect();

		for (int i = 0; i < MAX_NUMBER_OF_CLASSES; i++) {
			tp[i] += eval.getTP(i);
			tn[i] += eval.getTN(i);
			fp[i] += eval.getFP(i);
			fn[i] += eval.getFN(i);


		} // end for i
	} // end add

	//
	public Evaluator get(int i) {
		return new Evaluator(getTP(i), getFP(i), getFN(i), total);
	} // end get

	//
	public void read(Reader r) throws IOException {
		LineNumberReader lr = new LineNumberReader(r);
		String line = null;

		// first line
		if ((line = lr.readLine()) == null) {
			return;
		}

		while ((line = lr.readLine()) != null) {
			String[] s = line.split("\t");

			if (s[0].equals("micro")) {

				microTP = Integer.parseInt(s[1]);
				microFP = Integer.parseInt(s[2]);
				microFN = Integer.parseInt(s[3]);
				total = Integer.parseInt(s[4]);
				microTN = total - microTP - microFP - microFN;

			}
			else {
				int i = Integer.parseInt(s[0]);
				tp[i] = Integer.parseInt(s[1]);
				fp[i] = Integer.parseInt(s[2]);
				fn[i] = Integer.parseInt(s[3]);
				total = Integer.parseInt(s[4]);
				tn[i] = total - tp[i] - fp[i] - fn[i];
				correct += tp[i] + tn[i];
			}
		}
	} // end read

	//
	public void write(Writer w) throws IOException {

		w.write("c\ttp\tfp\tfn\ttotal\tprec\trecall\tF1\tacc\n");

		int count = 0;
		for (int i = 1; i < MAX_NUMBER_OF_CLASSES; i++) {
			if ((tp[i] != 0) || (fp[i] != 0) || (fn[i] != 0)) {
				//w.write(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\t" + decFormatter.format(getAccuracy(i)) + "\n");
				w.write(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\n");
				count++;
			}
		} // end for i
		if (count > 1) {
			//sb.append("\n");
			w.write("micro\t" + getTP() + "\t" + getFP() + "\t" + getFN() + "\t" + getTotal() + "\t" + decFormatter.format(getPrecision()) + "\t" + decFormatter.format(getRecall()) + "\t" + decFormatter.format(getF1()) + "\t" + decFormatter.format(getAccuracy()) + "\n");
		}

		w.flush();
	} // end write

	//
	public String toString() {
		StringBuffer sb = new StringBuffer();
		sb.append("c\ttp\tfp\tfn\ttotal\tprec\trecall\tF1\tacc\n");

		int count = 0;
		for (int i = 1; i < MAX_NUMBER_OF_CLASSES; i++) {
			if ((tp[i] != 0) || (fp[i] != 0) || (fn[i] != 0)) {
				//sb.append(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\t" + decFormatter.format(getAccuracy(i)) + "\n");
				sb.append(i + "\t" + (int) tp[i] + "\t" + (int) fp[i] + "\t" + (int) fn[i] + "\t" + (int) total + "\t" + decFormatter.format(getPrecision(i)) + "\t" + decFormatter.format(getRecall(i)) + "\t" + decFormatter.format(getF1(i)) + "\n");
				count++;
			}
		} // end for i
		if (count > 1) {
			//sb.append("\n");
			sb.append("micro\t" + getTP() + "\t" + getFP() + "\t" + getFN() + "\t" + getTotal() + "\t" + decFormatter.format(getPrecision()) + "\t" + decFormatter.format(getRecall()) + "\t" + decFormatter.format(getF1()) + "\t" + decFormatter.format((double) correct / total) + "\n");
		}
		return sb.toString();
	} // end toString

	//
	public static void main(String args[]) throws Exception {
		String logConfig = System.getProperty("log-config");
		if (logConfig == null) {
			logConfig = "log-config.txt";
		}

		PropertyConfigurator.configure(logConfig);

		if (args.length != 2) {
			System.err.println("java -mx512M org.itc.irst.tcc.sre.util.Evaluator reference-file answer-file");
			System.exit(0);
		}

		File ref = new File(args[0]);
		File ans = new File(args[1]);
		Evaluator eval = new Evaluator(ref, ans);

		//
		DecimalFormat formatter = new DecimalFormat("0.00");


		//System.out.println("microTP\tfp\tfn\ttotal\tprec\trecall\tF1\tacc");
		System.out.println(eval.getPrecision() + "\t" + eval.getRecall() + "\t" + eval.getF1());
		System.out.println(formatter.format(eval.getPrecision()) + "\t" + formatter.format(eval.getRecall()) + "\t" + formatter.format(eval.getF1()));

	} // end main
} // end class Evaluator




© 2015 - 2025 Weber Informatics LLC | Privacy Policy