All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.asher_stern.crf.function.optimization.GradientDescentOptimizer Maven / Gradle / Ivy

Go to download

Implementation of linear-chain Conditional Random Fields (CRF) in pure Java

There is a newer version: 1.2.0
Show newest version
package com.asher_stern.crf.function.optimization;

import org.apache.log4j.Logger;

import com.asher_stern.crf.function.DerivableFunction;
import com.asher_stern.crf.utilities.CrfException;
import com.asher_stern.crf.utilities.StringUtilities;
import com.asher_stern.crf.utilities.VectorUtilities;

/**
 * A {@link Minimizer} which updates the function's input by moving it along the negation of its
 * gradient.
 * This method is simple but  inefficient.
 * 
 * @author Asher Stern
 * Date: Nov 6, 2014
 *
 */
public class GradientDescentOptimizer extends Minimizer
{
	public static final double DEFAULT_RATE = 0.01;
	public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.0001;
	
	public GradientDescentOptimizer(DerivableFunction function)
	{
		this(function,DEFAULT_RATE,DEFAULT_CONVERGENCE_THRESHOLD);
	}
	
	public GradientDescentOptimizer(DerivableFunction function,double rate,double convergenceThreshold)
	{
		super(function);
		this.rate = rate;
		this.convergenceThreshold = convergenceThreshold;
	}
	
	
	@Override
	public void find()
	{
//		LineSearch lineSearch = new ConstantLineSearch(rate);
		LineSearch lineSearch = new ArmijoLineSearch();
		
		int size = function.size();
		point = new double[size];
		for (int i=0;iconvergenceThreshold);
		if (logger.isDebugEnabled()){logger.debug("Gradient-descent: number of iterations: "+debug_iterationIndex);}
		calculated = true;
		
	}
	
	
	@Override
	public double getValue()
	{
		if (!calculated) throw new CrfException("Not calculated");
		return value;
	}
	@Override
	public double[] getPoint()
	{
		if (!calculated) throw new CrfException("Not calculated");
		return point;
	}
	
	@SuppressWarnings("unused")
	private final double rate;
	private final double convergenceThreshold;
	
	private boolean calculated = false;
	private double value;
	private double[] point;
	
	
	private static final Logger logger = Logger.getLogger(GradientDescentOptimizer.class);
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy