org.deeplearning4j.util.NonZeroStoppingConjugateGradient Maven / Gradle / Ivy
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum [email protected]
*/
package org.deeplearning4j.util;
import org.deeplearning4j.optimize.NeuralNetEpochListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import cc.mallet.optimize.BackTrackLineSearch;
import cc.mallet.optimize.LineOptimizer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.optimize.OptimizerEvaluator;
import cc.mallet.types.MatrixOps;
/**
* Modified based on cc.mallet.optimize.ConjugateGradient
* no termination when zero tolerance
*
* @author GuoDing
* @since 2013-08-25
*/
// Conjugate Gradient, Polak and Ribiere version
// from "Numeric Recipes in C", Section 10.6.
public class NonZeroStoppingConjugateGradient implements Optimizer {
private static Logger logger = LoggerFactory.getLogger(NonZeroStoppingConjugateGradient.class);
boolean converged = false;
Optimizable.ByGradientValue optimizable;
LineOptimizer.ByGradient lineMaximizer;
double initialStepSize = 1;
double tolerance = 0.001;
double gradientTolerance = 0.001;
int maxIterations = 1000;
private String myName = "";
private NeuralNetEpochListener listener;
// "eps" is a small number to recitify the special case of converging
// to exactly zero function value
final double eps = 1.0e-10;
private OptimizerEvaluator.ByGradient eval;
public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue function, double initialStepSize) {
this.initialStepSize = initialStepSize;
this.optimizable = function;
this.lineMaximizer = new BackTrackLineSearch(function);
BackTrackLineSearch l = (BackTrackLineSearch) this.lineMaximizer;
l.setAbsTolx(eps);
// Alternative:
//this.lineMaximizer = new GradientBracketLineOptimizer (function);
}
public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue function,NeuralNetEpochListener listener) {
this(function, 0.01);
this.listener = listener;
}
public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue function, double initialStepSize,NeuralNetEpochListener listener) {
this(function,initialStepSize);
this.listener = listener;
}
public NonZeroStoppingConjugateGradient(Optimizable.ByGradientValue function) {
this(function, 0.01);
}
public Optimizable getOptimizable() {
return this.optimizable;
}
public boolean isConverged() {
return converged;
}
public void setEvaluator(OptimizerEvaluator.ByGradient eval) {
this.eval = eval;
}
public void setLineMaximizer(LineOptimizer.ByGradient lineMaximizer) {
this.lineMaximizer = lineMaximizer;
}
public void setInitialStepSize(double initialStepSize) {
this.initialStepSize = initialStepSize;
}
public double getInitialStepSize() {
return this.initialStepSize;
}
public double getStepSize() {
return step;
}
// The state of a conjugate gradient search
double fp, gg, gam, dgg, step, fret;
double[] xi, g, h;
int j, iterations;
public boolean optimize() {
return optimize(maxIterations);
}
public void setTolerance(double t) {
tolerance = t;
}
public boolean optimize(int numIterations) {
myName = Thread.currentThread().getName();
if (converged)
return true;
int n = optimizable.getNumParameters();
long last = System.currentTimeMillis();
if (xi == null) {
fp = optimizable.getValue();
xi = new double[n];
g = new double[n];
h = new double[n];
optimizable.getValueGradient(xi);
System.arraycopy(xi, 0, g, 0, n);
System.arraycopy(xi, 0, h, 0, n);
step = initialStepSize;
iterations = 0;
}
long curr = 0;
for (int iterationCount = 0; iterationCount < numIterations; iterationCount++) {
curr = System.currentTimeMillis();
if(listener != null)
listener.epochDone(iterationCount);
logger.info(myName + " ConjugateGradient: At iteration " + iterations + ", cost = " + fp + " -"
+ (curr - last));
last = curr;
try {
step = lineMaximizer.optimize(xi, step);
} catch (Throwable e) {
logger.info(e.getMessage());
}
fret = optimizable.getValue();
optimizable.getValueGradient(xi);
// This termination provided by "Numeric Recipes in C".
if ((0 < tolerance) && (2.0 * Math.abs(fret - fp) <= tolerance * (Math.abs(fret) + Math.abs(fp) + eps))) {
logger.info("ConjugateGradient converged: old value= " + fp + " new value= " + fret + " tolerance="
+ tolerance);
converged = true;
return true;
}
fp = fret;
// This termination provided by McCallum
double twoNorm = MatrixOps.twoNorm(xi);
if (twoNorm < gradientTolerance) {
logger.info("ConjugateGradient converged: gradient two norm " + twoNorm + ", less than "
+ gradientTolerance);
converged = true;
return true;
}
dgg = gg = 0.0;
for (j = 0; j < xi.length; j++) {
gg += g[j] * g[j];
dgg += xi[j] * (xi[j] - g[j]);
}
gam = dgg / gg;
for (j = 0; j < xi.length; j++) {
g[j] = xi[j];
h[j] = xi[j] + gam * h[j];
}
assert (!MatrixOps.isNaN(h));
// gdruck
// Mallet line search algorithms stop search whenever
// a step is found that increases the value significantly.
// ConjugateGradient assumes that line maximization finds something
// close
// to the maximum in that direction. In tests, sometimes the
// direction suggested by CG was downhill. Consequently, here I am
// setting the search direction to the gradient if the slope is
// negative or 0.
if (MatrixOps.dotProduct(xi, h) > 0) {
MatrixOps.set(xi, h);
} else {
logger.warn("Reverting back to GA");
MatrixOps.set(h, xi);
}
iterations++;
if (iterations > maxIterations) {
logger.info("Too many iterations in ConjugateGradient.java");
converged = true;
return true;
}
if (eval != null) {
eval.evaluate(optimizable, iterations);
}
}
return false;
}
public void reset() {
xi = null;
}
public int getMaxIterations() {
return maxIterations;
}
public void setMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy