![JAR search and dependency download from the Maven repository](/logo.png)
gov.sandia.cognition.learning.algorithm.regression.LinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: LinearRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright September 5, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminantWithBias;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.method.AbstractConfidenceStatistic;
import gov.sandia.cognition.statistics.distribution.ChiSquareDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
/**
* Computes the least-squares regression for a LinearCombinationFunction
* given a dataset. A LinearCombinationFunction is a weighted linear
* combination of (potentially) nonlinear basis functions. This looks like
* y(x) = b + w'x, where "b" is a scalar bias and "w" is a weight vector.
* The internal class LinearRegression.Statistic returns the goodness-of-fit
* statistics for a set of target-estimate pairs, include a p-value for the
* null hypothesis significance.
*
* @author Kevin R. Dixon
* @since 2.0
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-09-02",
changesNeeded=false,
comments={
"Made minor changes to javadoc",
"Looks fine."
}
)
@PublicationReferences(
references={
@PublicationReference(
author="Wikipedia",
title="Linear regression",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/Linear_regression"
)
,
@PublicationReference(
author="Wikipedia",
title="Tikhonov regularization",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Tikhonov_regularization",
notes="Despite what Wikipedia says, this is always called Ridge Regression"
)
}
)
public class LinearRegression
extends AbstractCloneableSerializable
implements SupervisedBatchLearner
{
/**
* Tolerance for the pseudo inverse in the learn method, {@value}.
*/
public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;
/**
* Default regularization, {@value}.
*/
public static final double DEFAULT_REGULARIZATION = 0.0;
/**
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
private boolean usePseudoInverse;
/**
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
private double regularization;
/**
* Creates a new instance of LinearRegression
*/
public LinearRegression()
{
this( DEFAULT_REGULARIZATION, true );
}
/**
* Creates a new instance of LinearRegression
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
* @param usePseudoInverse
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public LinearRegression(
double regularization,
boolean usePseudoInverse )
{
this.setRegularization(regularization);
this.setUsePseudoInverse(usePseudoInverse);
}
@Override
public LinearRegression clone()
{
return (LinearRegression) super.clone();
}
/**
* Computes the linear regression for the given Collection of
* InputOutputPairs. The inputs of the pairs is the independent variable,
* and the pair output is the dependent variable (variable to predict).
* The pairs can have an associated weight to bias the regression equation.
* @param data
* Collection of InputOutputPairs for the variables. Can be
* WeightedInputOutputPairs.
* @return
* LinearCombinationFunction that minimizes the RMS error of the outputs.
*/
@Override
public LinearDiscriminantWithBias learn(
Collection extends InputOutputPair extends Vectorizable, Double>> data )
{
// We need to cheat to figure out how many coefficients we need...
// So we'll push the first sample through... wasteful, but general
int numCoefficients = CollectionUtil.getFirst(data).getInput().convertToVector().getDimensionality();
int numSamples = data.size();
Matrix X = MatrixFactory.getDefault().createMatrix( numCoefficients+1, numSamples );
Matrix Xt = MatrixFactory.getDefault().createMatrix( numSamples, numCoefficients+1 );
Vector y = VectorFactory.getDefault().createVector( numSamples );
Vector one = VectorFactory.getDefault().copyValues(1.0);
int n = 0;
for (InputOutputPair extends Vectorizable, Double> pair : data)
{
double output = pair.getOutput();
Vector input = pair.getInput().convertToVector().stack(one);
// We don't want Xt to have the weight factor too
final double weight = DatasetUtil.getWeight(pair);
if( weight != 1.0 )
{
// We can use scaleEquals() here because of the stack() method
input.scaleEquals(weight);
output *= weight;
}
Xt.setRow( n, input );
X.setColumn( n, input );
y.setElement( n, output );
n++;
}
// Solve for the coefficients
Vector coefficients;
if( this.getUsePseudoInverse() )
{
Matrix pseudoInverse = X.pseudoInverse(DEFAULT_PSEUDO_INVERSE_TOLERANCE);
coefficients = y.times( pseudoInverse );
}
else
{
Matrix lhs = X.times( Xt );
if( this.regularization > 0.0 )
{
for( int i = 0; i < numSamples; i++ )
{
double v = lhs.getElement(i, i);
lhs.setElement(i, i, v + this.regularization);
}
}
Vector rhs = y.times( Xt );
coefficients = lhs.solve( rhs );
}
Vector w = coefficients.subVector(0, numCoefficients-1);
double bias = coefficients.getElement(numCoefficients);
return new LinearDiscriminantWithBias( w, bias );
}
/**
* Getter for usePseudoInverse
* @return
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public boolean getUsePseudoInverse()
{
return this.usePseudoInverse;
}
/**
* Setter for usePseudoInverse
* @param usePseudoInverse
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public void setUsePseudoInverse(
boolean usePseudoInverse )
{
this.usePseudoInverse = usePseudoInverse;
}
/**
* Getter for regularization
* @return
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public double getRegularization()
{
return this.regularization;
}
/**
* Setter for regularization
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public void setRegularization(
double regularization)
{
ArgumentChecker.assertIsNonNegative("regularization", regularization);
this.regularization = regularization;
}
/**
* Computes regression statistics using a chi-square measure of the
* statistical significance of the learned approximator
*/
public static class Statistic
extends AbstractConfidenceStatistic
{
/**
* Gets the value of the chi-square variable,
* Total weighted sum-squared error between the targets and estimates
*/
private double chiSquare;
/**
* Root mean-squared error of the targets and estimates
*/
private double rootMeanSquaredError;
/**
* Average L1-norm error (absolute value difference) between the
* targets and estimates
*/
private double meanL1Error;
/**
* Pearson Correlation between the targets and estimates, [-1,1]
*/
private double targetEstimateCorrelation;
/**
* Fraction of variance unaccounted for in the predictions, [0,1]
*/
private double unpredictedErrorFraction;
/**
* Number of samples used to create the Regression
*/
private int numSamples;
/**
* Number of parameters in the learned approximator
*/
private int numParameters;
/**
* Degrees of freedom in the Regression = numSamples-numParameters
*/
private double degreesOfFreedom;
/**
* Creates a new instance of Statistic
* @param targets
* Collection of ground-truth targets for the learned approximator
* @param estimates
* Collection of estimates from the learned approximator
* @param numParameters
* Number of parameters in the learned approximator
*/
public Statistic(
Collection targets,
Collection estimates,
int numParameters )
{
super( 0.0 );
Collection weights =
Collections.nCopies( targets.size(), new Double( 1.0 ) );
this.computeStatistics( targets, estimates, weights, numParameters );
}
/**
* Creates a new instance of Statistic
* @param targets
* Collection of ground-truth targets for the learned approximator
* @param estimates
* Collection of estimates from the learned approximator
* @param weights
* Collection of weights to apply to the corresponding target-estimate
* pair
* @param numParameters
* Number of parameters in the learned approximator
*/
public Statistic(
Collection targets,
Collection estimates,
Collection weights,
int numParameters )
{
super( 0.0 );
this.computeStatistics( targets, estimates, weights, numParameters );
}
/**
* Copy Constructor
* @param other
* Statistic to copy
*/
private Statistic(
Statistic other )
{
super( other.getNullHypothesisProbability() );
this.setDegreesOfFreedom( other.getDegreesOfFreedom() );
this.setMeanL1Error( other.getMeanL1Error() );
this.setNumParameters( other.getNumParameters() );
this.setNumSamples( other.getNumSamples() );
this.setRootMeanSquaredError( other.getRootMeanSquaredError() );
this.setTargetEstimateCorrelation( other.getTargetEstimateCorrelation() );
this.setUnpredictedErrorFraction( other.getUnpredictedErrorFraction() );
}
@Override
public Statistic clone()
{
return (Statistic) super.clone();
}
/**
* Creates a new instance of Statistic
* @param targets
* Collection of ground-truth targets for the learned approximator
* @param estimates
* Collection of estimates from the learned approximator
* @param weights
* Collection of weights to apply to the corresponding target-estimate
* pair
* @param numParameters
* Number of parameters in the learned approximator
*/
private void computeStatistics(
Collection targets,
Collection estimates,
Collection weights,
int numParameters )
{
if ((targets.size() != estimates.size()) &&
(targets.size() != weights.size()))
{
throw new IllegalArgumentException(
"Targets, Estimates, and Weights must be the same size!" );
}
// Compute the errors between the targets and the estimates
int num = targets.size();
ArrayList errors = new ArrayList( num );
double averageL1Error = 0.0;
double weightSum = 0.0;
Iterator it = targets.iterator();
Iterator ie = estimates.iterator();
Iterator iw = weights.iterator();
for (int n = 0; n < num; n++)
{
double estimate = ie.next();
double target = it.next();
double weight = iw.next();
double error = weight * (target - estimate);
errors.add( error );
averageL1Error += Math.abs( error );
weightSum += weight;
}
if (weightSum > 0)
{
averageL1Error /= weightSum;
}
else
{
averageL1Error = 0.0;
}
// Make sure that the DOFs stays above 0.0
// (so let's just cap it to a minimum of 1.0)
double dofs = num - numParameters;
if( dofs < 1.0 )
{
dofs = 1.0;
}
double chi2 = UnivariateStatisticsUtil.computeSumSquaredDifference( errors, 0.0 );
double pvalue = 1.0 - ChiSquareDistribution.CDF.evaluate(
chi2, dofs );
double rmsError =
UnivariateStatisticsUtil.computeRootMeanSquaredError( errors, 0.0 );
double correlation =
UnivariateStatisticsUtil.computeCorrelation( targets, estimates );
double unpredictedFraction = 1.0 - (correlation * correlation);
this.setNullHypothesisProbability( pvalue );
this.setChiSquare( chi2 );
this.setDegreesOfFreedom( dofs );
this.setMeanL1Error( averageL1Error );
this.setNumSamples( num );
this.setRootMeanSquaredError( rmsError );
this.setTargetEstimateCorrelation( correlation );
this.setUnpredictedErrorFraction( unpredictedFraction );
this.setNumParameters( numParameters );
}
/**
* Getter for rootMeanSquaredError
* @return
* Root mean-squared error of the targets and estimates
*/
public double getRootMeanSquaredError()
{
return this.rootMeanSquaredError;
}
/**
* Setter fpr rootMeanSquaredError
* @param rootMeanSquaredError
* Root mean-squared error of the targets and estimates
*/
protected void setRootMeanSquaredError(
double rootMeanSquaredError )
{
this.rootMeanSquaredError = rootMeanSquaredError;
}
/**
* Getter for targetEstimateCorrelation
* @return
* Pearson Correlation between the targets and estimates, [-1,1]
*/
public double getTargetEstimateCorrelation()
{
return this.targetEstimateCorrelation;
}
/**
* Setter for targetEstimateCorrelation
* @param targetEstimateCorrelation
* Pearson Correlation between the targets and estimates, [-1,1]
*/
protected void setTargetEstimateCorrelation(
double targetEstimateCorrelation )
{
this.targetEstimateCorrelation = targetEstimateCorrelation;
}
/**
* Getter for unpredictedErrorFraction
* @return
* Fraction of variance unaccounted for in the predictions, [0,1]
*/
public double getUnpredictedErrorFraction()
{
return this.unpredictedErrorFraction;
}
/**
* Setter for unpredictedErrorFraction
* @param unpredictedErrorFraction
* Fraction of variance unaccounted for in the predictions, [0,1]
*/
protected void setUnpredictedErrorFraction(
double unpredictedErrorFraction )
{
this.unpredictedErrorFraction = unpredictedErrorFraction;
}
/**
* Getter for numSamples
* @return
* Number of samples used to create the Regression
*/
public int getNumSamples()
{
return this.numSamples;
}
/**
* Setter for numSamples
* @param numSamples
* Number of samples used to create the Regression
*/
protected void setNumSamples(
int numSamples )
{
this.numSamples = numSamples;
}
/**
* Getter for degreesOfFreedom
* @return
* Degrees of freedom in the Regression = numSamples-numParameters
*/
public double getDegreesOfFreedom()
{
return this.degreesOfFreedom;
}
/**
* Setter for degreesOfFreedom
* @param degreesOfFreedom
* Degrees of freedom in the Regression = numSamples-numParameters
*/
protected void setDegreesOfFreedom(
double degreesOfFreedom )
{
this.degreesOfFreedom = degreesOfFreedom;
}
/**
* Getter for meanL1Error
* @return
* Average L1-norm error (absolute value difference) between the
* targets and estimates
*/
public double getMeanL1Error()
{
return this.meanL1Error;
}
/**
* Setter for meanL1Error
* @param meanL1Error
* Average L1-norm error (absolute value difference) between the
* targets and estimates
*/
protected void setMeanL1Error(
double meanL1Error )
{
this.meanL1Error = meanL1Error;
}
/**
* Getter for numParameters
* @return
* Number of parameters in the learned approximator
*/
public int getNumParameters()
{
return this.numParameters;
}
/**
* Setter for numParameters
* @param numParameters
* Number of parameters in the learned approximator
*/
public void setNumParameters(
int numParameters )
{
this.numParameters = numParameters;
}
/**
* Getter for chiSquare
* @return
* Gets the value of the chi-square variable,
* Total weighted sum-squared error between the targets and estimates
*/
public double getChiSquare()
{
return this.chiSquare;
}
/**
* Setter for chiSquare
* @param chiSquare
* Gets the value of the chi-square variable,
* Total weighted sum-squared error between the targets and estimates
*/
public void setChiSquare(
double chiSquare )
{
this.chiSquare = chiSquare;
}
@Override
public double getTestStatistic()
{
return this.getChiSquare();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy