gov.sandia.cognition.learning.algorithm.regression.LogisticRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cognitive-foundry Show documentation
Show all versions of cognitive-foundry Show documentation
A single jar with all the Cognitive Foundry components.
/*
* File: LogisticRegression.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Nov 27, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.regression;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.CompositeEvaluatorPair;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminantWithBias;
import gov.sandia.cognition.math.ProbabilityUtil;
import gov.sandia.cognition.math.matrix.DiagonalMatrix;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.LogisticDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.ObjectUtil;
/**
* Performs Logistic Regression by means of the iterative reweighted least
* squares (IRLS) algorithm, where the logistic function has an explicit bias
* term, and a diagonal L2 regularization term. When the regularization term
* is zero, this is equivalent to unregularized regression. The targets for
* the data should be probabilities, [0,1].
*
* @author Kevin R. Dixon
* @since 2.0
*/
@PublicationReferences(
references={
@PublicationReference(
author="Tommi S. Jaakkola",
title="Machine learning: lecture 5",
type=PublicationType.WebPage,
year=2004,
url="http://www.ai.mit.edu/courses/6.867-f04/lectures/lecture-5-ho.pdf",
notes="Good formulation of logistic regression on slides 15-20"
),
@PublicationReference(
author={
"Paul Komarek",
"Andrew Moore"
},
title="Making Logistic Regression A Core Data Mining Tool With TR-IRLS",
publication="Proceedings of the 5th International Conference on Data Mining Machine Learning",
type=PublicationType.Conference,
year=2005,
url="http://www.autonlab.org/autonweb/14717.html",
notes="Good practical overview of logistic regression"
),
@PublicationReference(
author="Christopher M. Bishop",
title="Pattern Recognition and Machine Learning",
type=PublicationType.Book,
year=2006,
pages={207,208},
notes="Section 4.3.3"
)
}
)
public class LogisticRegression
extends AbstractAnytimeSupervisedBatchLearner
{
/**
* Default number of iterations before stopping, {@value}
*/
public static final int DEFAULT_MAX_ITERATIONS = 100;
/**
* Default tolerance change in weights before stopping, {@value}
*/
public static final double DEFAULT_TOLERANCE = 1e-10;
/**
* Default regularization, {@value}.
*/
public static final double DEFAULT_REGULARIZATION = 0.0;
/**
* The object to optimize, used as a factory on successive runs of the
* algorithm.
*/
private LogisticRegression.Function objectToOptimize;
/**
* Return value from the algorithm
*/
private LogisticRegression.Function result;
/**
* Tolerance change in weights before stopping
*/
private double tolerance;
/**
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
private double regularization;
/**
* Default constructor, with no regularization.
*/
public LogisticRegression()
{
this( DEFAULT_REGULARIZATION );
}
/**
* Creates a new instance of LogisticRegression
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public LogisticRegression(
double regularization )
{
this( regularization, DEFAULT_TOLERANCE, DEFAULT_MAX_ITERATIONS );
}
/**
* Creates a new instance of LogisticRegression
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
* @param tolerance
* Tolerance change in weights before stopping
* @param maxIterations
* Maximum number of iterations before stopping
*/
public LogisticRegression(
double regularization,
double tolerance,
int maxIterations )
{
super( maxIterations );
this.setRegularization(regularization);
this.setTolerance( tolerance );
}
/**
* Weighting for each sample
*/
private transient DiagonalMatrix W;
/**
* Derivative of each sample's estimate
*/
private transient DiagonalMatrix R;
/**
* Inverse of R
*/
private transient DiagonalMatrix Ri;
/**
* Data matrix where each column is an input sample
*/
private transient Matrix X;
/**
* Transpose of the data matrix
*/
private transient Matrix Xt;
/**
* Target value minus the estimated value
*/
private transient Vector err;
@Override
public LogisticRegression clone()
{
LogisticRegression clone = (LogisticRegression) super.clone();
clone.setObjectToOptimize(
ObjectUtil.cloneSafe( this.getObjectToOptimize() ) );
clone.setResult(
ObjectUtil.cloneSafe( this.getResult() ) );
return clone;
}
@Override
protected boolean initializeAlgorithm()
{
int M = this.data.iterator().next().getInput().convertToVector().getDimensionality();
int N = this.data.size();
if( this.getObjectToOptimize() == null )
{
this.setObjectToOptimize( new Function( M ) );
}
this.setResult( this.getObjectToOptimize().clone() );
this.R = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
this.Ri = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
this.X = MatrixFactory.getDefault().createMatrix( M+1, N );
this.err = VectorFactory.getDefault().createVector( N );
this.W = MatrixFactory.getDiagonalDefault().createMatrix( N, N );
int n = 0;
Vector one = VectorFactory.getDefault().copyValues(1.0);
for( InputOutputPair extends Vectorizable,Double> sample : this.data )
{
ProbabilityUtil.assertIsProbability(sample.getOutput());
this.X.setColumn( n, sample.getInput().convertToVector().stack(one) );
this.W.setElement( n, DatasetUtil.getWeight(sample) );
n++;
}
this.Xt = this.X.transpose();
return true;
}
@Override
protected boolean step()
{
int n = 0;
LogisticRegression.Function f = this.getResult();
for( InputOutputPair extends Vectorizable,Double> sample : this.data )
{
final double y = sample.getOutput();
final double yhat = f.evaluate( sample.getInput() );
final double r = yhat*(1.0-yhat);
this.err.setElement( n, (y - yhat) );
this.R.setElement( n, r );
this.Ri.setElement( n, (r!=0.0) ? 1.0/r : 0.0 );
n++;
}
Vector w = f.convertToVector();
Vector z = w.times( this.X );
z.plusEquals( this.Ri.times( this.err ) );
this.R.timesEquals(this.W);
Matrix lhs = this.X.times( this.R.times( this.Xt ) );
if( this.regularization != 0.0 )
{
final int N = this.X.getNumRows();
for( int i = 0; i < N; i++ )
{
final double v = lhs.getElement(i, i);
lhs.setElement(i, i, v + this.regularization);
}
}
Vector rhs = this.X.times( this.R.times( z ) );
Vector wnew = lhs.solve( rhs );
f.convertFromVector( wnew );
double delta = wnew.minus( w ).norm2();
return delta > this.getTolerance();
}
@Override
protected void cleanupAlgorithm()
{
this.X = null;
this.Xt = null;
this.err = null;
this.R = null;
this.Ri = null;
this.W = null;
}
/**
* Getter for objectToOptimize
* @return
* The object to optimize, used as a factory on successive runs of the
* algorithm.
*/
public LogisticRegression.Function getObjectToOptimize()
{
return this.objectToOptimize;
}
/**
* Setter for objectToOptimize
* @param objectToOptimize
* The object to optimize, used as a factory on successive runs of the
* algorithm.
*/
public void setObjectToOptimize(
LogisticRegression.Function objectToOptimize )
{
this.objectToOptimize = objectToOptimize;
}
@Override
public LogisticRegression.Function getResult()
{
return this.result;
}
/**
* Setter for result
* @param result
* Return value from the algorithm
*/
public void setResult(
LogisticRegression.Function result )
{
this.result = result;
}
/**
* Getter for tolerance
* @return
* Tolerance change in weights before stopping, must be nonnegative.
*/
public double getTolerance()
{
return this.tolerance;
}
/**
* Setter for tolerance
* @param tolerance
* Tolerance change in weights before stopping, must be nonnegative.
*/
public void setTolerance(
double tolerance )
{
ArgumentChecker.assertIsNonNegative("tolerance", tolerance);
this.tolerance = tolerance;
}
/**
* Getter for regularization
* @return
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public double getRegularization()
{
return this.regularization;
}
/**
* Setter for regularization
* @param regularization
* L2 ridge regularization term, must be nonnegative, a value of zero is
* equivalent to unregularized regression.
*/
public void setRegularization(
double regularization)
{
ArgumentChecker.assertIsNonNegative("regularization", regularization);
this.regularization = regularization;
}
/**
* Class that is a linear discriminant, followed by a sigmoid function.
*/
public static class Function
extends CompositeEvaluatorPair
implements Vectorizable
{
/**
* Creates a new {@link LogisticRegression.Function}.
*/
public Function()
{
super();
}
/**
* Creates a new instance of Function
* @param dimensionality
* Dimensionality of the inputs
*/
public Function(
int dimensionality )
{
super( new LinearDiscriminantWithBias(
VectorFactory.getDefault().createVector( dimensionality ), 0.0 ),
new LogisticDistribution.CDF() );
}
@Override
public Function clone()
{
return (Function) super.clone();
}
@Override
public Vector convertToVector()
{
return ((Vectorizable) this.getFirst()).convertToVector();
}
@Override
public void convertFromVector(
Vector parameters )
{
((Vectorizable) this.getFirst()).convertFromVector( parameters );
}
@Override
public LinearDiscriminantWithBias getFirst()
{
return (LinearDiscriminantWithBias) super.getFirst();
}
@Override
public LogisticDistribution.CDF getSecond()
{
return (LogisticDistribution.CDF) super.getSecond();
}
}
}