gov.sandia.cognition.learning.algorithm.regression.MultivariateLinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: MultivariateLinearRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jun 22, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminantWithBias;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;
/**
* Performs multivariate regression with an explicit bias term, with optional
* L2 regularization.
* @author Kevin R. Dixon
* @since 3.2.1
*/
@PublicationReferences(
references={
@PublicationReference(
author="Wikipedia",
title="Linear regression",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/Linear_regression"
)
,
@PublicationReference(
author="Wikipedia",
title="Tikhonov regularization",
type=PublicationType.WebPage,
year=2011,
url="http://en.wikipedia.org/wiki/Tikhonov_regularization",
notes="Despite what Wikipedia says, this is always called Ridge Regression"
)
}
)
public class MultivariateLinearRegression
extends AbstractCloneableSerializable
implements SupervisedBatchLearner
{
/**
* Default regularization, {@value}.
*/
public static final double DEFAULT_REGULARIZATION = 0.0;
/**
* Tolerance for the pseudo inverse in the learn method, {@value}.
*/
public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;
/**
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
private boolean usePseudoInverse;
/**
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
private double regularization;
/**
* Creates a new instance of MultivariateLinearRegression
*/
public MultivariateLinearRegression()
{
this.setUsePseudoInverse(true);
}
@Override
public MultivariateLinearRegression clone()
{
return (MultivariateLinearRegression) super.clone();
}
@Override
public MultivariateDiscriminantWithBias learn(
Collection extends InputOutputPair extends Vector, Vector>> data)
{
// We need to cheat to figure out how many coefficients we need...
// So we'll push the first sample through... wasteful, but general
InputOutputPair extends Vector,Vector> first =
CollectionUtil.getFirst(data);
int M = first.getInput().getDimensionality();
int N = first.getOutput().getDimensionality();
int numSamples = data.size();
Matrix X = MatrixFactory.getDefault().createMatrix( M+1, numSamples );
Matrix Xt = MatrixFactory.getDefault().createMatrix( numSamples, M+1 );
Matrix Y = MatrixFactory.getDefault().createMatrix( N, numSamples );
Matrix Yt = MatrixFactory.getDefault().createMatrix( numSamples, N );
// The matrix equation looks like:
// y = C*[f0(x) f1(x) ... fn(x) ], fi() is the ith basis function
int i = 0;
Vector one = VectorFactory.getDefault().copyValues(1.0);
for (InputOutputPair extends Vector, Vector> pair : data)
{
Vector output = pair.getOutput();
Vector input = pair.getInput().convertToVector().stack(one);
final double weight = DatasetUtil.getWeight(pair);
if( weight != 1.0 )
{
// We can use scaleEquals() here because of the stack() method
input.scaleEquals(weight);
output = output.scale(weight);
}
Xt.setRow( i, input );
X.setColumn( i, input );
Y.setColumn( i, output );
Yt.setRow( i, output );
i++;
}
// Solve for the coefficients
Matrix coefficients;
if( this.getUsePseudoInverse() )
{
Matrix pseudoInverse = Xt.pseudoInverse(DEFAULT_PSEUDO_INVERSE_TOLERANCE);
coefficients = pseudoInverse.times( Yt ).transpose();
}
else
{
Matrix lhs = X.times( Xt );
if( this.regularization > 0.0 )
{
for( i = 0; i < M+1; i++ )
{
double v = lhs.getElement(i, i);
lhs.setElement(i, i, v + this.regularization);
}
}
Matrix rhs = Y.times( Xt );
coefficients = lhs.solve( rhs.transpose() ).transpose();
}
Matrix discriminant = coefficients.getSubMatrix(0, N-1, 0, M-1);
Vector bias = coefficients.getColumn(M);
return new MultivariateDiscriminantWithBias( discriminant, bias );
}
/**
* Getter for usePseudoInverse
* @return
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public boolean getUsePseudoInverse()
{
return this.usePseudoInverse;
}
/**
* Setter for usePseudoInverse
* @param usePseudoInverse
* Flag to use a pseudoinverse. True to use the expensive, but more
* accurate, pseudoinverse routine. False uses a very fast, but
* numerically less stable LU solver. Default value is "true".
*/
public void setUsePseudoInverse(
boolean usePseudoInverse )
{
this.usePseudoInverse = usePseudoInverse;
}
/**
* Getter for regularization
* @return
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public double getRegularization()
{
return this.regularization;
}
/**
* Setter for regularization
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public void setRegularization(
double regularization)
{
ArgumentChecker.assertIsNonNegative("regularization", regularization);
this.regularization = regularization;
}
}