All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerGradientDescent Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                FunctionMinimizerGradientDescent.java
 * Authors:             Kevin R. Dixon
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 *
 * Copyright November 8, 2007, Sandia Corporation.  Under the terms of Contract
 * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
 * or on behalf of the U.S. Government. Export of this program may require a
 * license from the United States Government. See CopyrightHistory.txt for
 * complete details.
 *
 *
 */

package gov.sandia.cognition.learning.algorithm.minimization;

import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.math.DifferentiableEvaluator;
import gov.sandia.cognition.math.matrix.Vector;

/**
 * This is an implementation of the classic Gradient Descent algorithm, also
 * known as Steepest Descent, Backpropagation (for neural nets), or Hill 
 * Climbing.  This algorithm takes a small step in the direction indicated by
 * the gradient.  This implementation is "efficient" in that it only uses
 * gradient calls during minimization (not function calls).  We also use a
 * momentum term to mimic "heavy ball" optimization to speed up learning and
 * avoid local minima.
 *

* A few words of advice: Don't use this algorithm. I'm not one of those * hard-core "gradient descent sucks" people, but there are uniformly better * algorithms out there, such as BFGS and conjugate gradient. It's really here * for illustrative purposes and essentially contains absolutely no advantage * over BFGS or conjugate gradient minimization, except its simplicity. If * you're looking for something quick and dirty, then be my guest. However, * please consider using BFGS or CG instead. (CG is like GD, but where the * momentum and step size are optimally selected for each step.) In my * experience, non-derivative algorithms, like Powell's method, are more * efficient and have better convergence than GD. *

* Oh, yeah. The other minimization algorithms don't require you to guess * parameters either. * * @author Kevin R. Dixon * @since 2.0 * */ public class FunctionMinimizerGradientDescent extends AbstractAnytimeFunctionMinimizer> { /** * The learning rate (or step size), must be (0,1], typically ~0.1 */ private double learningRate; /** * The momentum rate, must be [0,1), typically ~0.8 */ private double momentum; /** * Default learning rate */ public static final double DEFAULT_LEARNING_RATE = 0.1; /** * Default momentum */ public static final double DEFAULT_MOMENTUM = 0.8; /** * Default tolerance */ public static final double DEFAULT_TOLERANCE = 1e-7; /** * Default max iterations */ public static final int DEFAULT_MAX_ITERATIONS = 1000000; /** * Creates a new instance of FunctionMinimizerGradientDescent */ public FunctionMinimizerGradientDescent() { this( DEFAULT_LEARNING_RATE, DEFAULT_MOMENTUM ); } /** * Creates a new instance of FunctionMinimizerGradientDescent * @param learningRate * The learning rate (or step size), must be (0,1], typically ~0.1 * @param momentum * The momentum rate, must be [0,1), typically ~0.8 */ public FunctionMinimizerGradientDescent( double learningRate, double momentum ) { this( learningRate, momentum, null, DEFAULT_TOLERANCE, DEFAULT_MAX_ITERATIONS ); } /** * Creates a new instance of FunctionMinimizerGradientDescent * @param learningRate * The learning rate (or step size), must be (0,1], typically ~0.1 * @param momentum * The momentum rate, must be [0,1), typically ~0.8 * @param initialGuess * Initial guess of the minimum * @param tolerance * Tolerance of the algorithm, must be >=0.0, typically 1e-5 * @param maxIterations * Maximum number of iterations before stopping, must be >0, typically ~1000 */ public FunctionMinimizerGradientDescent( double learningRate, double momentum, Vector initialGuess, double tolerance, int maxIterations ) { super( initialGuess, tolerance, maxIterations ); this.setLearningRate( learningRate ); this.setMomentum( momentum ); } /** * Previous input change, used for adding momentum */ private Vector previousDelta; /** * {@inheritDoc} * @return {@inheritDoc} */ protected boolean initializeAlgorithm() { this.previousDelta = null; this.result = new DefaultInputOutputPair( this.initialGuess.clone(), null ); return true; } /** * {@inheritDoc} * @return {@inheritDoc} */ protected boolean step() { Vector xhat = this.result.getInput(); // Compute the gradient and scale it by the learningRate Vector gradient = this.data.differentiate( xhat ); Vector delta = gradient.scale( -this.learningRate ); // See if we should add a momentum term if (this.previousDelta != null) { if (this.momentum != 0.0) { delta.plusEquals( this.previousDelta.scale( this.momentum ) ); } } this.previousDelta = delta; xhat.plusEquals( delta ); return !MinimizationStoppingCriterion.convergence( xhat, null, gradient, delta, this.getTolerance() ); } /** * {@inheritDoc} */ protected void cleanupAlgorithm() { double yhat = this.data.evaluate( this.result.getInput() ); this.result = DefaultInputOutputPair.create( this.result.getInput(), yhat); } /** * Getter for learningRate * @return * The learning rate (or step size), must be (0,1], typically ~0.1 */ public double getLearningRate() { return this.learningRate; } /** * Setter for learningRate * @param learningRate * The learning rate (or step size), must be (0,1], typically ~0.1 */ public void setLearningRate( double learningRate ) { if ((learningRate <= 0.0) || (learningRate > 1.0)) { throw new IllegalArgumentException( "Learning rate " + learningRate + " must be (0,1]." ); } this.learningRate = learningRate; } /** * Setter for momentum * @return * The momentum rate, must be [0,1), typically ~0.8 */ public double getMomentum() { return this.momentum; } /** * Getter for momentum * @param momentum * The momentum rate, must be [0,1), typically ~0.8 */ public void setMomentum( double momentum ) { if ((momentum < 0.0) || (momentum >= 1.0)) { throw new IllegalArgumentException( "momentum must be 0.0 <= momentum < 1.0" ); } this.momentum = momentum; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy