gov.sandia.cognition.learning.function.cost.SumSquaredErrorCostFunction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: SumSquaredErrorCostFunction.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 4, 2008, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.cost;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.Collection;
/**
* This is the sum-squared error cost function
* @author Kevin R. Dixon
* @since 2.1
*/
public class SumSquaredErrorCostFunction
extends AbstractParallelizableCostFunction
{
/**
* Creates a new instance of SumSquaredErrorCostFunction
*/
public SumSquaredErrorCostFunction()
{
this( (Collection extends InputOutputPair extends Vector, Vector>>) null );
}
/**
* Creates a new instance of MeanSquaredErrorCostFunction
*
* @param dataset The dataset of examples to use to compute the error.
*/
public SumSquaredErrorCostFunction(
Collection extends InputOutputPair extends Vector, Vector>> dataset )
{
super( dataset );
}
@Override
public SumSquaredErrorCostFunction clone()
{
return (SumSquaredErrorCostFunction) super.clone();
}
public Object evaluatePartial(
Evaluator super Vector, ? extends Vector> evaluator )
{
double sumSquaredError = 0.0;
double weightSum = 0.0;
for (InputOutputPair extends Vector,Vector> pair : this.getCostParameters() )
{
// Compute the error vector.
Vector target = pair.getOutput();
Vector estimate = evaluator.evaluate( pair.getInput() );
double errorSquared = target.euclideanDistanceSquared( estimate );
double weight = DatasetUtil.getWeight(pair);
weightSum += weight;
sumSquaredError += weight * errorSquared;
}
weightSum *= 2.0;
return new EvaluatePartialSSE( sumSquaredError, weightSum );
}
public Double evaluateAmalgamate(
Collection