All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.regression.LogisticRegression Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                LogisticRegression.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright Nov 27, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 * 
 */

package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.CompositeEvaluatorPair;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminantWithBias;
import gov.sandia.cognition.math.ProbabilityUtil;
import gov.sandia.cognition.math.matrix.DiagonalMatrix;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.LogisticDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.ObjectUtil;

/**
 * Performs Logistic Regression by means of the iterative reweighted least
 * squares (IRLS) algorithm, where the logistic function has an explicit bias
 * term, and a diagonal L2 regularization term.  When the regularization term
 * is zero, this is equivalent to unregularized regression.  The targets for
 * the data should be probabilities, [0,1].
 * 
 * @author Kevin R. Dixon
 * @since 2.0
 */
@PublicationReferences(
    references={
        @PublicationReference(
            author="Tommi S. Jaakkola",
            title="Machine learning: lecture 5",
            type=PublicationType.WebPage,
            year=2004,
            url="http://www.ai.mit.edu/courses/6.867-f04/lectures/lecture-5-ho.pdf",
            notes="Good formulation of logistic regression on slides 15-20"
        ),
        @PublicationReference(
            author={
                "Paul Komarek",
                "Andrew Moore"
            },
            title="Making Logistic Regression A Core Data Mining Tool With TR-IRLS",
            publication="Proceedings of the 5th International Conference on Data Mining Machine Learning",
            type=PublicationType.Conference,
            year=2005,
            url="http://www.autonlab.org/autonweb/14717.html",
            notes="Good practical overview of logistic regression"
        ),
        @PublicationReference(
            author="Christopher M. Bishop",
            title="Pattern Recognition and Machine Learning",
            type=PublicationType.Book,
            year=2006,
            pages={207,208},
            notes="Section 4.3.3"
        )
    }
)
public class LogisticRegression
    extends AbstractAnytimeSupervisedBatchLearner
{

    /**
     * Default number of iterations before stopping, {@value}
     */
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    
    /**
     * Default tolerance change in weights before stopping, {@value}
     */
    public static final double DEFAULT_TOLERANCE = 1e-10;
    
    /**
     * Default regularization, {@value}.
     */
    public static final double DEFAULT_REGULARIZATION = 0.0;

    /**
     * The object to optimize, used as a factory on successive runs of the
     * algorithm.
     */
    private LogisticRegression.Function objectToOptimize;

    /**
     * Return value from the algorithm
     */
    private LogisticRegression.Function result;
    
    /**
     * Tolerance change in weights before stopping
     */
    private double tolerance;

    /**
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    private double regularization;

    /**
     * Default constructor, with no regularization.
     */
    public LogisticRegression()
    {
        this( DEFAULT_REGULARIZATION );
    }
    
    /**
     * Creates a new instance of LogisticRegression
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public LogisticRegression(
        double regularization )
    {
        this( regularization, DEFAULT_TOLERANCE, DEFAULT_MAX_ITERATIONS );
    }
    
    /**
     * Creates a new instance of LogisticRegression
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     * @param tolerance
     * Tolerance change in weights before stopping
     * @param maxIterations
     * Maximum number of iterations before stopping
     */
    public LogisticRegression(
        double regularization,
        double tolerance,
        int maxIterations )
    {
        super( maxIterations );
        this.setRegularization(regularization);
        this.setTolerance( tolerance );
    }
    
    /**
     * Weighting for each sample
     */
    private transient DiagonalMatrix W;
    
    /**
     * Derivative of each sample's estimate
     */
    private transient DiagonalMatrix R;

    /**
     * Inverse of R
     */
    private transient DiagonalMatrix Ri;

    /**
     * Data matrix where each column is an input sample
     */
    private transient Matrix X;
    
    /**
     * Transpose of the data matrix
     */
    private transient Matrix Xt;
    
    /**
     * Target value minus the estimated value
     */
    private transient Vector err;

    @Override
    public LogisticRegression clone()
    {
        LogisticRegression clone = (LogisticRegression) super.clone();
        clone.setObjectToOptimize(
            ObjectUtil.cloneSafe( this.getObjectToOptimize() ) );
        clone.setResult(
            ObjectUtil.cloneSafe( this.getResult() ) );
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm()
    {
        
        int M = this.data.iterator().next().getInput().convertToVector().getDimensionality();
        int N = this.data.size();
        
        if( this.getObjectToOptimize() == null )
        {
            this.setObjectToOptimize( new Function( M ) );
        }
        
        this.setResult( this.getObjectToOptimize().clone() );
        
        this.R = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
        this.Ri = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
        this.X = MatrixFactory.getDefault().createMatrix( M+1, N );
        this.err = VectorFactory.getDefault().createVector( N );
        this.W = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
        
        int n = 0;
        Vector one = VectorFactory.getDefault().copyValues(1.0);
        for( InputOutputPair sample : this.data )
        {
            ProbabilityUtil.assertIsProbability(sample.getOutput());
            this.X.setColumn( n, sample.getInput().convertToVector().stack(one) );
            this.W.setElement( n, DatasetUtil.getWeight(sample) );
            n++;
        }
        this.Xt = this.X.transpose();
        
        return true;
        
    }

    @Override
    protected boolean step()
    {
        
        int n = 0;
        LogisticRegression.Function f = this.getResult();
        for( InputOutputPair sample : this.data )
        {
            final double y = sample.getOutput();
            final double yhat = f.evaluate( sample.getInput() );
            final double r = yhat*(1.0-yhat);
            this.err.setElement( n, (y - yhat) );
            this.R.setElement( n, r );
            this.Ri.setElement( n, (r!=0.0) ? 1.0/r : 0.0 );
            n++;
        }
        
        Vector w = f.convertToVector();

        Vector z = w.times( this.X );
        z.plusEquals( this.Ri.times( this.err ) );

        this.R.timesEquals(this.W);
        Matrix lhs = this.X.times( this.R.times( this.Xt ) );
        if( this.regularization != 0.0 )
        {
            final int N = this.X.getNumRows();
            for( int i = 0; i < N; i++ )
            {
                final double v = lhs.getElement(i, i);
                lhs.setElement(i, i, v + this.regularization);
            }
        }

        Vector rhs = this.X.times( this.R.times( z ) );
        
        Vector wnew = lhs.solve( rhs );        
        f.convertFromVector( wnew );
        
        double delta = wnew.minus( w ).norm2();
        return delta > this.getTolerance();
        
    }

    @Override
    protected void cleanupAlgorithm()
    {
        this.X = null;
        this.Xt = null;
        this.err = null;
        this.R = null;
        this.Ri = null;
        this.W = null;
    }

    /**
     * Getter for objectToOptimize
     * @return
     * The object to optimize, used as a factory on successive runs of the
     * algorithm.
     */
    public LogisticRegression.Function getObjectToOptimize()
    {
        return this.objectToOptimize;
    }

    /**
     * Setter for objectToOptimize
     * @param objectToOptimize
     * The object to optimize, used as a factory on successive runs of the
     * algorithm.
     */
    public void setObjectToOptimize(
        LogisticRegression.Function objectToOptimize )
    {
        this.objectToOptimize = objectToOptimize;
    }

    @Override
    public LogisticRegression.Function getResult()
    {
        return this.result;
    }

    /**
     * Setter for result
     * @param result
     * Return value from the algorithm
     */
    public void setResult(
        LogisticRegression.Function result )
    {
        this.result = result;
    }

    /**
     * Getter for tolerance
     * @return
     * Tolerance change in weights before stopping, must be nonnegative.
     */
    public double getTolerance()
    {
        return this.tolerance;
    }

    /**
     * Setter for tolerance
     * @param tolerance
     * Tolerance change in weights before stopping, must be nonnegative.
     */
    public void setTolerance(
        double tolerance )
    {
        ArgumentChecker.assertIsNonNegative("tolerance", tolerance);
        this.tolerance = tolerance;
    }

    /**
     * Getter for regularization
     * @return 
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public double getRegularization()
    {
        return this.regularization;
    }

    /**
     * Setter for regularization
     * @param regularization
     * L2 ridge regularization term, must be nonnegative, a value of zero is
     * equivalent to unregularized regression.
     */
    public void setRegularization(
        double regularization)
    {
        ArgumentChecker.assertIsNonNegative("regularization", regularization);
        this.regularization = regularization;
    }
    
    /**
     * Class that is a linear discriminant, followed by a sigmoid function.
     */
    public static class Function
        extends CompositeEvaluatorPair
        implements Vectorizable
    {
        /**
         * Creates a new {@link LogisticRegression.Function}.
         */
        public Function()
        {
            super();
        }
        
        /**
         * Creates a new instance of Function
         * @param dimensionality
         * Dimensionality of the inputs
         */
        public Function(
            int dimensionality )
        {   
            super( new LinearDiscriminantWithBias(
                VectorFactory.getDefault().createVector( dimensionality ), 0.0 ),
                new LogisticDistribution.CDF() );
        }
        
        @Override
        public Function clone()
        {
            return (Function) super.clone();
        }

        @Override
        public Vector convertToVector()
        {
            return ((Vectorizable) this.getFirst()).convertToVector();
        }

        @Override
        public void convertFromVector(
            Vector parameters )
        {
            ((Vectorizable) this.getFirst()).convertFromVector( parameters );
        }

        @Override
        public LinearDiscriminantWithBias getFirst()
        {
            return (LinearDiscriminantWithBias) super.getFirst();
        }

        @Override
        public LogisticDistribution.CDF getSecond()
        {
            return (LogisticDistribution.CDF) super.getSecond();
        }
        
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy