All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.regression.LinearRegression Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                LinearRegression.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright September 5, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminantWithBias;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.method.AbstractConfidenceStatistic;
import gov.sandia.cognition.statistics.distribution.ChiSquareDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

/**
 * Computes the least-squares regression for a LinearCombinationFunction
 * given a dataset.  A LinearCombinationFunction is a weighted linear
 * combination of (potentially) nonlinear basis functions.  This looks like
 * y(x) = b + w'x, where "b" is a scalar bias and "w" is a weight vector.
 * The internal class LinearRegression.Statistic returns the goodness-of-fit
 * statistics for a set of target-estimate pairs, include a p-value for the
 * null hypothesis significance.
 *
 * @author  Kevin R. Dixon
 * @since   2.0
 */
@CodeReview(
    reviewer="Kevin R. Dixon",
    date="2008-09-02",
    changesNeeded=false,
    comments={
        "Made minor changes to javadoc",
        "Looks fine."
    }
)
@PublicationReferences(
    references={
        @PublicationReference(
            author="Wikipedia",
            title="Linear regression",
            type=PublicationType.WebPage,
            year=2008,
            url="http://en.wikipedia.org/wiki/Linear_regression"
        )
        ,
        @PublicationReference(
            author="Wikipedia",
            title="Tikhonov regularization",
            type=PublicationType.WebPage,
            year=2011,
            url="http://en.wikipedia.org/wiki/Tikhonov_regularization",
            notes="Despite what Wikipedia says, this is always called Ridge Regression"
        )
    }
)
public class LinearRegression
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner
{

    /**
     * Tolerance for the pseudo inverse in the learn method, {@value}.
     */
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;

    /**
     * Default regularization, {@value}.
     */
    public static final double DEFAULT_REGULARIZATION = 0.0;

    /**
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    private boolean usePseudoInverse;

    /**
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    private double regularization;

    /**
     * Creates a new instance of LinearRegression
     */
    public LinearRegression()
    {
        this( DEFAULT_REGULARIZATION, true );
    }

    /**
     * Creates a new instance of LinearRegression
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     * @param usePseudoInverse
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public LinearRegression(
        double regularization,
        boolean usePseudoInverse )
    {
        this.setRegularization(regularization);
        this.setUsePseudoInverse(usePseudoInverse);
    }

    @Override
    public LinearRegression clone()
    {
        return (LinearRegression) super.clone();
    }

    /**
     * Computes the linear regression for the given Collection of
     * InputOutputPairs.  The inputs of the pairs is the independent variable,
     * and the pair output is the dependent variable (variable to predict).
     * The pairs can have an associated weight to bias the regression equation.
     * @param data 
     * Collection of InputOutputPairs for the variables.  Can be 
     * WeightedInputOutputPairs.
     * @return 
     * LinearCombinationFunction that minimizes the RMS error of the outputs.
     */
    @Override
    public LinearDiscriminantWithBias learn(
        Collection> data )
    {

        // We need to cheat to figure out how many coefficients we need...
        // So we'll push the first sample through... wasteful, but general
        int numCoefficients = CollectionUtil.getFirst(data).getInput().convertToVector().getDimensionality();
        int numSamples = data.size();

        Matrix X = MatrixFactory.getDefault().createMatrix( numCoefficients+1, numSamples );
        Matrix Xt = MatrixFactory.getDefault().createMatrix( numSamples, numCoefficients+1 );
        Vector y = VectorFactory.getDefault().createVector( numSamples );

        Vector one = VectorFactory.getDefault().copyValues(1.0);
        int n = 0;
        for (InputOutputPair pair : data)
        {
            double output = pair.getOutput();
            Vector input = pair.getInput().convertToVector().stack(one);

            // We don't want Xt to have the weight factor too
            final double weight = DatasetUtil.getWeight(pair);
            if( weight != 1.0 )
            {
                // We can use scaleEquals() here because of the stack() method
                input.scaleEquals(weight);
                output *= weight;
            }
            Xt.setRow( n, input );
            X.setColumn( n, input );
            y.setElement( n, output );
            n++;
        }

        // Solve for the coefficients
        Vector coefficients;
        if( this.getUsePseudoInverse() )
        {
            Matrix pseudoInverse = X.pseudoInverse(DEFAULT_PSEUDO_INVERSE_TOLERANCE);
            coefficients = y.times( pseudoInverse );
        }
        else
        {
            Matrix lhs = X.times( Xt );
            if( this.regularization > 0.0 )
            {
                for( int i = 0; i < numSamples; i++ )
                {
                    double v = lhs.getElement(i, i);
                    lhs.setElement(i, i, v + this.regularization);
                }
            }
            Vector rhs = y.times( Xt );
            coefficients = lhs.solve( rhs );
        }

        Vector w = coefficients.subVector(0, numCoefficients-1);
        double bias = coefficients.getElement(numCoefficients);
        return new LinearDiscriminantWithBias( w, bias );
    }

    /**
     * Getter for usePseudoInverse
     * @return
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public boolean getUsePseudoInverse()
    {
        return this.usePseudoInverse;
    }

    /**
     * Setter for usePseudoInverse
     * @param usePseudoInverse
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public void setUsePseudoInverse(
        boolean usePseudoInverse )
    {
        this.usePseudoInverse = usePseudoInverse;
    }

    /**
     * Getter for regularization
     * @return
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public double getRegularization()
    {
        return this.regularization;
    }

    /**
     * Setter for regularization
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public void setRegularization(
        double regularization)
    {
        ArgumentChecker.assertIsNonNegative("regularization", regularization);
        this.regularization = regularization;
    }
    
    /**
     * Computes regression statistics using a chi-square measure of the
     * statistical significance of the learned approximator
     */
    public static class Statistic
        extends AbstractConfidenceStatistic
    {

        /**
         * Gets the value of the chi-square variable,
         * Total weighted sum-squared error between the targets and estimates
         */
        private double chiSquare;

        /**
         * Root mean-squared error of the targets and estimates
         */
        private double rootMeanSquaredError;

        /**
         * Average L1-norm error (absolute value difference) between the 
         * targets and estimates
         */
        private double meanL1Error;

        /**
         * Pearson Correlation between the targets and estimates, [-1,1]
         */
        private double targetEstimateCorrelation;

        /**
         * Fraction of variance unaccounted for in the predictions, [0,1]
         */
        private double unpredictedErrorFraction;

        /**
         * Number of samples used to create the Regression
         */
        private int numSamples;

        /**
         * Number of parameters in the learned approximator
         */
        private int numParameters;

        /**
         * Degrees of freedom in the Regression = numSamples-numParameters
         */
        private double degreesOfFreedom;

        /**
         * Creates a new instance of Statistic
         * @param targets
         * Collection of ground-truth targets for the learned approximator
         * @param estimates
         * Collection of estimates from the learned approximator
         * @param numParameters
         * Number of parameters in the learned approximator
         */
        public Statistic(
            Collection targets,
            Collection estimates,
            int numParameters )
        {
            super( 0.0 );
            Collection weights =
                Collections.nCopies( targets.size(), new Double( 1.0 ) );
            this.computeStatistics( targets, estimates, weights, numParameters );

        }

        /**
         * Creates a new instance of Statistic
         * @param targets
         * Collection of ground-truth targets for the learned approximator
         * @param estimates
         * Collection of estimates from the learned approximator
         * @param weights
         * Collection of weights to apply to the corresponding target-estimate
         * pair
         * @param numParameters
         * Number of parameters in the learned approximator
         */
        public Statistic(
            Collection targets,
            Collection estimates,
            Collection weights,
            int numParameters )
        {
            super( 0.0 );
            this.computeStatistics( targets, estimates, weights, numParameters );
        }

        /**
         * Copy Constructor
         * @param other
         * Statistic to copy
         */
        private Statistic(
            Statistic other )
        {
            super( other.getNullHypothesisProbability() );
            this.setDegreesOfFreedom( other.getDegreesOfFreedom() );
            this.setMeanL1Error( other.getMeanL1Error() );
            this.setNumParameters( other.getNumParameters() );
            this.setNumSamples( other.getNumSamples() );
            this.setRootMeanSquaredError( other.getRootMeanSquaredError() );
            this.setTargetEstimateCorrelation( other.getTargetEstimateCorrelation() );
            this.setUnpredictedErrorFraction( other.getUnpredictedErrorFraction() );
        }

        @Override
        public Statistic clone()
        {
            return (Statistic) super.clone();
        }

        /**
         * Creates a new instance of Statistic
         * @param targets
         * Collection of ground-truth targets for the learned approximator
         * @param estimates
         * Collection of estimates from the learned approximator
         * @param weights
         * Collection of weights to apply to the corresponding target-estimate
         * pair
         * @param numParameters
         * Number of parameters in the learned approximator
         */
        private void computeStatistics(
            Collection targets,
            Collection estimates,
            Collection weights,
            int numParameters )
        {
            
            if ((targets.size() != estimates.size()) &&
                (targets.size() != weights.size()))
            {
                throw new IllegalArgumentException(
                    "Targets, Estimates, and Weights must be the same size!" );
            }

            // Compute the errors between the targets and the estimates
            int num = targets.size();
            ArrayList errors = new ArrayList( num );
            double averageL1Error = 0.0;
            double weightSum = 0.0;

            Iterator it = targets.iterator();
            Iterator ie = estimates.iterator();
            Iterator iw = weights.iterator();
            for (int n = 0; n < num; n++)
            {
                double estimate = ie.next();
                double target = it.next();
                double weight = iw.next();

                double error = weight * (target - estimate);
                errors.add( error );

                averageL1Error += Math.abs( error );
                weightSum += weight;
            }
            if (weightSum > 0)
            {
                averageL1Error /= weightSum;
            }
            else
            {
                averageL1Error = 0.0;
            }

            // Make sure that the DOFs stays above 0.0
            // (so let's just cap it to a minimum of 1.0)
            double dofs = num - numParameters;
            if( dofs < 1.0 )
            {
                dofs = 1.0;
            }
            double chi2 = UnivariateStatisticsUtil.computeSumSquaredDifference( errors, 0.0 );

            double pvalue = 1.0 - ChiSquareDistribution.CDF.evaluate(
                chi2, dofs );

            double rmsError =
                UnivariateStatisticsUtil.computeRootMeanSquaredError( errors, 0.0 );

            double correlation =
                UnivariateStatisticsUtil.computeCorrelation( targets, estimates );

            double unpredictedFraction = 1.0 - (correlation * correlation);

            this.setNullHypothesisProbability( pvalue );
            this.setChiSquare( chi2 );
            this.setDegreesOfFreedom( dofs );
            this.setMeanL1Error( averageL1Error );
            this.setNumSamples( num );
            this.setRootMeanSquaredError( rmsError );
            this.setTargetEstimateCorrelation( correlation );
            this.setUnpredictedErrorFraction( unpredictedFraction );
            this.setNumParameters( numParameters );

        }
        
        /**
         * Getter for rootMeanSquaredError
         * @return 
         * Root mean-squared error of the targets and estimates
         */
        public double getRootMeanSquaredError()
        {
            return this.rootMeanSquaredError;
        }

        /**
         * Setter fpr rootMeanSquaredError
         * @param rootMeanSquaredError 
         * Root mean-squared error of the targets and estimates
         */
        protected void setRootMeanSquaredError(
            double rootMeanSquaredError )
        {
            this.rootMeanSquaredError = rootMeanSquaredError;
        }

        /**
         * Getter for targetEstimateCorrelation
         * @return 
         * Pearson Correlation between the targets and estimates, [-1,1]
         */
        public double getTargetEstimateCorrelation()
        {
            return this.targetEstimateCorrelation;
        }

        /**
         * Setter for targetEstimateCorrelation
         * @param targetEstimateCorrelation 
         * Pearson Correlation between the targets and estimates, [-1,1]
         */
        protected void setTargetEstimateCorrelation(
            double targetEstimateCorrelation )
        {
            this.targetEstimateCorrelation = targetEstimateCorrelation;
        }

        /**
         * Getter for unpredictedErrorFraction
         * @return 
         * Fraction of variance unaccounted for in the predictions, [0,1]
         */
        public double getUnpredictedErrorFraction()
        {
            return this.unpredictedErrorFraction;
        }

        /**
         * Setter for unpredictedErrorFraction
         * @param unpredictedErrorFraction 
         * Fraction of variance unaccounted for in the predictions, [0,1]
         */
        protected void setUnpredictedErrorFraction(
            double unpredictedErrorFraction )
        {
            this.unpredictedErrorFraction = unpredictedErrorFraction;
        }

        /**
         * Getter for numSamples
         * @return 
         * Number of samples used to create the Regression
         */
        public int getNumSamples()
        {
            return this.numSamples;
        }

        /**
         * Setter for numSamples
         * @param numSamples 
         * Number of samples used to create the Regression
         */
        protected void setNumSamples(
            int numSamples )
        {
            this.numSamples = numSamples;
        }

        /**
         * Getter for degreesOfFreedom
         * @return 
         * Degrees of freedom in the Regression = numSamples-numParameters
         */
        public double getDegreesOfFreedom()
        {
            return this.degreesOfFreedom;
        }

        /**
         * Setter for degreesOfFreedom
         * @param degreesOfFreedom 
         * Degrees of freedom in the Regression = numSamples-numParameters
         */
        protected void setDegreesOfFreedom(
            double degreesOfFreedom )
        {
            this.degreesOfFreedom = degreesOfFreedom;
        }

        /**
         * Getter for meanL1Error
         * @return 
         * Average L1-norm error (absolute value difference) between the 
         * targets and estimates
         */
        public double getMeanL1Error()
        {
            return this.meanL1Error;
        }

        /**
         * Setter for meanL1Error
         * @param meanL1Error 
         * Average L1-norm error (absolute value difference) between the 
         * targets and estimates
         */
        protected void setMeanL1Error(
            double meanL1Error )
        {
            this.meanL1Error = meanL1Error;
        }

        /**
         * Getter for numParameters
         * @return
         * Number of parameters in the learned approximator
         */
        public int getNumParameters()
        {
            return this.numParameters;
        }

        /**
         * Setter for numParameters
         * @param numParameters
         * Number of parameters in the learned approximator
         */
        public void setNumParameters(
            int numParameters )
        {
            this.numParameters = numParameters;
        }

        /**
         * Getter for chiSquare
         * @return
         * Gets the value of the chi-square variable,
         * Total weighted sum-squared error between the targets and estimates
         */
        public double getChiSquare()
        {
            return this.chiSquare;
        }

        /**
         * Setter for chiSquare
         * @param chiSquare
         * Gets the value of the chi-square variable,
         * Total weighted sum-squared error between the targets and estimates
         */
        public void setChiSquare(
            double chiSquare )
        {
            this.chiSquare = chiSquare;
        }

        @Override
        public double getTestStatistic()
        {
            return this.getChiSquare();
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy