All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression Maven / Gradle / Ivy

/*
 * File:                BayesianLinearRegression.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Mar 11, 2010, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

/**
 * Computes a Bayesian linear estimator for a given feature function
 * and a set of observed data.
 * @author Kevin R. Dixon
 * @since 3.0
 */
@PublicationReferences(
    references={
        @PublicationReference(
            author="Christopher M. Bishop",
            title="Pattern Recognition and Machine Learning",
            type=PublicationType.Book,
            year=2006,
            pages={152,159}
        )
        ,
        @PublicationReference(
            author="Hanna M. Wallach",
            title="Introduction to Gaussian Process Regression",
            type=PublicationType.Misc,
            year=2005,
            url="http://www.cs.umass.edu/~wallach/talks/gp_intro.pdf"
        )
        ,
        @PublicationReference(
            author="Wikipedia",
            title="Bayesian linear regression",
            type=PublicationType.WebPage,
            year=2010,
            url="http://en.wikipedia.org/wiki/Bayesian_linear_regression"
        )
    }
)
public class BayesianLinearRegression
    extends AbstractCloneableSerializable
    implements BayesianRegression
{

    /**
     * Default output variance, {@value}.
     */
    public static final double DEFAULT_OUTPUT_VARIANCE = 1.0;

    /**
     * Default weight variance, {@value}.
     */
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0;

    /**
     * Assumed known variance of the outputs (measurements),
     * must be greater than zero.
     */
    protected double outputVariance;

    /**
     * Prior distribution of the weights, typically a zero-mean,
     * diagonal-variance distribution.
     */
    protected MultivariateGaussian weightPrior;

    /** 
     * Creates a new instance of BayesianLinearRegression 
     * @param dimensionality
     * Sets up the parameters (except featureMap) for the given dimensionality
     * of objects in feature space.
     */
    public BayesianLinearRegression(
        int dimensionality )
    {
        this( DEFAULT_OUTPUT_VARIANCE,
            new MultivariateGaussian( VectorFactory.getDefault().createVector(dimensionality),
                MatrixFactory.getDefault().createIdentity(dimensionality,dimensionality).scale( DEFAULT_WEIGHT_VARIANCE ) ) );
    }

    /**
     * Creates a new instance of BayesianLinearRegression
     * @param outputVariance
     * Assumed known variance of the outputs (measurements),
     * must be greater than zero.
     * @param weightPrior
     * Prior distribution of the weights, typically a zero-mean,
     * diagonal-variance distribution.
     */
    public BayesianLinearRegression(
        double outputVariance,
        MultivariateGaussian weightPrior)
    {
        this.setOutputVariance(outputVariance);
        this.setWeightPrior(weightPrior);
    }

    @Override
    public BayesianLinearRegression clone()
    {
        @SuppressWarnings("unchecked")
        BayesianLinearRegression clone = (BayesianLinearRegression) super.clone();
        clone.setWeightPrior( ObjectUtil.cloneSafe( this.getWeightPrior() ) );
        return clone;
    }

    @Override
    public MultivariateGaussian.PDF learn(
        Collection> data)
    {
        MultivariateGaussian prior = this.getWeightPrior();

        RingAccumulator Cin = new RingAccumulator();
        Matrix Ci = prior.getCovarianceInverse().clone();
        Cin.accumulate( Ci );
        RingAccumulator zn = new RingAccumulator();
        Vector z = Ci.times( prior.getMean() );
        zn.accumulate( z );

        for (InputOutputPair pair : data)
        {
            Vector x1 = pair.getInput().convertToVector();
            Vector x2 = x1.clone();
            final double beta = DatasetUtil.getWeight(pair) / this.outputVariance;
            if( beta != 1.0 )
            {
                x2.scaleEquals(beta);
            }
            Cin.accumulate( x1.outerProduct(x2) );

            final double y = pair.getOutput();
            if( y != 1.0 )
            {
                x2.scaleEquals(y);
            }
            zn.accumulate( x2 );
        }

        Ci = Cin.getSum();
        Matrix C = Ci.inverse();
        z = zn.getSum();
        Vector mean = C.times( z );

        return new MultivariateGaussian.PDF( mean, C );
    }


    /**
     * Creates the distribution from which the outputs are generated, given
     * the weights and the input to consider.
     * @param input
     * Input to condition on
     * @param weights
     * Weights that determine the mean
     * @return
     * Conditional distribution from which outputs are generated.
     */
    @Override
    public UnivariateGaussian createConditionalDistribution(
        Vectorizable input,
        Vector weights )
    {
        double mean = input.convertToVector().dotProduct(weights);
        return new UnivariateGaussian( mean, this.getOutputVariance() );
    }

    /**
     * Getter for weightPrior
     * @return
     * Prior distribution of the weights, typically a zero-mean,
     * diagonal-variance distribution.
     */
    public MultivariateGaussian getWeightPrior()
    {
        return this.weightPrior;
    }

    /**
     * Setter for weightPrior
     * @param weightPrior
     * Prior distribution of the weights, typically a zero-mean,
     * diagonal-variance distribution.
     */
    public void setWeightPrior(
        MultivariateGaussian weightPrior)
    {
        this.weightPrior = weightPrior;
    }

    /**
     * Getter for outputVariance
     * @return
     * Assumed known variance of the outputs (measurements),
     * must be greater than zero.
     */
    public double getOutputVariance()
    {
        return this.outputVariance;
    }

    /**
     * Setter for outputVariance
     * @param outputVariance
     * Assumed known variance of the outputs (measurements),
     * must be greater than zero.
     */
    public void setOutputVariance(
        double outputVariance)
    {
        if( outputVariance <= 0.0 )
        {
            throw new IllegalArgumentException(
                "outputVariance must be > 0.0" );
        }
        this.outputVariance = outputVariance;
    }

    /**
     * Creates the predictive distribution of outputs given the weight posterior
     * @param posterior
     * Posterior distribution of weights.
     * @return
     * Predictive distribution of outputs given the posterior.
     */
    @Override
    public BayesianLinearRegression.PredictiveDistribution createPredictiveDistribution(
        MultivariateGaussian posterior )
    {
        return new PredictiveDistribution( posterior );
    }

    /**
     * Creates the predictive distribution for the likelihood of a given point.
     */
    @PublicationReference(
        author="Christopher M. Bishop",
        title="Pattern Recognition and Machine Learning",
        type=PublicationType.Book,
        year=2006,
        pages=156
    )
    public class PredictiveDistribution
        extends AbstractCloneableSerializable
        implements Evaluator
    {

        /**
         * Posterior distribution of the weights given the data.
         */
        private MultivariateGaussian posterior;

        /**
         * Creates a new instance of PredictiveDistribution
         * @param posterior
         * Posterior distribution of the weights given the data.
         */
        public PredictiveDistribution(
            MultivariateGaussian posterior )
        {
            this.posterior = posterior;
        }

        @Override
        public UnivariateGaussian.PDF evaluate(
            Vectorizable input)
        {
            // Bishop's equations 3.58-3.59
            Vector x = input.convertToVector();
            double mean = x.dotProduct( this.posterior.getMean() );
            double variance = x.times( this.posterior.getCovariance() ).dotProduct(x) + outputVariance;
            return new UnivariateGaussian.PDF( mean, variance );
        }

    }

    /**
     * Incremental estimator for BayesianLinearRegression
     */
    public static class IncrementalEstimator
        extends BayesianLinearRegression
        implements IncrementalLearner, IncrementalEstimator.SufficientStatistic>
    {

        /**
         * Creates a new instance of IncrementalEstimator
         * @param dimensionality
         * Sets up the parameters (except featureMap) for the given dimensionality
         * of objects in feature space.
         */
        public IncrementalEstimator(
            int dimensionality )
        {
            super( dimensionality );
        }

        /**
         * Creates a new instance of IncrementalEstimator
         * @param outputVariance
         * Assumed known variance of the outputs (measurements),
         * must be greater than zero.
         * @param weightPrior
         * Prior distribution of the weights, typically a zero-mean,
         * diagonal-variance distribution.
         */
        public IncrementalEstimator(
            double outputVariance,
            MultivariateGaussian weightPrior)
        {
            super( outputVariance, weightPrior);
        }

        @Override
        public IncrementalEstimator.SufficientStatistic createInitialLearnedObject()
        {
            return new SufficientStatistic(this.getWeightPrior());
        }

        @Override
        public MultivariateGaussian.PDF learn(
            Collection> data)
        {
            IncrementalEstimator.SufficientStatistic target =
                this.createInitialLearnedObject();
            this.update(target, data);
            return target.create();
        }

        @Override
        public void update(
            IncrementalEstimator.SufficientStatistic target,
            InputOutputPair data)
        {
            target.update(data);
        }

        @Override
        public void update(
            SufficientStatistic target,
            Iterable> data)
        {
            target.update(data);
        }

        /**
         * SufficientStatistic for incremental Bayesian linear regression
         */
        public class SufficientStatistic
            extends AbstractSufficientStatistic, MultivariateGaussian>
        {

            /**
             * "z" statistic, proportional to the mean
             */
            private Vector z;

            /**
             * Covariance inverse, sometimes called "precision"
             */
            private Matrix covarianceInverse;

            /**
             * Creates a new instance of SufficientStatistic
             * @param prior
             * Prior on the weights
             */
            public SufficientStatistic(
                MultivariateGaussian prior )
            {
                super();

                if( prior != null )
                {
                    this.covarianceInverse = prior.getCovarianceInverse().clone();
                    this.z = this.covarianceInverse.times( prior.getMean() );
                    this.count = 1;
                }
                else
                {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0;
                }
            }

            @Override
            public void update(
                InputOutputPair value)
            {
                this.count++;
                Vector v = value.getInput().convertToVector();
                Vector x1 = v;
                Vector x2 = v.clone();
                final double y = value.getOutput();
                final double beta = DatasetUtil.getWeight(value) / outputVariance;
                if( beta != 1.0 )
                {
                    x2.scaleEquals(beta);
                }

                if( this.covarianceInverse == null )
                {
                    this.covarianceInverse = x1.outerProduct(x2);
                }
                else
                {
                    this.covarianceInverse.plusEquals( x1.outerProduct(x2) );
                }

                if( y != 1.0 )
                {
                    x2.scaleEquals( y );
                }

                if( this.z == null )
                {
                    this.z = x2;
                }
                else
                {
                    this.z.plusEquals( x2 );
                }
            }

            @Override
            public MultivariateGaussian.PDF create()
            {
                MultivariateGaussian.PDF g =
                    new MultivariateGaussian.PDF(this.getDimensionality());
                this.create(g);
                return g;
            }

            @Override
            public void create(
                MultivariateGaussian distribution)
            {
                distribution.setMean(this.getMean());
                distribution.setCovarianceInverse(this.getCovarianceInverse());
            }

            /**
             * Getter for covarianceInverse
             * @return
             * Covariance inverse, sometimes called "precision"
             */
            public Matrix getCovarianceInverse()
            {
                return this.covarianceInverse;
            }

            /**
             * Getter for z
             * @return
             * "z" statistic, proportional to the mean
             */
            public Vector getZ()
            {
                return this.z;
            }
            
            /**
             * Computes the mean of the Gaussian, but involves a matrix
             * inversion and multiplication, so it's expensive.
             * @return
             * Mean of the Gaussian.
             */
            public Vector getMean()
            {
                return this.covarianceInverse.inverse().times( this.z );
            }

            /**
             * Gets the dimensionality of the underlying Gaussian
             * @return
             * Dimensionality of the underlying Gaussian
             */
            public int getDimensionality()
            {
                return this.getZ().getDimensionality();
            }

        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy