All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.asher_stern.crf.function.optimization.GradientDescentOptimizer Maven / Gradle / Ivy

Go to download

Implementation of linear-chain Conditional Random Fields (CRF) in pure Java

The newest version!
package com.asher_stern.crf.function.optimization;

import static com.asher_stern.crf.utilities.ArithmeticUtilities.big;
import static com.asher_stern.crf.utilities.ArithmeticUtilities.safeAdd;
import static com.asher_stern.crf.utilities.ArithmeticUtilities.safeMultiply;
import static com.asher_stern.crf.utilities.ArithmeticUtilities.safeSubtract;

import java.math.BigDecimal;

import org.apache.log4j.Logger;

import com.asher_stern.crf.function.DerivableFunction;
import com.asher_stern.crf.utilities.CrfException;
import com.asher_stern.crf.utilities.StringUtilities;
import com.asher_stern.crf.utilities.VectorUtilities;

/**
 * A {@link Minimizer} which updates the function's input by moving it along the negation of its
 * gradient.
 * This method is simple but  inefficient.
 * 
 * @author Asher Stern
 * Date: Nov 6, 2014
 *
 */
public class GradientDescentOptimizer extends Minimizer
{
	public static final BigDecimal DEFAULT_RATE = big(0.01);
	public static final BigDecimal DEFAULT_CONVERGENCE_THRESHOLD = big(0.0001);
	
	/**
	 * Constructor with default convergence threshold. See {@link GradientDescentOptimizer#GradientDescentOptimizer(DerivableFunction, double, double)}.
	 * @param function
	 */
	public GradientDescentOptimizer(DerivableFunction function)
	{
		this(function,DEFAULT_RATE,DEFAULT_CONVERGENCE_THRESHOLD);
	}
	
	/**
	 * Constructor with convergence threshold.
	 * @param function the function to optimize (find its minimum).
	 * @param rate Not used in this implementation. Rate is a coefficient by which the gradient is multiplied at
	 * each step of the gradient descent. However, a more advanced technique is to use Armijo line search
	 * (see {@link ArmijoLineSearch}) which find a good rate automatically.
	 * For developers: it is possible to change the code and use {@link ConstantLineSearch}, and use this given rate. 
	 * @param convergenceThreshold the convergence threshold, which is the maximum allowed gap between the result of this
	 * optimizer and the real optimum (i.e., the optimizer might return a result which is only "close enough" to the optimum,
	 * while being slightly different from the real optimum).
	 */
	public GradientDescentOptimizer(DerivableFunction function,BigDecimal rate,BigDecimal convergenceThreshold)
	{
		super(function);
		this.rate = rate;
		this.convergenceThreshold = convergenceThreshold;
	}
	
	
	@Override
	public void find()
	{
//		LineSearch lineSearch = new ConstantLineSearch(rate);
		LineSearch lineSearch = new ArmijoLineSearch();
		
		int size = function.size();
		point = new BigDecimal[size];
		for (int i=0;i 0);
		if (logger.isDebugEnabled()){logger.debug("Gradient-descent: number of iterations: "+debug_iterationIndex);}
		calculated = true;
		
	}
	
	
	@Override
	public BigDecimal getValue()
	{
		if (!calculated) throw new CrfException("Not calculated");
		return value;
	}
	@Override
	public BigDecimal[] getPoint()
	{
		if (!calculated) throw new CrfException("Not calculated");
		return point;
	}
	
	
	public static final void singleStepUpdate(final int size, final BigDecimal[] point, BigDecimal[] gradient, final BigDecimal rate)
	{
		// size must be equal to point.length 
		for (int i=0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy