All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.eval.Evaluation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.eval;


import java.util.Set;

import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.SimpleBlas;


/**
 * Evaluation metrics: precision, recall, f1
 * @author Adam Gibson
 *
 */
public class Evaluation {

	private double truePositives;
	private Counter falsePositives = new Counter();
	private double falseNegatives;
	private ConfusionMatrix confusion = new ConfusionMatrix();

	/**
	 * Collects statistics on the real outcomes vs the 
	 * guesses. This is for logistic outcome matrices such that the 
	 * 
	 * Note that an IllegalArgumentException is thrown if the two passed in
	 * matrices aren't the same length.
	 * @param realOutcomes the real outcomes (usually binary)
	 * @param guesses the guesses (usually a probability vector)
	 */
	public void eval(DoubleMatrix realOutcomes,DoubleMatrix guesses) {
		if(realOutcomes.length != guesses.length)
			throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
		for(int i = 0; i < realOutcomes.rows; i++) {
			DoubleMatrix currRow = realOutcomes.getRow(i);
			DoubleMatrix guessRow = guesses.getRow(i);

			int currMax = SimpleBlas.iamax(currRow);
			int guessMax = SimpleBlas.iamax(guessRow);

			addToConfusion(currMax,guessMax);

			if(currMax == guessMax)
				incrementTruePositives();
			else {
				incrementFalseNegatives();
				incrementFalsePositives(guessMax);
			}
		}
	}



	public String stats() {
		StringBuilder builder = new StringBuilder()
		.append("\n");
		Set classes = confusion.getClasses();
		for(Integer clazz : classes) {
			for(Integer clazz2 : classes) {
				int count = confusion.getCount(clazz, clazz2);
				if(count != 0)
					builder.append("\nActual Class " + clazz + " was predicted with Predicted " + clazz2 + " with count " + count  + " times\n");
			}
		}
		builder.append("\n==========================F1 Scores========================================");
		builder.append("\n " + f1());
		builder.append("\n===========================================================================");
		return builder.toString();
	}

	/**
	 * Adds to the confusion matrix
	 * @param real the actual guess
	 * @param guess the system guess
	 */
	public void addToConfusion(int real,int guess) {
		confusion.add(real, guess);
	}

	/**
	 * Returns the number of times the given label
	 * has actually occurred
	 * @param i the label
	 * @return the number of times the label
	 * actually occurred
	 */
	public int classCount(int i) {
		return confusion.getActualTotal(i);
	}

	/**
	 * Returns the number of times a given label was predicted 
	 * @param label the label to get
	 * @return the number of times the given label was predicted
	 */
	public int numtimesPredicted(int label) {
		return confusion.getPredictedTotal(label);
	}

	/**
	 * Gets the number of times the 
	 * given class was predicted for the 
	 * given predicted label
	 * @param actual 
	 * @param predicted
	 * @return
	 */
	public int numTimesPredicted(int actual,int predicted) {
		return confusion.getCount(actual, predicted);
	}

	public double precision() {
		double prec = 0.0;
		for(Integer i : confusion.getClasses()) {
			prec += precision(i);
		}
		return prec / (double) confusion.getClasses().size();
	}


	public double f1() {
		double precision = precision();
		double recall = recall();
		if(precision == 0 || recall == 0)
			return 0;
		return 2.0 * ((precision * recall / (precision + recall)));
	}

	/**
	 * Calculate f1 score for a given class
	 * @param i the label to calculate f1 for
	 * @return the f1 score for the given label
	 */
	public double f1(int i) {
		double precision = precision(i);
		double recall = recall();
		if(precision == 0 || recall == 0)
			return 0;
		return 2.0 * ((precision * recall / (precision + recall)));
	}

	/**
	 * Returns the recall for the outcomes
	 * @return the recall for the outcomes
	 */
	public double recall() {
		if(truePositives == 0)
			return 0;
		return truePositives / (truePositives + falseNegatives);
	}

	/**
	 * Returns the precision for a given label
	 * @param i the label
	 * @return the precision for the label
	 */
	public double precision(int i) {
		if(truePositives == 0)
			return 0;
		return truePositives / (truePositives + falsePositives.getCount(i));
	}


	public void incrementTruePositives() {
		truePositives++;
	}

	public void incrementFalseNegatives() {
		falseNegatives++;
	}

	public void incrementFalsePositives(int i) {
		falsePositives.incrementCount(i, 1.0);
	}


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy