All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.math.LBFGSMinimizer Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.math;

import java.io.Serializable;
import java.util.LinkedList;

import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.Logger;

/**
 * @author Dan Klein
 */
public class LBFGSMinimizer implements GradientMinimizer, Serializable
{
	private static final long serialVersionUID = 36473897808840226L;

	double EPS = 1e-10;

	int maxIterations = 20;

	int maxHistorySize = 5;

	LinkedList inputDifferenceVectorList = new LinkedList();

	LinkedList derivativeDifferenceVectorList = new LinkedList();

	transient CallbackFunction iterCallbackFunction = null;

	int minIterations = -1;

	double initialStepSizeMultiplier = 0.01;

	double stepSizeMultiplier = 0.5;

	boolean dumpHistoryBeforeConverge = false;

	boolean alreadyDumped = false;

	int historyDropIters = -1;

	boolean verbose = true;

	public void setDumpHistoryBeforeConverge(boolean dumpHistoryBeforeConverge)
	{
		this.dumpHistoryBeforeConverge = dumpHistoryBeforeConverge;
	}

	public void setVerbose(boolean verbose)
	{
		this.verbose = verbose;
	}

	public void dumpHistoryPeriodically(int numIters)
	{
		this.historyDropIters = numIters;
	}

	public void setMinIteratons(int minIterations)
	{
		this.minIterations = minIterations;
	}

	public void setMaxIterations(int maxIterations)
	{
		this.maxIterations = maxIterations;
	}

	public void setInitialStepSizeMultiplier(double initialStepSizeMultiplier)
	{
		this.initialStepSizeMultiplier = initialStepSizeMultiplier;
	}

	public void setStepSizeMultiplier(double stepSizeMultiplier)
	{
		this.stepSizeMultiplier = stepSizeMultiplier;
	}

	public double[] getSearchDirection(int dimension, double[] derivative)
	{
		double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(dimension);
		double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
		return direction;
	}

	protected double[] getInitialInverseHessianDiagonal(int dimension)
	{
		double scale = 1.0;
		if (derivativeDifferenceVectorList.size() >= 1)
		{
			double[] lastDerivativeDifference = getLastDerivativeDifference();
			double[] lastInputDifference = getLastInputDifference();
			double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
			double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
			scale = num / den;
		}
		return DoubleArrays.constantArray(scale, dimension);
	}

	public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance)
	{
		return minimize(function, initial, tolerance, false);
	}

	public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance, boolean printProgress)
	{

		BacktrackingLineSearcher lineSearcher = new BacktrackingLineSearcher();
		double[] guess = DoubleArrays.clone(initial);
		int iteration = 0;
		for (iteration = 0; iteration < maxIterations; iteration++)
		{
			if (historyDropIters > 0 && iteration % historyDropIters == 0)
			{
				dumpHistory();
				if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumped History at iter %d", iteration);
			}
			double[] derivative = function.derivativeAt(guess);
			double value = function.valueAt(guess);
			double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(function);
			double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
			//      System.out.println(" Derivative is: "+DoubleArrays.toString(derivative, 100));
			//      DoubleArrays.assign(direction, derivative);
			DoubleArrays.scale(direction, -1.0);
			//      System.out.println(" Looking in direction: "+DoubleArrays.toString(direction, 100));
			if (iteration == 0)
				lineSearcher.stepSizeMultiplier = initialStepSizeMultiplier;
			else
				lineSearcher.stepSizeMultiplier = stepSizeMultiplier;
			double[] nextGuess = doLineSearch(function, lineSearcher, guess, direction);
			double nextValue = function.valueAt(nextGuess);
			double[] nextDerivative = function.derivativeAt(nextGuess);
			if (printProgress) printProgress(iteration, nextValue);

			if (iteration >= minIterations && converged(value, nextValue, tolerance))
			{
				if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Converged.");
				if (dumpHistoryBeforeConverge && !alreadyDumped)
				{
					dumpHistory();
					if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumping History. Doing Iteration Over");
					alreadyDumped = true;
					iteration--;
					continue;
				}
				else
				{
					return nextGuess;
				}
			}
			updateHistories(guess, nextGuess, derivative, nextDerivative);
			guess = nextGuess;
			value = nextValue;
			derivative = nextDerivative;
			if (iterCallbackFunction != null)
			{
				iterCallbackFunction.callback(guess, iteration, value, derivative);
			}
		}
		if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Stopped after " + iteration + " iterations.");
		//Logger.logs("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
		//System.err.println("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
		return guess;
	}

	/**
	 * This is an entry point for subclasses
	 * 
	 * @param function
	 * @param lineSearcher
	 * @param guess
	 * @param direction
	 * @return
	 */
	protected double[] doLineSearch(DifferentiableFunction function, BacktrackingLineSearcher lineSearcher, double[] guess, double[] direction)
	{
		return lineSearcher.minimize(function, guess, direction);
	}

	private void printProgress(int iteration, double nextValue)
	{
		if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Iteration %d ended with value %.6f", iteration, nextValue);
	}

	protected boolean converged(double value, double nextValue, double tolerance)
	{
		if (value == nextValue) return true;
		double valueChange = Math.abs(nextValue - value);
		double valueAverage = Math.abs(nextValue + value + EPS) / 2.0;
		if (valueChange / valueAverage < tolerance) return true;
		return false;
	}

	protected void updateHistories(double[] guess, double[] nextGuess, double[] derivative, double[] nextDerivative)
	{
		double[] guessChange = DoubleArrays.addMultiples(nextGuess, 1.0, guess, -1.0);
		double[] derivativeChange = DoubleArrays.addMultiples(nextDerivative, 1.0, derivative, -1.0);
		pushOntoList(guessChange, inputDifferenceVectorList);
		pushOntoList(derivativeChange, derivativeDifferenceVectorList);
	}

	private void pushOntoList(double[] vector, LinkedList vectorList)
	{
		vectorList.addFirst(vector);
		if (vectorList.size() > maxHistorySize) vectorList.removeLast();
	}

	private int historySize()
	{
		return inputDifferenceVectorList.size();
	}

	public void setMaxHistorySize(int maxHistorySize)
	{
		this.maxHistorySize = maxHistorySize;
	}

	private double[] getInputDifference(int num)
	{
		// 0 is previous, 1 is the one before that
		return inputDifferenceVectorList.get(num);
	}

	private double[] getDerivativeDifference(int num)
	{
		return derivativeDifferenceVectorList.get(num);
	}

	private double[] getLastDerivativeDifference()
	{
		return derivativeDifferenceVectorList.getFirst();
	}

	private double[] getLastInputDifference()
	{
		return inputDifferenceVectorList.getFirst();
	}

	private double[] implicitMultiply(double[] initialInverseHessianDiagonal, double[] derivative)
	{
		double[] rho = new double[historySize()];
		double[] alpha = new double[historySize()];
		double[] right = DoubleArrays.clone(derivative);
		// loop last backward
		for (int i = historySize() - 1; i >= 0; i--)
		{
			double[] inputDifference = getInputDifference(i);
			double[] derivativeDifference = getDerivativeDifference(i);
			rho[i] = DoubleArrays.innerProduct(inputDifference, derivativeDifference);
			if (rho[i] == 0.0) throw new RuntimeException("[LBFGSMinimizer.implicitMultiply]: Curvature problem.");
			alpha[i] = DoubleArrays.innerProduct(inputDifference, right) / rho[i];
			right = DoubleArrays.addMultiples(right, 1.0, derivativeDifference, -1.0 * alpha[i]);
		}
		double[] left = DoubleArrays.pointwiseMultiply(initialInverseHessianDiagonal, right);
		for (int i = 0; i < historySize(); i++)
		{
			double[] inputDifference = getInputDifference(i);
			double[] derivativeDifference = getDerivativeDifference(i);
			double beta = DoubleArrays.innerProduct(derivativeDifference, left) / rho[i];
			left = DoubleArrays.addMultiples(left, 1.0, inputDifference, alpha[i] - beta);
		}
		return left;
	}

	private double[] getInitialInverseHessianDiagonal(DifferentiableFunction function)
	{
		double scale = 1.0;
		if (derivativeDifferenceVectorList.size() >= 1)
		{
			double[] lastDerivativeDifference = getLastDerivativeDifference();
			double[] lastInputDifference = getLastInputDifference();
			double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
			double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
			scale = num / den;
		}
		return DoubleArrays.constantArray(scale, function.dimension());
	}

	/**
	 * User callback function to test or examine weights at the end of each
	 * iteration
	 * 
	 * @param callbackFunction
	 *            Will get called with the following args (double[]
	 *            currentGuess, int iterDone, double value, double[] derivative)
	 *            You don't have to read any or all of these.
	 */
	public void setIterationCallbackFunction(CallbackFunction callbackFunction)
	{
		this.iterCallbackFunction = callbackFunction;
	}

	public LBFGSMinimizer()
	{
	}

	public LBFGSMinimizer(int maxIterations)
	{
		this.maxIterations = maxIterations;
	}

	public void dumpHistory()
	{
		inputDifferenceVectorList.clear();
		derivativeDifferenceVectorList.clear();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy