All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.regression.LinearBasisRegression Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                LinearRegression.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright September 5, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 */

package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

/**
 * Computes the least-squares regression for a LinearCombinationFunction
 * given a dataset.  A LinearCombinationFunction is a weighted linear
 * combination of (potentially) nonlinear basis functions.  This looks like
 * y(x) = a0*f0(x) + a1*f1(x) + ... + an*fn(x) and so forth.
 * The internal class LinearRegression.Statistic returns the goodness-of-fit
 * statistics for a set of target-estimate pairs, include a p-value for the
 * null hypothesis significance.
 *
 * @param    Input class for the basis functions, for example, Double,
 *          Vector, String.
 * @author  Kevin R. Dixon
 * @since   2.0
 */
@CodeReview(
    reviewer="Kevin R. Dixon",
    date="2008-09-02",
    changesNeeded=false,
    comments={
        "Made minor changes to javadoc",
        "Looks fine."
    }
)
@PublicationReference(
    author="Wikipedia",
    title="Linear regression",
    type=PublicationType.WebPage,
    year=2008,
    url="http://en.wikipedia.org/wiki/Linear_regression"
)
public class LinearBasisRegression
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner>
{

    /**
     * Tolerance for the pseudo inverse in the learn method, {@value}.
     */
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10;

    /**
     * Function that maps the InputType to a Vector
     */
    private Evaluator inputToVectorMap;
    
    /**
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    private boolean usePseudoInverse;

    /**
     * Creates a new instance of LinearRegression
     * @param basisFunctions 
     * Basis functions to create the ScalarBasisSet from
     */
    public LinearBasisRegression(
        Collection> basisFunctions )
    {
        this( new ScalarBasisSet( basisFunctions ) );
    }

    /**
     * Creates a new instance of LinearRegression
     * @param inputToVectorMap 
     * Function that maps the InputType to a Vector
     */
    public LinearBasisRegression(
        ScalarBasisSet inputToVectorMap )
    {
        this( (Evaluator) inputToVectorMap );
    }

    /**
     * Creates a new instance of LinearRegression
     * @param inputToVectorMap
     * Function that maps the InputType to a Vector
     */
    public LinearBasisRegression(
        Evaluator inputToVectorMap )
    {
        this.setInputToVectorMap( inputToVectorMap );
        this.setUsePseudoInverse( true );
    }

    @Override
    public LinearBasisRegression clone()
    {
        @SuppressWarnings("unchecked")
        LinearBasisRegression clone =
            (LinearBasisRegression) super.clone();
        clone.setInputToVectorMap(
            ObjectUtil.cloneSmart( this.getInputToVectorMap() ) );
        return clone;
    }

    /**
     * Computes the linear regression for the given Collection of
     * InputOutputPairs.  The inputs of the pairs is the independent variable,
     * and the pair output is the dependent variable (variable to predict).
     * The pairs can have an associated weight to bias the regression equation.
     * @param data 
     * Collection of InputOutputPairs for the variables.  Can be 
     * WeightedInputOutputPairs.
     * @return 
     * LinearCombinationFunction that minimizes the RMS error of the outputs.
     */
    @Override
    public VectorFunctionLinearDiscriminant learn(
        Collection> data )
    {
        // Create the vector-based dataset first
        ArrayList> vectorData =
            new ArrayList>( data.size() );
        for (InputOutputPair pair : data)
        {
            double weight = DatasetUtil.getWeight(pair);
            Vector xrow = this.inputToVectorMap.evaluate( pair.getInput() );
            Double output = pair.getOutput();
            vectorData.add( DefaultWeightedInputOutputPair.create( xrow, output, weight ) );
        }
        
        LinearRegression linear = new LinearRegression();
        linear.setUsePseudoInverse(this.getUsePseudoInverse());
        LinearDiscriminant weights = linear.learn(vectorData);
        return new VectorFunctionLinearDiscriminant(
            this.inputToVectorMap, weights );
    }

    /**
     * Getter for inputToVectorMap
     * @return
     * Function that maps the InputType to a Vector
     */
    public Evaluator getInputToVectorMap()
    {
        return this.inputToVectorMap;
    }

    /**
     * Setter for inputToVectorMap
     * @param inputToVectorMap
     * Function that maps the InputType to a Vector
     */
    public void setInputToVectorMap(
        Evaluator inputToVectorMap )
    {
        this.inputToVectorMap = inputToVectorMap;
    }

    /**
     * Getter for usePseudoInverse
     * @return
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public boolean getUsePseudoInverse()
    {
        return this.usePseudoInverse;
    }

    /**
     * Setter for usePseudoInverse
     * @param usePseudoInverse
     * Flag to use a pseudoinverse.  True to use the expensive, but more
     * accurate, pseudoinverse routine.  False uses a very fast, but
     * numerically less stable LU solver.  Default value is "true".
     */
    public void setUsePseudoInverse(
        boolean usePseudoInverse )
    {
        this.usePseudoInverse = usePseudoInverse;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy