com.asher_stern.crf.function.optimization.GradientDescentOptimizer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of CRF Show documentation
Show all versions of CRF Show documentation
Implementation of linear-chain Conditional Random Fields (CRF) in pure Java
package com.asher_stern.crf.function.optimization;
import org.apache.log4j.Logger;
import com.asher_stern.crf.function.DerivableFunction;
import com.asher_stern.crf.utilities.CrfException;
import com.asher_stern.crf.utilities.StringUtilities;
import com.asher_stern.crf.utilities.VectorUtilities;
/**
* A {@link Minimizer} which updates the function's input by moving it along the negation of its
* gradient.
* This method is simple but inefficient.
*
* @author Asher Stern
* Date: Nov 6, 2014
*
*/
public class GradientDescentOptimizer extends Minimizer
{
public static final double DEFAULT_RATE = 0.01;
public static final double DEFAULT_CONVERGENCE_THRESHOLD = 0.0001;
public GradientDescentOptimizer(DerivableFunction function)
{
this(function,DEFAULT_RATE,DEFAULT_CONVERGENCE_THRESHOLD);
}
public GradientDescentOptimizer(DerivableFunction function,double rate,double convergenceThreshold)
{
super(function);
this.rate = rate;
this.convergenceThreshold = convergenceThreshold;
}
@Override
public void find()
{
// LineSearch lineSearch = new ConstantLineSearch(rate);
LineSearch lineSearch = new ArmijoLineSearch();
int size = function.size();
point = new double[size];
for (int i=0;iconvergenceThreshold);
if (logger.isDebugEnabled()){logger.debug("Gradient-descent: number of iterations: "+debug_iterationIndex);}
calculated = true;
}
@Override
public double getValue()
{
if (!calculated) throw new CrfException("Not calculated");
return value;
}
@Override
public double[] getPoint()
{
if (!calculated) throw new CrfException("Not calculated");
return point;
}
@SuppressWarnings("unused")
private final double rate;
private final double convergenceThreshold;
private boolean calculated = false;
private double value;
private double[] point;
private static final Logger logger = Logger.getLogger(GradientDescentOptimizer.class);
}