gov.sandia.cognition.learning.performance.RootMeanSquaredErrorEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: RootMeanSquaredErrorEvaluator.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright December 18, 2007, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*
*/
package gov.sandia.cognition.learning.performance;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import java.util.Collection;
/**
* The {@code RootMeanSquaredErrorEvaluator} class implements a method for
* computing the performance of a supervised learner for a scalar function by
* the root mean squared error (RMSE or RSE) between the target and estimated
* outputs.
*
* @param The type of the input to the evaluator to compute the
* performance of.
* @author Justin Basilico
* @since 2.0
*/
public class RootMeanSquaredErrorEvaluator
extends AbstractSupervisedPerformanceEvaluator
{
/**
* Creates a new {@code RootMeanSquaredErrorEvaluator}.
*/
public RootMeanSquaredErrorEvaluator()
{
super();
}
/**
* {@inheritDoc}
*
* @param data {@inheritDoc}
* @return {@inheritDoc}
*/
public Double evaluatePerformance(
final Collection extends TargetEstimatePair extends Double, ? extends Double>> data )
{
return RootMeanSquaredErrorEvaluator.compute( data );
}
/**
* Computes the mean squared error for the given pairs of values. The
* squared difference between the two values in each pair is computed and
* then the mean over all the values is returned.
*
* @param data The data to compute the mean squared error over.
* @return The mean squared error.
*/
public static double compute(
final Collection extends TargetEstimatePair extends Double, ? extends Double>> data )
{
// Since we compute the mean we need to know how many items there are.
final int count = data.size();
if (count <= 0)
{
// There must be at least one item to compute a mean.
return 0.0;
}
// Compute the error for each pair and add it to the sum.
double errorSum = 0.0;
for (TargetEstimatePair extends Double, ? extends Double> pair : data)
{
final double target = pair.getTarget();
final double estimate = pair.getEstimate();
final double difference = target - estimate;
// The error is the squared difference.
final double error = difference * difference;
errorSum += error;
}
// Compute the mean of the error sum.
return Math.sqrt( errorSum / (double) count );
}
}