![JAR search and dependency download from the Maven repository](/logo.png)
edu.berkeley.nlp.math.LBFGSMinimizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of berkeleyparser Show documentation
Show all versions of berkeleyparser Show documentation
The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).
The newest version!
package edu.berkeley.nlp.math;
import java.io.Serializable;
import java.util.LinkedList;
import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.Logger;
/**
* @author Dan Klein
*/
public class LBFGSMinimizer implements GradientMinimizer, Serializable
{
private static final long serialVersionUID = 36473897808840226L;
double EPS = 1e-10;
int maxIterations = 20;
int maxHistorySize = 5;
LinkedList inputDifferenceVectorList = new LinkedList();
LinkedList derivativeDifferenceVectorList = new LinkedList();
transient CallbackFunction iterCallbackFunction = null;
int minIterations = -1;
double initialStepSizeMultiplier = 0.01;
double stepSizeMultiplier = 0.5;
boolean dumpHistoryBeforeConverge = false;
boolean alreadyDumped = false;
int historyDropIters = -1;
boolean verbose = true;
public void setDumpHistoryBeforeConverge(boolean dumpHistoryBeforeConverge)
{
this.dumpHistoryBeforeConverge = dumpHistoryBeforeConverge;
}
public void setVerbose(boolean verbose)
{
this.verbose = verbose;
}
public void dumpHistoryPeriodically(int numIters)
{
this.historyDropIters = numIters;
}
public void setMinIteratons(int minIterations)
{
this.minIterations = minIterations;
}
public void setMaxIterations(int maxIterations)
{
this.maxIterations = maxIterations;
}
public void setInitialStepSizeMultiplier(double initialStepSizeMultiplier)
{
this.initialStepSizeMultiplier = initialStepSizeMultiplier;
}
public void setStepSizeMultiplier(double stepSizeMultiplier)
{
this.stepSizeMultiplier = stepSizeMultiplier;
}
public double[] getSearchDirection(int dimension, double[] derivative)
{
double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(dimension);
double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
return direction;
}
protected double[] getInitialInverseHessianDiagonal(int dimension)
{
double scale = 1.0;
if (derivativeDifferenceVectorList.size() >= 1)
{
double[] lastDerivativeDifference = getLastDerivativeDifference();
double[] lastInputDifference = getLastInputDifference();
double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
scale = num / den;
}
return DoubleArrays.constantArray(scale, dimension);
}
public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance)
{
return minimize(function, initial, tolerance, false);
}
public double[] minimize(DifferentiableFunction function, double[] initial, double tolerance, boolean printProgress)
{
BacktrackingLineSearcher lineSearcher = new BacktrackingLineSearcher();
double[] guess = DoubleArrays.clone(initial);
int iteration = 0;
for (iteration = 0; iteration < maxIterations; iteration++)
{
if (historyDropIters > 0 && iteration % historyDropIters == 0)
{
dumpHistory();
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumped History at iter %d", iteration);
}
double[] derivative = function.derivativeAt(guess);
double value = function.valueAt(guess);
double[] initialInverseHessianDiagonal = getInitialInverseHessianDiagonal(function);
double[] direction = implicitMultiply(initialInverseHessianDiagonal, derivative);
// System.out.println(" Derivative is: "+DoubleArrays.toString(derivative, 100));
// DoubleArrays.assign(direction, derivative);
DoubleArrays.scale(direction, -1.0);
// System.out.println(" Looking in direction: "+DoubleArrays.toString(direction, 100));
if (iteration == 0)
lineSearcher.stepSizeMultiplier = initialStepSizeMultiplier;
else
lineSearcher.stepSizeMultiplier = stepSizeMultiplier;
double[] nextGuess = doLineSearch(function, lineSearcher, guess, direction);
double nextValue = function.valueAt(nextGuess);
double[] nextDerivative = function.derivativeAt(nextGuess);
if (printProgress) printProgress(iteration, nextValue);
if (iteration >= minIterations && converged(value, nextValue, tolerance))
{
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Converged.");
if (dumpHistoryBeforeConverge && !alreadyDumped)
{
dumpHistory();
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Dumping History. Doing Iteration Over");
alreadyDumped = true;
iteration--;
continue;
}
else
{
return nextGuess;
}
}
updateHistories(guess, nextGuess, derivative, nextDerivative);
guess = nextGuess;
value = nextValue;
derivative = nextDerivative;
if (iterCallbackFunction != null)
{
iterCallbackFunction.callback(guess, iteration, value, derivative);
}
}
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Stopped after " + iteration + " iterations.");
//Logger.logs("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
//System.err.println("LBFGSMinimizer.minimize: Exceeded maxIterations without converging.");
return guess;
}
/**
* This is an entry point for subclasses
*
* @param function
* @param lineSearcher
* @param guess
* @param direction
* @return
*/
protected double[] doLineSearch(DifferentiableFunction function, BacktrackingLineSearcher lineSearcher, double[] guess, double[] direction)
{
return lineSearcher.minimize(function, guess, direction);
}
private void printProgress(int iteration, double nextValue)
{
if (verbose) Logger.logs("[LBFGSMinimizer.minimize] Iteration %d ended with value %.6f", iteration, nextValue);
}
protected boolean converged(double value, double nextValue, double tolerance)
{
if (value == nextValue) return true;
double valueChange = Math.abs(nextValue - value);
double valueAverage = Math.abs(nextValue + value + EPS) / 2.0;
if (valueChange / valueAverage < tolerance) return true;
return false;
}
protected void updateHistories(double[] guess, double[] nextGuess, double[] derivative, double[] nextDerivative)
{
double[] guessChange = DoubleArrays.addMultiples(nextGuess, 1.0, guess, -1.0);
double[] derivativeChange = DoubleArrays.addMultiples(nextDerivative, 1.0, derivative, -1.0);
pushOntoList(guessChange, inputDifferenceVectorList);
pushOntoList(derivativeChange, derivativeDifferenceVectorList);
}
private void pushOntoList(double[] vector, LinkedList vectorList)
{
vectorList.addFirst(vector);
if (vectorList.size() > maxHistorySize) vectorList.removeLast();
}
private int historySize()
{
return inputDifferenceVectorList.size();
}
public void setMaxHistorySize(int maxHistorySize)
{
this.maxHistorySize = maxHistorySize;
}
private double[] getInputDifference(int num)
{
// 0 is previous, 1 is the one before that
return inputDifferenceVectorList.get(num);
}
private double[] getDerivativeDifference(int num)
{
return derivativeDifferenceVectorList.get(num);
}
private double[] getLastDerivativeDifference()
{
return derivativeDifferenceVectorList.getFirst();
}
private double[] getLastInputDifference()
{
return inputDifferenceVectorList.getFirst();
}
private double[] implicitMultiply(double[] initialInverseHessianDiagonal, double[] derivative)
{
double[] rho = new double[historySize()];
double[] alpha = new double[historySize()];
double[] right = DoubleArrays.clone(derivative);
// loop last backward
for (int i = historySize() - 1; i >= 0; i--)
{
double[] inputDifference = getInputDifference(i);
double[] derivativeDifference = getDerivativeDifference(i);
rho[i] = DoubleArrays.innerProduct(inputDifference, derivativeDifference);
if (rho[i] == 0.0) throw new RuntimeException("[LBFGSMinimizer.implicitMultiply]: Curvature problem.");
alpha[i] = DoubleArrays.innerProduct(inputDifference, right) / rho[i];
right = DoubleArrays.addMultiples(right, 1.0, derivativeDifference, -1.0 * alpha[i]);
}
double[] left = DoubleArrays.pointwiseMultiply(initialInverseHessianDiagonal, right);
for (int i = 0; i < historySize(); i++)
{
double[] inputDifference = getInputDifference(i);
double[] derivativeDifference = getDerivativeDifference(i);
double beta = DoubleArrays.innerProduct(derivativeDifference, left) / rho[i];
left = DoubleArrays.addMultiples(left, 1.0, inputDifference, alpha[i] - beta);
}
return left;
}
private double[] getInitialInverseHessianDiagonal(DifferentiableFunction function)
{
double scale = 1.0;
if (derivativeDifferenceVectorList.size() >= 1)
{
double[] lastDerivativeDifference = getLastDerivativeDifference();
double[] lastInputDifference = getLastInputDifference();
double num = DoubleArrays.innerProduct(lastDerivativeDifference, lastInputDifference);
double den = DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
scale = num / den;
}
return DoubleArrays.constantArray(scale, function.dimension());
}
/**
* User callback function to test or examine weights at the end of each
* iteration
*
* @param callbackFunction
* Will get called with the following args (double[]
* currentGuess, int iterDone, double value, double[] derivative)
* You don't have to read any or all of these.
*/
public void setIterationCallbackFunction(CallbackFunction callbackFunction)
{
this.iterCallbackFunction = callbackFunction;
}
public LBFGSMinimizer()
{
}
public LBFGSMinimizer(int maxIterations)
{
this.maxIterations = maxIterations;
}
public void dumpHistory()
{
inputDifferenceVectorList.clear();
derivativeDifferenceVectorList.clear();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy