gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression Maven / Gradle / Ivy
/*
* File: BayesianLinearRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Mar 11, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
/**
* Computes a Bayesian linear estimator for a given feature function
* and a set of observed data.
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReferences(
references={
@PublicationReference(
author="Christopher M. Bishop",
title="Pattern Recognition and Machine Learning",
type=PublicationType.Book,
year=2006,
pages={152,159}
)
,
@PublicationReference(
author="Hanna M. Wallach",
title="Introduction to Gaussian Process Regression",
type=PublicationType.Misc,
year=2005,
url="http://www.cs.umass.edu/~wallach/talks/gp_intro.pdf"
)
,
@PublicationReference(
author="Wikipedia",
title="Bayesian linear regression",
type=PublicationType.WebPage,
year=2010,
url="http://en.wikipedia.org/wiki/Bayesian_linear_regression"
)
}
)
public class BayesianLinearRegression
extends AbstractCloneableSerializable
implements BayesianRegression
{
/**
* Default output variance, {@value}.
*/
public static final double DEFAULT_OUTPUT_VARIANCE = 1.0;
/**
* Default weight variance, {@value}.
*/
public static final double DEFAULT_WEIGHT_VARIANCE = 1.0;
/**
* Assumed known variance of the outputs (measurements),
* must be greater than zero.
*/
protected double outputVariance;
/**
* Prior distribution of the weights, typically a zero-mean,
* diagonal-variance distribution.
*/
protected MultivariateGaussian weightPrior;
/**
* Creates a new instance of BayesianLinearRegression
* @param dimensionality
* Sets up the parameters (except featureMap) for the given dimensionality
* of objects in feature space.
*/
public BayesianLinearRegression(
int dimensionality )
{
this( DEFAULT_OUTPUT_VARIANCE,
new MultivariateGaussian( VectorFactory.getDefault().createVector(dimensionality),
MatrixFactory.getDefault().createIdentity(dimensionality,dimensionality).scale( DEFAULT_WEIGHT_VARIANCE ) ) );
}
/**
* Creates a new instance of BayesianLinearRegression
* @param outputVariance
* Assumed known variance of the outputs (measurements),
* must be greater than zero.
* @param weightPrior
* Prior distribution of the weights, typically a zero-mean,
* diagonal-variance distribution.
*/
public BayesianLinearRegression(
double outputVariance,
MultivariateGaussian weightPrior)
{
this.setOutputVariance(outputVariance);
this.setWeightPrior(weightPrior);
}
@Override
public BayesianLinearRegression clone()
{
@SuppressWarnings("unchecked")
BayesianLinearRegression clone = (BayesianLinearRegression) super.clone();
clone.setWeightPrior( ObjectUtil.cloneSafe( this.getWeightPrior() ) );
return clone;
}
@Override
public MultivariateGaussian.PDF learn(
Collection extends InputOutputPair extends Vectorizable, Double>> data)
{
MultivariateGaussian prior = this.getWeightPrior();
RingAccumulator Cin = new RingAccumulator();
Matrix Ci = prior.getCovarianceInverse().clone();
Cin.accumulate( Ci );
RingAccumulator zn = new RingAccumulator();
Vector z = Ci.times( prior.getMean() );
zn.accumulate( z );
for (InputOutputPair extends Vectorizable, Double> pair : data)
{
Vector x1 = pair.getInput().convertToVector();
Vector x2 = x1.clone();
final double beta = DatasetUtil.getWeight(pair) / this.outputVariance;
if( beta != 1.0 )
{
x2.scaleEquals(beta);
}
Cin.accumulate( x1.outerProduct(x2) );
final double y = pair.getOutput();
if( y != 1.0 )
{
x2.scaleEquals(y);
}
zn.accumulate( x2 );
}
Ci = Cin.getSum();
Matrix C = Ci.inverse();
z = zn.getSum();
Vector mean = C.times( z );
return new MultivariateGaussian.PDF( mean, C );
}
/**
* Creates the distribution from which the outputs are generated, given
* the weights and the input to consider.
* @param input
* Input to condition on
* @param weights
* Weights that determine the mean
* @return
* Conditional distribution from which outputs are generated.
*/
@Override
public UnivariateGaussian createConditionalDistribution(
Vectorizable input,
Vector weights )
{
double mean = input.convertToVector().dotProduct(weights);
return new UnivariateGaussian( mean, this.getOutputVariance() );
}
/**
* Getter for weightPrior
* @return
* Prior distribution of the weights, typically a zero-mean,
* diagonal-variance distribution.
*/
public MultivariateGaussian getWeightPrior()
{
return this.weightPrior;
}
/**
* Setter for weightPrior
* @param weightPrior
* Prior distribution of the weights, typically a zero-mean,
* diagonal-variance distribution.
*/
public void setWeightPrior(
MultivariateGaussian weightPrior)
{
this.weightPrior = weightPrior;
}
/**
* Getter for outputVariance
* @return
* Assumed known variance of the outputs (measurements),
* must be greater than zero.
*/
public double getOutputVariance()
{
return this.outputVariance;
}
/**
* Setter for outputVariance
* @param outputVariance
* Assumed known variance of the outputs (measurements),
* must be greater than zero.
*/
public void setOutputVariance(
double outputVariance)
{
if( outputVariance <= 0.0 )
{
throw new IllegalArgumentException(
"outputVariance must be > 0.0" );
}
this.outputVariance = outputVariance;
}
/**
* Creates the predictive distribution of outputs given the weight posterior
* @param posterior
* Posterior distribution of weights.
* @return
* Predictive distribution of outputs given the posterior.
*/
@Override
public BayesianLinearRegression.PredictiveDistribution createPredictiveDistribution(
MultivariateGaussian posterior )
{
return new PredictiveDistribution( posterior );
}
/**
* Creates the predictive distribution for the likelihood of a given point.
*/
@PublicationReference(
author="Christopher M. Bishop",
title="Pattern Recognition and Machine Learning",
type=PublicationType.Book,
year=2006,
pages=156
)
public class PredictiveDistribution
extends AbstractCloneableSerializable
implements Evaluator
{
/**
* Posterior distribution of the weights given the data.
*/
private MultivariateGaussian posterior;
/**
* Creates a new instance of PredictiveDistribution
* @param posterior
* Posterior distribution of the weights given the data.
*/
public PredictiveDistribution(
MultivariateGaussian posterior )
{
this.posterior = posterior;
}
@Override
public UnivariateGaussian.PDF evaluate(
Vectorizable input)
{
// Bishop's equations 3.58-3.59
Vector x = input.convertToVector();
double mean = x.dotProduct( this.posterior.getMean() );
double variance = x.times( this.posterior.getCovariance() ).dotProduct(x) + outputVariance;
return new UnivariateGaussian.PDF( mean, variance );
}
}
/**
* Incremental estimator for BayesianLinearRegression
*/
public static class IncrementalEstimator
extends BayesianLinearRegression
implements IncrementalLearner, IncrementalEstimator.SufficientStatistic>
{
/**
* Creates a new instance of IncrementalEstimator
* @param dimensionality
* Sets up the parameters (except featureMap) for the given dimensionality
* of objects in feature space.
*/
public IncrementalEstimator(
int dimensionality )
{
super( dimensionality );
}
/**
* Creates a new instance of IncrementalEstimator
* @param outputVariance
* Assumed known variance of the outputs (measurements),
* must be greater than zero.
* @param weightPrior
* Prior distribution of the weights, typically a zero-mean,
* diagonal-variance distribution.
*/
public IncrementalEstimator(
double outputVariance,
MultivariateGaussian weightPrior)
{
super( outputVariance, weightPrior);
}
@Override
public IncrementalEstimator.SufficientStatistic createInitialLearnedObject()
{
return new SufficientStatistic(this.getWeightPrior());
}
@Override
public MultivariateGaussian.PDF learn(
Collection extends InputOutputPair extends Vectorizable, Double>> data)
{
IncrementalEstimator.SufficientStatistic target =
this.createInitialLearnedObject();
this.update(target, data);
return target.create();
}
@Override
public void update(
IncrementalEstimator.SufficientStatistic target,
InputOutputPair extends Vectorizable, Double> data)
{
target.update(data);
}
@Override
public void update(
SufficientStatistic target,
Iterable extends InputOutputPair extends Vectorizable, Double>> data)
{
target.update(data);
}
/**
* SufficientStatistic for incremental Bayesian linear regression
*/
public class SufficientStatistic
extends AbstractSufficientStatistic, MultivariateGaussian>
{
/**
* "z" statistic, proportional to the mean
*/
private Vector z;
/**
* Covariance inverse, sometimes called "precision"
*/
private Matrix covarianceInverse;
/**
* Creates a new instance of SufficientStatistic
* @param prior
* Prior on the weights
*/
public SufficientStatistic(
MultivariateGaussian prior )
{
super();
if( prior != null )
{
this.covarianceInverse = prior.getCovarianceInverse().clone();
this.z = this.covarianceInverse.times( prior.getMean() );
this.count = 1;
}
else
{
this.covarianceInverse = null;
this.z = null;
this.count = 0;
}
}
@Override
public void update(
InputOutputPair extends Vectorizable, Double> value)
{
this.count++;
Vector v = value.getInput().convertToVector();
Vector x1 = v;
Vector x2 = v.clone();
final double y = value.getOutput();
final double beta = DatasetUtil.getWeight(value) / outputVariance;
if( beta != 1.0 )
{
x2.scaleEquals(beta);
}
if( this.covarianceInverse == null )
{
this.covarianceInverse = x1.outerProduct(x2);
}
else
{
this.covarianceInverse.plusEquals( x1.outerProduct(x2) );
}
if( y != 1.0 )
{
x2.scaleEquals( y );
}
if( this.z == null )
{
this.z = x2;
}
else
{
this.z.plusEquals( x2 );
}
}
@Override
public MultivariateGaussian.PDF create()
{
MultivariateGaussian.PDF g =
new MultivariateGaussian.PDF(this.getDimensionality());
this.create(g);
return g;
}
@Override
public void create(
MultivariateGaussian distribution)
{
distribution.setMean(this.getMean());
distribution.setCovarianceInverse(this.getCovarianceInverse());
}
/**
* Getter for covarianceInverse
* @return
* Covariance inverse, sometimes called "precision"
*/
public Matrix getCovarianceInverse()
{
return this.covarianceInverse;
}
/**
* Getter for z
* @return
* "z" statistic, proportional to the mean
*/
public Vector getZ()
{
return this.z;
}
/**
* Computes the mean of the Gaussian, but involves a matrix
* inversion and multiplication, so it's expensive.
* @return
* Mean of the Gaussian.
*/
public Vector getMean()
{
return this.covarianceInverse.inverse().times( this.z );
}
/**
* Gets the dimensionality of the underlying Gaussian
* @return
* Dimensionality of the underlying Gaussian
*/
public int getDimensionality()
{
return this.getZ().getDimensionality();
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy