All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.optimize.LimitedMemoryBFGS Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

/** 
   @author Aron Culotta [email protected]
 */

/**
 Limited Memory BFGS, as described in Byrd, Nocedal, and Schnabel,
 "Representations of Quasi-Newton Matrices and Their Use in Limited
 Memory Methods"
 */
package cc.mallet.optimize;

import java.util.logging.*;
import java.util.LinkedList;

import cc.mallet.optimize.BackTrackLineSearch;
import cc.mallet.optimize.LineOptimizer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;

public class LimitedMemoryBFGS implements Optimizer {
	
	private static Logger logger = MalletLogger.getLogger("edu.umass.cs.mallet.base.ml.maximize.LimitedMemoryBFGS");

	boolean converged = false;
	Optimizable.ByGradientValue optimizable;
	final int maxIterations = 1000;	
	// xxx need a more principled stopping point
	//final double tolerance = .0001;
	private double tolerance = .0001;
	final double gradientTolerance = .001;
	final double eps = 1.0e-5;

	// The number of corrections used in BFGS update
	// ideally 3 <= m <= 7. Larger m means more cpu time, memory.
	final int m = 4;

	// Line search function
	private LineOptimizer.ByGradient lineMaximizer;
	
	public LimitedMemoryBFGS (Optimizable.ByGradientValue function) {
		this.optimizable = function;
		lineMaximizer = new BackTrackLineSearch (function);
	}
	
	public Optimizable getOptimizable () { return this.optimizable; }
	public boolean isConverged () { return converged; }


	/**
	 * Sets the LineOptimizer.ByGradient to use in L-BFGS optimization.
	 * @param lineOpt line optimizer for L-BFGS
	 */
	public void setLineOptimizer(LineOptimizer.ByGradient lineOpt) {
		lineMaximizer = lineOpt;
	}

	// State of search
	// g = gradient
	// s = list of m previous "parameters" values
	// y = list of m previous "g" values
	// rho = intermediate calculation
	double [] g, oldg, direction, parameters, oldParameters;
	LinkedList s = new LinkedList();
	LinkedList y = new LinkedList();
	LinkedList rho = new LinkedList();
	double [] alpha;
	static double step = 1.0;
	int iterations;

	private OptimizerEvaluator.ByGradient eval = null;

	// CPAL - added this
	public void setTolerance(double newtol) {
		this.tolerance = newtol;
	}

	public void setEvaluator (OptimizerEvaluator.ByGradient eval) {
		this.eval = eval;
	}
	
	public int getIteration () {
		return iterations;
	}

	public boolean optimize () {
		return optimize (Integer.MAX_VALUE);
	}

	public boolean optimize (int numIterations) {

		double initialValue = optimizable.getValue();
		logger.fine("Entering L-BFGS.optimize(). Initial Value="+initialValue);		

		if (g == null) { //first time through
			
			logger.fine("First time through L-BFGS");
			iterations = 0;
			s = new LinkedList();
			y = new LinkedList();
			rho = new LinkedList();
			alpha = new double[m];	    

			for (int i=0; i 0)) {
					oldParameters[i] = 0.0;
				}
				else {
					oldParameters[i] = parameters[i] - oldParameters[i];
				}
				
				if (Double.isInfinite(g[i]) &&
						Double.isInfinite(oldg[i]) &&
						(g[i] * oldg[i] > 0)) {
					oldg[i] = 0.0;
				}
				else {
					oldg[i] = g[i] - oldg[i];
				}
				
				sy += oldParameters[i] * oldg[i]; 	 // si * yi
				yy += oldg[i] * oldg[i];
				direction[i] = g[i];
			}

			if ( sy > 0 ) {
				throw new InvalidOptimizableException ("sy = "+sy+" > 0" );
			}

			double gamma = sy / yy;	 // scaling factor
			
			if ( gamma > 0 ) {
				throw new InvalidOptimizableException ("gamma = "+gamma+" > 0" );
			}

			push (rho, 1.0/sy);
			// These arrays are now the *differences* between parameters and gradient.
			push (s, oldParameters);
			push (y, oldg);
			
			assert (s.size() == y.size()) : "s.size: " + s.size() + " y.size: " + y.size();

			//
			// This next section is where we calculate the new direction
			//
				
			// First work backwards, from the most recent difference vectors
			for (int i = s.size() - 1; i >= 0; i--) {
				alpha[i] = ((Double)rho.get(i)).doubleValue() * MatrixOps.dotProduct ( (double[])s.get(i), direction );
				MatrixOps.plusEquals (direction, (double[])y.get(i), -1.0 * alpha[i]);
			}
			
			// Scale the direction by the ratio of s'y and y'y
			MatrixOps.timesEquals(direction, gamma);
			
			// Now work forwards, from the oldest to the newest difference vectors
			for (int i = 0; i < y.size(); i++) {
				double beta =
					(((Double)rho.get(i)).doubleValue()) *
					MatrixOps.dotProduct((double[])y.get(i), direction);
				MatrixOps.plusEquals(direction, (double[])s.get(i), alpha[i] - beta);
			}

			// Move the current values to the "last iteration" buffers and negate the search direction
			for (int i=0; i < oldg.length; i++) {
				oldParameters[i] = parameters[i];
				oldg[i] = g[i];
				direction[i] *= -1.0;
			}
			
			logger.fine ("before linesearch: direction.gradient.dotprod: "+
					MatrixOps.dotProduct(direction,g)+"\ndirection.2norm: " +
					MatrixOps.twoNorm (direction) + "\nparameters.2norm: " +
					MatrixOps.twoNorm(parameters));
			
			// Test whether the gradient is ok
			//TestMaximizable.testValueAndGradientInDirection (maxable, direction);

			// Do a line search in the current direction		
			step = lineMaximizer.optimize(direction, step);
			
			if (step == 0.0) { // could not step in this direction. 
				g = null; // reset search
				step = 1.0;
				// xxx Temporary test; passed OK
 			    // TestMaximizable.testValueAndGradientInDirection (maxable, direction);
				throw new OptimizationException("Line search could not step in the current direction. " +
						"(This is not necessarily cause for alarm. Sometimes this happens close to the maximum," +
						" where the function may be very flat.)");
				//	return false;
			}
			optimizable.getParameters (parameters);
			optimizable.getValueGradient(g);
			logger.fine ("after linesearch: direction.2norm: " +
					MatrixOps.twoNorm (direction));					
			double newValue = optimizable.getValue();

			// Test for terminations
			if (2.0 * Math.abs(newValue-value) <= tolerance * (Math.abs(newValue) + Math.abs(value) + eps)) {
				logger.info("Exiting L-BFGS on termination #1:\nvalue difference below tolerance (oldValue: " + value + " newValue: " + newValue);
				converged = true;
				return true;
			}
			double gg = MatrixOps.twoNorm(g);
			if (gg < gradientTolerance) {
				logger.fine("Exiting L-BFGS on termination #2: \ngradient="+gg+" < "+gradientTolerance);
				converged = true;
				return true;
			}	    
			if (gg == 0.0) {
				logger.fine("Exiting L-BFGS on termination #3: \ngradient==0.0");
				converged = true;
				return true;
			}
			logger.fine("Gradient = "+gg);
			iterations++;
			if (iterations > maxIterations) {
				System.err.println("Too many iterations in L-BFGS.java. Continuing with current parameters.");
				converged = true;
				return true;
				//throw new IllegalStateException ("Too many iterations.");
			}

			//end of iteration. call evaluator
			if (eval != null && ! eval.evaluate (optimizable, iterationCount)) {
				logger.fine ("Exiting L-BFGS on termination #4: evaluator returned false.");
				converged = true;
				return false;
			}
		}
		return false;
	}

	/** Resets the previous gradients and values that are used to
	 * approximate the Hessian. NOTE - If the {@link Optimizable} object
	 * is modified externally, this method should be called to avoid
	 * IllegalStateExceptions. */
	public void reset () {
		g = null;
	}

	/**
	 * Pushes a new object onto the queue l
	 * @param l linked list queue of Matrix obj's
	 * @param toadd matrix to push onto queue
	 */
	private void push(LinkedList l, double[] toadd) {
		assert(l.size() <= m);
		if (l.size() == m) {
			// remove oldest matrix and add newest to end of list.
			// to make this more efficient, actually overwrite
			// memory of oldest matrix

			// this overwrites the oldest matrix
			double[] last = (double[]) l.get(0);
			System.arraycopy(toadd, 0, last, 0, toadd.length);
			Object ptr = last;
			// this readjusts the pointers in the list
			for (int i=0; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy