All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.asher_stern.crf.function.optimization.LbfgsMinimizer Maven / Gradle / Ivy

Go to download

Implementation of linear-chain Conditional Random Fields (CRF) in pure Java

The newest version!
package com.asher_stern.crf.function.optimization;

import java.math.BigDecimal;
import java.util.ArrayList;
import static com.asher_stern.crf.utilities.ArithmeticUtilities.*;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.ListIterator;

import org.apache.log4j.Logger;

import com.asher_stern.crf.function.DerivableFunction;
import com.asher_stern.crf.utilities.CrfException;
import com.asher_stern.crf.utilities.DerivableFunctionWithLastCache;
import com.asher_stern.crf.utilities.StringUtilities;
import com.asher_stern.crf.utilities.VectorUtilities;

/**
 * Implementation of L-BFGS algorithm for minimizing a function.
 * 
* L-BFGS stands for "Limited memory BFGS", where "BFGS" is an acronym of * "Broyden Fletcher Goldfarb Shanno" who developed the BFGS algorithm. *
* The BFGS algorithm approximates Newton method for optimization, by approximating * the inverse of the Hessian without calculating the exact Hessian. * The L-BFGS algorithm approximates the BFGS algorithm by approximating calculations * that are performed with the inverse of the Hessian, but stores neither the * inverse of the Hessian nor its approximation in the memory. * Thus the L-BFGS algorithm is much cheaper in space complexity notion. *
* The L-BFGS algorithm is described in the book "Numerical Optimization" by Jorge Nocedal and Stephen J. Wright, * Chapter 9. The book can be downloaded from http://www.bioinfo.org.cn/~wangchao/maa/Numerical_Optimization.pdf * * * * @author Asher Stern * Date: Nov 7, 2014 * */ public class LbfgsMinimizer extends Minimizer { public static final int DEFAULT_NUMBER_OF_PREVIOUS_ITERATIONS_TO_MEMORIZE = 20; public static final BigDecimal DEFAULT_GRADIENT_CONVERGENCE = big(0.01); public LbfgsMinimizer(DerivableFunction function) { this(function,DEFAULT_NUMBER_OF_PREVIOUS_ITERATIONS_TO_MEMORIZE, DEFAULT_GRADIENT_CONVERGENCE); } public LbfgsMinimizer(DerivableFunction function, int numberOfPreviousIterationsToMemorize, BigDecimal convergence) { // super(function); super(new DerivableFunctionWithLastCache(function)); this.numberOfPreviousIterationsToMemorize = numberOfPreviousIterationsToMemorize; this.convergence = convergence; this.convergenceSquare = safeMultiply(this.convergence,this.convergence); } public void setInitialPoint(BigDecimal[] initialPoint) { if (initialPoint.length!=function.size()) throw new CrfException("Wrong length of initial point specified by the caller."); this.initialPoint = initialPoint; } public void setDebugInfo(DebugInfo debugInfo) { this.debugInfo = debugInfo; } @Override public void find() { previousItrations = new LinkedList(); LineSearch lineSearch = new ArmijoLineSearch(); initializeInitialPoint(); value = function.value(point); if (logger.isInfoEnabled()) {logger.info("LBFGS: initial value = "+StringUtilities.bigDecimalToString(value));} BigDecimal[] gradient = function.gradient(point); BigDecimal previousValue = value; int forLogger_iterationIndex=0; while (VectorUtilities.euclideanNormSquare(gradient).compareTo(convergenceSquare)>0) { if (logger.isDebugEnabled()) {logger.debug(String.format("Gradient norm square = %s", StringUtilities.bigDecimalToString(VectorUtilities.euclideanNormSquare(gradient)) ));} previousValue = value; BigDecimal[] previousPoint = Arrays.copyOf(point, point.length); BigDecimal[] previousGradient = Arrays.copyOf(gradient, gradient.length); // 1. Update point (which is the vector "x"). BigDecimal[] direction = VectorUtilities.multiplyByScalar(BigDecimal.ONE.negate(), twoLoopRecursion(point)); BigDecimal alpha_rate = lineSearch.findRate(function, point, direction); point = VectorUtilities.addVectors(point, VectorUtilities.multiplyByScalar(alpha_rate, direction)); // 2. Prepare next iteration value = function.value(point); gradient = function.gradient(point); previousItrations.add(new PointAndGradientSubstractions(VectorUtilities.subtractVectors(point, previousPoint), VectorUtilities.subtractVectors(gradient, previousGradient))); if (previousItrations.size()>numberOfPreviousIterationsToMemorize) { previousItrations.removeFirst(); } if (previousItrations.size()>numberOfPreviousIterationsToMemorize) {throw new CrfException("BUG");} // 3. Print log messages ++forLogger_iterationIndex; if (value.compareTo(previousValue)>0) {logger.error("LBFGS: value > previous value");} if (logger.isInfoEnabled()) {logger.info("LBFGS iteration "+forLogger_iterationIndex+": value = "+StringUtilities.bigDecimalToString(value) );} if ( (debugInfo!=null) && (logger.isInfoEnabled()) ) { logger.info(debugInfo.info(point)); } } calculated = true; } @Override public BigDecimal getValue() { if (!calculated) {throw new CrfException("Not calculated.");} return value; } @Override public BigDecimal[] getPoint() { if (!calculated) {throw new CrfException("Not calculated.");} return point; } public static interface DebugInfo { public String info(BigDecimal[] point); } private void initializeInitialPoint() { point = new BigDecimal[function.size()]; if (this.initialPoint==null) { for (int i=0;i rhoList = new ArrayList(previousItrations.size()); ArrayList alphaList = new ArrayList(previousItrations.size()); BigDecimal[] q = function.gradient(point); for (PointAndGradientSubstractions substractions : previousItrations) { BigDecimal rho = safeDivide(BigDecimal.ONE, VectorUtilities.product(substractions.getGradientSubstraction(), substractions.getPointSubstraction())); rhoList.add(rho); BigDecimal alpha = safeMultiply(rho, VectorUtilities.product(substractions.getPointSubstraction(), q)); alphaList.add(alpha); q = VectorUtilities.subtractVectors(q, VectorUtilities.multiplyByScalar(alpha, substractions.getGradientSubstraction()) ); } BigDecimal[] r = calculateInitial_r_forTwoLoopRecursion(q); ListIterator previousIterationsIterator = previousItrations.listIterator(previousItrations.size()); ListIterator rhoIterator = rhoList.listIterator(rhoList.size()); ListIterator alphaIterator = alphaList.listIterator(alphaList.size()); while (previousIterationsIterator.hasPrevious()&&rhoIterator.hasPrevious()&&alphaIterator.hasPrevious()) { PointAndGradientSubstractions substractions = previousIterationsIterator.previous(); BigDecimal rho = rhoIterator.previous(); BigDecimal alpha = alphaIterator.previous(); BigDecimal beta = safeMultiply(rho, VectorUtilities.product(substractions.getGradientSubstraction(), r)); r = VectorUtilities.addVectors( r, VectorUtilities.multiplyByScalar(safeSubtract(alpha,beta) , substractions.getPointSubstraction()) ); } if ((previousIterationsIterator.hasPrevious()||rhoIterator.hasPrevious()||alphaIterator.hasPrevious())) {throw new CrfException("BUG");} return r; } private BigDecimal[] calculateInitial_r_forTwoLoopRecursion(BigDecimal[] q) { BigDecimal gamma = BigDecimal.ONE; if (previousItrations.size()>=1) { PointAndGradientSubstractions lastSubstraction = previousItrations.get(0); gamma = safeDivide( VectorUtilities.product(lastSubstraction.getPointSubstraction(), lastSubstraction.getGradientSubstraction()) , VectorUtilities.product(lastSubstraction.getGradientSubstraction(), lastSubstraction.getGradientSubstraction()) ); } BigDecimal[] r = VectorUtilities.multiplyByScalar(gamma, q); return r; } // input private final int numberOfPreviousIterationsToMemorize; // m private final BigDecimal convergence; private final BigDecimal convergenceSquare; private BigDecimal[] initialPoint = null; private DebugInfo debugInfo = null; // internals private LinkedList previousItrations; // newest is pushed to the beginning. private boolean calculated = false; // output private BigDecimal[] point = null; private BigDecimal value = BigDecimal.ZERO; private static final Logger logger = Logger.getLogger(LbfgsMinimizer.class); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy