All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.performance.RootMeanSquaredErrorEvaluator Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                RootMeanSquaredErrorEvaluator.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright December 18, 2007, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive 
 * license for use of this work by or on behalf of the U.S. Government. Export 
 * of this program may require a license from the United States Government. 
 * See CopyrightHistory.txt for complete details.
 *
 * 
 */

package gov.sandia.cognition.learning.performance;

import gov.sandia.cognition.learning.data.TargetEstimatePair;
import java.util.Collection;

/**
 * The {@code RootMeanSquaredErrorEvaluator} class implements a method for 
 * computing the performance of a supervised learner for a scalar function by 
 * the root mean squared error (RMSE or RSE) between the target and estimated 
 * outputs.
 * 
 * @param    The type of the input to the evaluator to compute the
 *          performance of.
 * @author  Justin Basilico
 * @since   2.0
 */
public class RootMeanSquaredErrorEvaluator
    extends AbstractSupervisedPerformanceEvaluator
{

    /**
     * Creates a new {@code RootMeanSquaredErrorEvaluator}.
     */
    public RootMeanSquaredErrorEvaluator()
    {
        super();
    }

    /**
     * {@inheritDoc}
     *
     * @param  data {@inheritDoc}
     * @return {@inheritDoc}
     */
    public Double evaluatePerformance(
        final Collection> data )
    {
        return RootMeanSquaredErrorEvaluator.compute( data );
    }

    /**
     * Computes the mean squared error for the given pairs of values. The
     * squared difference between the two values in each pair is computed and
     * then the mean over all the values is returned.
     *
     * @param  data The data to compute the mean squared error over.
     * @return The mean squared error.
     */
    public static double compute(
        final Collection> data )
    {
        // Since we compute the mean we need to know how many items there are.
        final int count = data.size();

        if (count <= 0)
        {
            // There must be at least one item to compute a mean.
            return 0.0;
        }

        // Compute the error for each pair and add it to the sum.
        double errorSum = 0.0;
        for (TargetEstimatePair pair : data)
        {
            final double target = pair.getTarget();
            final double estimate = pair.getEstimate();
            final double difference = target - estimate;

            // The error is the squared difference.
            final double error = difference * difference;
            errorSum += error;
        }

        // Compute the mean of the error sum.
        return Math.sqrt( errorSum / (double) count );
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy