gov.sandia.cognition.learning.algorithm.regression.KernelWeightedRobustRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: KernelWeightedRobustRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Nov 18, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import java.util.ArrayList;
/**
* KernelWeightedRobustRegression takes a supervised learning algorithm that
* operates on a weighted collection of InputOutputPairs and modifies the
* weight of a sample based on the dataset output and its corresponding
* estimate from the Evaluator from the supervised learning algorithm at each
* iteration. This weight is added to the dataset sample and the supervised
* learning algorithm is run again. This process repeats until the weights
* converge. This algorithm is a direct generalization of the LOESS-based
* (LOWESS-based) Robust Regression using a general learner and kernel.
*
* A typical use case is using a regression algorithm (LinearRegression or
* DecoupledVectorLinearRegression) and a RadialBasisKernel. This results in
* a regression algorithm that learns to "ignore" outliers and fit the
* remaining data. (Think of fitting a height-versus-age curve and an 8-foot
* tall Yao Ming made it into your training set, skewing your results with that
* massive outlier.)
*
* KernelWeightedRobustRegression is different from LocallyWeightedLearning in
* that KWRR creates a global function approximator and holds for all inputs.
* Thus, learning time for KWRR is relatively high up front, but evaluation time
* is relatively low. On the other hand, LWL creates a local function
* approximator in response to each evaluation, and LWL does not create a global
* function approximator. As such, LWL has (almost) no up-front learning time,
* but each evaluation requires a relatively high evaluation.
*
* KWRR is more appropriate when you know the general structure of your data,
* but it is riddled with outliers. LWL is more appropriate when you don't
* know/understand the general trend of your data AND you can afford evaluation
* time to be somewhat costly.
*
* @param Input class for the Evaluator and inputs on the
* InputOutputPairs dataset
* @param Output class for the Evaluator, outputs on the
* InputOutputPairs dataset. Furthermore, the Kernel must be able to
* evaluate OutputTypes.
* @author Kevin R. Dixon
* @since 2.0
*/
public class KernelWeightedRobustRegression
extends AbstractAnytimeSupervisedBatchLearner>
{
/**
* DecoupledVectorFunction that is being optimized
*/
private Evaluator super InputType, ? extends OutputType> result;
/**
* Internal learning algorithm that computes optimal solutions
* given the current weightedData. The iterationLearner should operate on
* WeightedInputOutputPairs (we have a hard time enforcing this, as many
* learning algorithms operate both on InputOutputPairs and
* WeightedInputOutputPairs)
*/
private SupervisedBatchLearner iterationLearner;
/**
* Kernel function that provides the weighting for the estimate error,
* generally the Kernel should weight accurate estimates higher than
* inaccurate estimates.
*/
private Kernel super OutputType> kernelWeightingFunction;
/**
* Tolerance before stopping the algorithm
*/
private double tolerance;
/**
* Default maximum number of iterations before stopping
*/
public static final int DEFAULT_MAX_ITERATIONS = 100;
/**
* Default tolerance stopping criterion
*/
public static final double DEFAULT_TOLERANCE = 1e-5;
/**
* Creates a new instance of RobustRegression
* @param iterationLearner
* Internal learning algorithm that computes optimal solutions
* given the current weightedData. The iterationLearner should operate on
* WeightedInputOutputPairs (we have a hard time enforcing this, as many
* learning algorithms operate both on InputOutputPairs and
* WeightedInputOutputPairs and their prototype is "? extends InputOutputPair")
* @param kernelWeightingFunction
* Kernel function that provides the weighting for the estimate error,
* generally the Kernel should weight accurate estimates higher than
* inaccurate estimates.
*/
public KernelWeightedRobustRegression(
SupervisedBatchLearner iterationLearner,
Kernel super OutputType> kernelWeightingFunction )
{
this( iterationLearner, kernelWeightingFunction, DEFAULT_MAX_ITERATIONS, DEFAULT_TOLERANCE );
}
/**
* Creates a new instance of RobustRegression
* @param iterationLearner
* Internal learning algorithm that computes optimal solutions
* given the current weightedData. The iterationLearner should operate on
* WeightedInputOutputPairs (we have a hard time enforcing this, as many
* learning algorithms operate both on InputOutputPairs and
* WeightedInputOutputPairs and their prototype is "? extends InputOutputPair")
* @param kernelWeightingFunction
* Kernel function that provides the weighting for the estimate error,
* generally the Kernel should weight accurate estimates higher than
* inaccurate estimates.
* @param maxIterations
* @param tolerance
* Tolerance before stopping the algorithm
*/
public KernelWeightedRobustRegression(
SupervisedBatchLearner iterationLearner,
Kernel super OutputType> kernelWeightingFunction,
int maxIterations,
double tolerance )
{
super( maxIterations );
this.setLearned( null );
this.setTolerance( tolerance );
this.setKernelWeightingFunction( kernelWeightingFunction );
this.setIterationLearner( iterationLearner );
}
/**
* Weighted copy of the data
*/
private ArrayList> weightedData;
protected boolean initializeAlgorithm()
{
this.weightedData =
new ArrayList>(
this.data.size() );
for (InputOutputPair extends InputType, ? extends OutputType> pair : this.data)
{
// Create an initial weighted dataset, using the weights from
// the original dataset, if available, otherwise just use
// uniform weights
double weight = DatasetUtil.getWeight(pair);
this.weightedData.add( new DefaultWeightedInputOutputPair(
pair.getInput(), pair.getOutput(), weight ) );
}
return true;
}
protected boolean step()
{
// Compute the learner based on the current weighting of the samples
this.result = this.iterationLearner.learn( this.weightedData );
// Update the weight set using the result function and the Kernel
// and track the how much the weights have changed
double change = this.updateWeights( this.result );
// If the weights have stabilized, then we're done.
return (change > this.tolerance);
}
protected void cleanupAlgorithm()
{
}
/**
* Updates the weightedData from the given prediction function and
* the internal Kernel
* @param f
* Prediction function to use to update the weights of the weightedData
* using the Kernel
* @return Mean L1 norm of the weight change
*/
private double updateWeights(
Evaluator super InputType, ? extends OutputType> f )
{
double change = 0.0;
for (DefaultWeightedInputOutputPair pair
: this.weightedData)
{
// Use the kernel to determine the new weight of the sample
// Generally, the kernel should weight accurate samples more
// than inaccurate samples
OutputType yhat = f.evaluate( pair.getInput() );
double weightNew = this.kernelWeightingFunction.evaluate(
pair.getOutput(), yhat );
double weightOld = pair.getWeight();
change += Math.abs( weightNew - weightOld );
pair.setWeight( weightNew );
}
change /= this.weightedData.size();
return change;
}
/**
* Getter for kernelWeightingFunction
* @return
* Kernel function that provides the weighting for the estimate error,
* generally the Kernel should weight accurate estimates higher than
* inaccurate estimates.
*/
public Kernel super OutputType> getKernelWeightingFunction()
{
return this.kernelWeightingFunction;
}
/**
* Getter for kernelWeightingFunction
* @param kernelWeightingFunction
* Kernel function that provides the weighting for the estimate error,
* generally the Kernel should weight accurate estimates higher than
* innaccurate estimates.
*/
public void setKernelWeightingFunction(
Kernel super OutputType> kernelWeightingFunction )
{
this.kernelWeightingFunction = kernelWeightingFunction;
}
/**
* Getter for tolerance
* @return
* Tolerance before stopping the algorithm
*/
public double getTolerance()
{
return this.tolerance;
}
/**
* Setter for tolerance
* @param tolerance
* Tolerance before stopping the algorithm
*/
public void setTolerance(
double tolerance )
{
if (tolerance <= 0.0)
{
throw new IllegalArgumentException(
"Tolerance must be > 0.0" );
}
this.tolerance = tolerance;
}
/**
* Getter for result
* @param result
* DecoupledVectorFunction that is being optimized
*/
public void setLearned(
Evaluator result )
{
this.result = result;
}
public Evaluator super InputType, ? extends OutputType> getResult()
{
return this.result;
}
/**
* Getter for iterationLearner
* @return
* Internal learning algorithm that computes optimal solutions
* given the current weightedData. The iterationLearner should operate on
* WeightedInputOutputPairs (we have a hard time enforcing this, as many
* learning algorithms operate both on InputOutputPairs and
* WeightedInputOutputPairs)
*/
public SupervisedBatchLearner getIterationLearner()
{
return this.iterationLearner;
}
/**
*
* @param iterationLearner
* Internal learning algorithm that computes optimal solutions
* given the current weightedData. The iterationLearner should operate on
* WeightedInputOutputPairs (we have a hard time enforcing this, as many
* learning algorithms operate both on InputOutputPairs and
* WeightedInputOutputPairs)
*/
public void setIterationLearner(
SupervisedBatchLearner iterationLearner )
{
this.iterationLearner = iterationLearner;
}
}