gov.sandia.cognition.learning.performance.MeanSquaredErrorEvaluator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MeanSquaredError.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 26, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*
*/
package gov.sandia.cognition.learning.performance;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import java.util.Collection;
/**
* The {@code MeanSquaredError} class implements the method for computing the
* performance of a supervised learner for a scalar function by the mean squared
* between the target and estimated outputs.
*
* @param The type of the input to the evaluator to compute the
* performance of.
* @author Justin Basilico
* @since 2.0
*/
public class MeanSquaredErrorEvaluator
extends AbstractSupervisedPerformanceEvaluator
{
/**
* Creates a new instance of MeanSquaredError.
*/
public MeanSquaredErrorEvaluator()
{
super();
}
/**
* {@inheritDoc}
*
* @param data {@inheritDoc}
* @return {@inheritDoc}
*/
public Double evaluatePerformance(
final Collection extends TargetEstimatePair extends Double,? extends Double>> data )
{
return MeanSquaredErrorEvaluator.compute( data );
}
/**
* Computes the mean squared error for the given pairs of values. The
* squared difference between the two values in each pair is computed and
* then the mean over all the values is returned.
*
* @param data The data to compute the mean squared error over.
* @return The mean squared error.
*/
public static double compute(
final Collection extends TargetEstimatePair extends Double, ? extends Double>> data )
{
// Since we compute the mean we need to know how many items there are.
final int count = data.size();
if (count <= 0)
{
// There must be at least one item to compute a mean.
return 0.0;
}
// Compute the error for each pair and add it to the sum.
double errorSum = 0.0;
for (TargetEstimatePair extends Double, ? extends Double> pair : data)
{
final double target = pair.getTarget();
final double estimate = pair.getEstimate();
final double difference = target - estimate;
// The error is the squared difference.
final double error = difference * difference;
errorSum += error;
}
// Compute the mean of the error sum.
return errorSum / (double) count;
}
}