gov.sandia.cognition.learning.algorithm.regression.LinearBasisRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: LinearRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 5, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
/**
* Computes the least-squares regression for a LinearCombinationFunction
* given a dataset. A LinearCombinationFunction is a weighted linear
* combination of (potentially) nonlinear basis functions. This looks like
* y(x) = a0*f0(x) + a1*f1(x) + ... + an*fn(x) and so forth.
* The internal class LinearRegression.Statistic returns the goodness-of-fit
* statistics for a set of target-estimate pairs, include a p-value for the
* null hypothesis significance.
*
* @param Input class for the basis functions, for example, Double,
* Vector, String.
* @author Kevin R. Dixon
* @since 2.0
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-09-02",
changesNeeded=false,
comments={
"Made minor changes to javadoc",
"Looks fine."
}
)
@PublicationReference(
author="Wikipedia",
title="Linear regression",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/Linear_regression"
)
public class LinearBasisRegression
extends AbstractCloneableSerializable
implements SupervisedBatchLearner>
{
/**
* Tolerance for the pseudo inverse in the learn method, {@value}.
*/
public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;
/**
* Function that maps the InputType to a Vector
*/
private Evaluator super InputType, Vector> inputToVectorMap;
/**
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
private boolean usePseudoInverse;
/**
* Creates a new instance of LinearRegression
* @param basisFunctions
* Basis functions to create the ScalarBasisSet from
*/
public LinearBasisRegression(
Collection extends Evaluator super InputType, Double>> basisFunctions )
{
this( new ScalarBasisSet( basisFunctions ) );
}
/**
* Creates a new instance of LinearRegression
* @param inputToVectorMap
* Function that maps the InputType to a Vector
*/
public LinearBasisRegression(
ScalarBasisSet inputToVectorMap )
{
this( (Evaluator super InputType, Vector>) inputToVectorMap );
}
/**
* Creates a new instance of LinearRegression
* @param inputToVectorMap
* Function that maps the InputType to a Vector
*/
public LinearBasisRegression(
Evaluator super InputType, Vector> inputToVectorMap )
{
this.setInputToVectorMap( inputToVectorMap );
this.setUsePseudoInverse( true );
}
@Override
public LinearBasisRegression clone()
{
@SuppressWarnings("unchecked")
LinearBasisRegression clone =
(LinearBasisRegression) super.clone();
clone.setInputToVectorMap(
ObjectUtil.cloneSmart( this.getInputToVectorMap() ) );
return clone;
}
/**
* Computes the linear regression for the given Collection of
* InputOutputPairs. The inputs of the pairs is the independent variable,
* and the pair output is the dependent variable (variable to predict).
* The pairs can have an associated weight to bias the regression equation.
* @param data
* Collection of InputOutputPairs for the variables. Can be
* WeightedInputOutputPairs.
* @return
* LinearCombinationFunction that minimizes the RMS error of the outputs.
*/
@Override
public VectorFunctionLinearDiscriminant learn(
Collection extends InputOutputPair extends InputType, Double>> data )
{
// Create the vector-based dataset first
ArrayList> vectorData =
new ArrayList>( data.size() );
for (InputOutputPair extends InputType, Double> pair : data)
{
double weight = DatasetUtil.getWeight(pair);
Vector xrow = this.inputToVectorMap.evaluate( pair.getInput() );
Double output = pair.getOutput();
vectorData.add( DefaultWeightedInputOutputPair.create( xrow, output, weight ) );
}
LinearRegression linear = new LinearRegression();
linear.setUsePseudoInverse(this.getUsePseudoInverse());
LinearDiscriminant weights = linear.learn(vectorData);
return new VectorFunctionLinearDiscriminant(
this.inputToVectorMap, weights );
}
/**
* Getter for inputToVectorMap
* @return
* Function that maps the InputType to a Vector
*/
public Evaluator super InputType, Vector> getInputToVectorMap()
{
return this.inputToVectorMap;
}
/**
* Setter for inputToVectorMap
* @param inputToVectorMap
* Function that maps the InputType to a Vector
*/
public void setInputToVectorMap(
Evaluator super InputType, Vector> inputToVectorMap )
{
this.inputToVectorMap = inputToVectorMap;
}
/**
* Getter for usePseudoInverse
* @return
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public boolean getUsePseudoInverse()
{
return this.usePseudoInverse;
}
/**
* Setter for usePseudoInverse
* @param usePseudoInverse
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public void setUsePseudoInverse(
boolean usePseudoInverse )
{
this.usePseudoInverse = usePseudoInverse;
}
}