All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.performance.MeanSquaredErrorEvaluator Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MeanSquaredError.java
 * Authors:             Justin Basilico
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright September 26, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 *
 */

package gov.sandia.cognition.learning.performance;

import gov.sandia.cognition.learning.data.TargetEstimatePair;
import java.util.Collection;

/**
 * The {@code MeanSquaredError} class implements the method for computing the
 * performance of a supervised learner for a scalar function by the mean squared
 * between the target and estimated outputs.
 *
 * @param    The type of the input to the evaluator to compute the
 *          performance of.
 * @author Justin Basilico
 * @since  2.0
 */
public class MeanSquaredErrorEvaluator
    extends AbstractSupervisedPerformanceEvaluator
{

    /**
     * Creates a new instance of MeanSquaredError.
     */
    public MeanSquaredErrorEvaluator()
    {
        super();
    }

    /**
     * {@inheritDoc}
     *
     * @param  data {@inheritDoc}
     * @return {@inheritDoc}
     */
    public Double evaluatePerformance(
        final Collection> data )
    {
        return MeanSquaredErrorEvaluator.compute( data );
    }

    /**
     * Computes the mean squared error for the given pairs of values. The
     * squared difference between the two values in each pair is computed and
     * then the mean over all the values is returned.
     *
     * @param  data The data to compute the mean squared error over.
     * @return The mean squared error.
     */
    public static double compute(
        final Collection> data )
    {
        // Since we compute the mean we need to know how many items there are.
        final int count = data.size();

        if (count <= 0)
        {
            // There must be at least one item to compute a mean.
            return 0.0;
        }

        // Compute the error for each pair and add it to the sum.
        double errorSum = 0.0;
        for (TargetEstimatePair pair : data)
        {
            final double target = pair.getTarget();
            final double estimate = pair.getEstimate();
            final double difference = target - estimate;

            // The error is the squared difference.
            final double error = difference * difference;
            errorSum += error;
        }

        // Compute the mean of the error sum.
        return errorSum / (double) count;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy