All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.regression.MultivariateLinearRegression Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                MultivariateLinearRegression.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Jun 22, 2011, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government.
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminantWithBias;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;

/**
 * Performs multivariate regression with an explicit bias term, with optional
 * L2 regularization.
 * @author Kevin R. Dixon
 * @since 3.2.1
 */
@PublicationReferences(
    references={
        @PublicationReference(
            author="Wikipedia",
            title="Linear regression",
            type=PublicationType.WebPage,
            year=2008,
            url="http://en.wikipedia.org/wiki/Linear_regression"
        )
        ,
        @PublicationReference(
            author="Wikipedia",
            title="Tikhonov regularization",
            type=PublicationType.WebPage,
            year=2011,
            url="http://en.wikipedia.org/wiki/Tikhonov_regularization",
            notes="Despite what Wikipedia says, this is always called Ridge Regression"
        )
    }
)
public class MultivariateLinearRegression 
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner
{

    /**
     * Default regularization, {@value}.
     */
    public static final double DEFAULT_REGULARIZATION = 0.0;    
    
    /**
     * Tolerance for the pseudo inverse in the learn method, {@value}.
     */
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;

    /**
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    private boolean usePseudoInverse;

    /**
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    private double regularization;    
    
    /** 
     * Creates a new instance of MultivariateLinearRegression 
     */
    public MultivariateLinearRegression()
    {
        this.setUsePseudoInverse(true);
    }

    @Override
    public MultivariateLinearRegression clone()
    {
        return (MultivariateLinearRegression) super.clone();
    }

    @Override
    public MultivariateDiscriminantWithBias learn(
        Collection> data)
    {
        // We need to cheat to figure out how many coefficients we need...
        // So we'll push the first sample through... wasteful, but general
        InputOutputPair first =
            CollectionUtil.getFirst(data);
        int M = first.getInput().getDimensionality();
        int N = first.getOutput().getDimensionality();
        int numSamples = data.size();

        Matrix X =  MatrixFactory.getDefault().createMatrix( M+1, numSamples );
        Matrix Xt = MatrixFactory.getDefault().createMatrix( numSamples, M+1 );
        Matrix Y =  MatrixFactory.getDefault().createMatrix( N, numSamples );
        Matrix Yt =  MatrixFactory.getDefault().createMatrix( numSamples, N );

        // The matrix equation looks like:
        // y = C*[f0(x) f1(x) ... fn(x) ], fi() is the ith basis function
        int i = 0;
        Vector one = VectorFactory.getDefault().copyValues(1.0);
        for (InputOutputPair pair : data)
        {
            Vector output = pair.getOutput();
            Vector input = pair.getInput().convertToVector().stack(one);
            final double weight = DatasetUtil.getWeight(pair);
            if( weight != 1.0 )
            {
                // We can use scaleEquals() here because of the stack() method
                input.scaleEquals(weight);
                output = output.scale(weight);
            }
            Xt.setRow( i, input );
            X.setColumn( i, input );
            Y.setColumn( i, output );
            Yt.setRow( i, output );
            i++;
        }

        // Solve for the coefficients
        Matrix coefficients;
        if( this.getUsePseudoInverse() )
        {
            Matrix pseudoInverse = Xt.pseudoInverse(DEFAULT_PSEUDO_INVERSE_TOLERANCE);
            coefficients = pseudoInverse.times( Yt ).transpose();
        }
        else
        {
            Matrix lhs = X.times( Xt );
            if( this.regularization > 0.0 )
            {
                for( i = 0; i < M+1; i++ )
                {
                    double v = lhs.getElement(i, i);
                    lhs.setElement(i, i, v + this.regularization);
                }
            }
            Matrix rhs = Y.times( Xt );
            coefficients = lhs.solve( rhs.transpose() ).transpose();
        }
        
        Matrix discriminant = coefficients.getSubMatrix(0, N-1, 0, M-1);
        Vector bias = coefficients.getColumn(M);
        
        return new MultivariateDiscriminantWithBias( discriminant, bias );
    }

    /**
     * Getter for usePseudoInverse
     * @return
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public boolean getUsePseudoInverse()
    {
        return this.usePseudoInverse;
    }

    /**
     * Setter for usePseudoInverse
     * @param usePseudoInverse
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public void setUsePseudoInverse(
        boolean usePseudoInverse )
    {
        this.usePseudoInverse = usePseudoInverse;
    }

    /**
     * Getter for regularization
     * @return
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public double getRegularization()
    {
        return this.regularization;
    }

    /**
     * Setter for regularization
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public void setRegularization(
        double regularization)
    {
        ArgumentChecker.assertIsNonNegative("regularization", regularization);
        this.regularization = regularization;
    }    
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy