All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.asher_stern.crf.function.optimization.LbfgsMinimizer Maven / Gradle / Ivy

Go to download

Implementation of linear-chain Conditional Random Fields (CRF) in pure Java

There is a newer version: 1.2.0
Show newest version
package com.asher_stern.crf.function.optimization;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.ListIterator;

import org.apache.log4j.Logger;

import com.asher_stern.crf.function.DerivableFunction;
import com.asher_stern.crf.utilities.CrfException;
import com.asher_stern.crf.utilities.DerivableFunctionWithLastCache;
import com.asher_stern.crf.utilities.VectorUtilities;

/**
 * Implementation of L-BFGS algorithm for minimizing a function.
 * 
* L-BFGS stands for "Limited memory BFGS", where "BFGS" is an acronym of * "Broyden Fletcher Goldfarb Shanno" who developed the BFGS algorithm. *
* The BFGS algorithm approximates Newton method for optimization, by approximating * the inverse of the Hessian without calculating the exact Hessian. * The L-BFGS algorithm approximates the BFGS algorithm by approximating calculations * that are performed with the inverse of the Hessian, but stores neither the * inverse of the Hessian nor its approximation in the memory. * Thus the L-BFGS algorithm is much cheaper in space complexity notion. *
* The L-BFGS algorithm is described in the book "Numerical Optimization" by Jorge Nocedal and Stephen J. Wright, * Chapter 9. The book can be downloaded from http://www.bioinfo.org.cn/~wangchao/maa/Numerical_Optimization.pdf * * * * @author Asher Stern * Date: Nov 7, 2014 * */ public class LbfgsMinimizer extends Minimizer { public static final int DEFAULT_NUMBER_OF_PREVIOUS_ITERATIONS_TO_MEMORIZE = 20; public static final double DEFAULT_CONVERGENCE = 0.001; public LbfgsMinimizer(DerivableFunction function) { this(function,DEFAULT_NUMBER_OF_PREVIOUS_ITERATIONS_TO_MEMORIZE,DEFAULT_CONVERGENCE); } public LbfgsMinimizer(DerivableFunction function, int numberOfPreviousIterationsToMemorize, double convergence) { // super(function); super(new DerivableFunctionWithLastCache(function)); this.numberOfPreviousIterationsToMemorize = numberOfPreviousIterationsToMemorize; this.convergence = convergence; } @Override public void find() { previousItrations = new LinkedList(); LineSearch lineSearch = new ArmijoLineSearch(); point = new double[function.size()]; for (int i=0;inumberOfPreviousIterationsToMemorize) { previousItrations.removeLast(); } if (previousItrations.size()>numberOfPreviousIterationsToMemorize) {throw new CrfException("BUG");} if (value>previousValue) {logger.error("LBFGS: value > previous value");} ++forLogger_iterationIndex; if (logger.isInfoEnabled()) {logger.info("LBFGS iteration "+forLogger_iterationIndex+": value = "+String.format("%-3.3f", value));} } while(Math.abs(previousValue-value)>convergence); calculated = true; } @Override public double getValue() { if (!calculated) {throw new CrfException("Not calculated.");} return value; } @Override public double[] getPoint() { if (!calculated) {throw new CrfException("Not calculated.");} return point; } private double[] twoLoopRecursion(double[] point) { ArrayList rhoList = new ArrayList(previousItrations.size()); ArrayList alphaList = new ArrayList(previousItrations.size()); double[] q = function.gradient(point); for (PointAndGradientSubstractions substractions : previousItrations) { double rho = 1.0/VectorUtilities.product(substractions.getGradientSubstraction(), substractions.getPointSubstraction()); rhoList.add(rho); double alpha = rho*VectorUtilities.product(substractions.getPointSubstraction(), q); alphaList.add(alpha); q = VectorUtilities.substractVectors(q, VectorUtilities.multiplyByScalar(alpha, substractions.getGradientSubstraction()) ); } double[] r = calculateInitial_r_forTwoLoopRecursion(q); ListIterator previousIterationsIterator = previousItrations.listIterator(previousItrations.size()); ListIterator rhoIterator = rhoList.listIterator(rhoList.size()); ListIterator alphaIterator = alphaList.listIterator(alphaList.size()); while (previousIterationsIterator.hasPrevious()&&rhoIterator.hasPrevious()&&alphaIterator.hasPrevious()) { PointAndGradientSubstractions substractions = previousIterationsIterator.previous(); double rho = rhoIterator.previous(); double alpha = alphaIterator.previous(); double betta = rho * VectorUtilities.product(substractions.getGradientSubstraction(), r); r = VectorUtilities.addVectors( r, VectorUtilities.multiplyByScalar(alpha-betta, substractions.getPointSubstraction()) ); } if ((previousIterationsIterator.hasPrevious()||rhoIterator.hasPrevious()||alphaIterator.hasPrevious())) {throw new CrfException("BUG");} return r; } private double[] calculateInitial_r_forTwoLoopRecursion(double[] q) { double gamma = 1.0; if (previousItrations.size()>=1) { PointAndGradientSubstractions lastSubstraction = previousItrations.get(0); gamma = VectorUtilities.product(lastSubstraction.getPointSubstraction(), lastSubstraction.getGradientSubstraction()) / VectorUtilities.product(lastSubstraction.getGradientSubstraction(), lastSubstraction.getGradientSubstraction()); } double[] r = VectorUtilities.multiplyByScalar(gamma, q); return r; } private final int numberOfPreviousIterationsToMemorize; // m private final double convergence; private LinkedList previousItrations; // newest is pushed to the beginning. private double[] point = null; private double value = 0.0; private boolean calculated = false; private static final Logger logger = Logger.getLogger(LbfgsMinimizer.class); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy