All Downloads are FREE. Search and download functionalities are using the official Maven repository.

math.optim.CGOptimizer Maven / Gradle / Ivy

The newest version!
package math.optim;

import java.text.DecimalFormat;

import math.function.DiffMultivariateFunction;
import math.function.MultivariateFunctionResult;

/**
 * Conjugate-gradient implementation based on the code in "Numerical Recipes in
 * C" (see p. 423 and others).
 * 

* As of now, it requires a differentiable function * {@link DiffMultivariateFunction} as input. *

* The basic way to use the CGOptimizer is with the simple {@code minimize} * method: *

* DiffMultivariateFunction dmf = new SomeDiffMultivariateFunction(); *
* double[] initial = getInitialGuess();
* MultivariateFunctionResult minimum = CGOptimizer.minimize(dmf, initial); * * @author Dan Klein */ public final class CGOptimizer { private static final DecimalFormat NF = new DecimalFormat("0.000E0"); private static final boolean SIMPLE_GD = false; private static final boolean CHECK_SIMPLE_GD_CONVERGENCE = true; private static final boolean VERBOSE = false; // constants private static final double GOLD = 1.618034; private static final double GLIMIT = 100.0; private static final double TINY = 1.0e-20; // overridden in dbrent() private static final int ITMAX = 10001; private static final double EPS = 1.0e-30; private static final int RESET_FREQ = 10; // default function tolerance private static final double FUNC_DEFAULT_TOL = 1e-10; static final class Minimand implements DiffMultivariateFunction { private final DiffMultivariateFunction f; Minimand(DiffMultivariateFunction f) { this.f = f; } @Override public double valueAt(double[] x) { return -f.valueAt(x); } @Override public void derivativeAt(double[] x, double[] grad) { f.derivativeAt(x, grad); for (int i = 0; i < grad.length; ++i) { grad[i] = -grad[i]; } } } // Minimand static final class OneDimDiffFunction { private final DiffMultivariateFunction function; private final double[] initial; private final double[] direction; private final double[] currVector; private final double[] currGradient; OneDimDiffFunction(DiffMultivariateFunction function, double[] initial, double[] direction) { this.function = function; this.initial = initial.clone(); this.direction = direction.clone(); this.currVector = new double[initial.length]; this.currGradient = new double[initial.length]; } double[] vectorOf(double x) { for (int i = 0; i < initial.length; i++) { currVector[i] = initial[i] + (x * direction[i]); } return currVector; } double valueAt(double x) { return function.valueAt(vectorOf(x)); } double derivativeAt(double x) { function.derivativeAt(vectorOf(x), currGradient); double d = 0.0; for (int i = 0; i < currGradient.length; i++) { d += currGradient[i] * direction[i]; } return d; } } // OneDimDiffFunction private CGOptimizer() { } public static MultivariateFunctionResult maximize( DiffMultivariateFunction function, double[] initial) { return maximize(function, FUNC_DEFAULT_TOL, initial); } public static MultivariateFunctionResult maximize( DiffMultivariateFunction function, double functionTolerance, double[] initial) { return maximize(function, functionTolerance, initial, ITMAX); } public static MultivariateFunctionResult maximize( DiffMultivariateFunction function, double functionTolerance, double[] initial, int maxIterations) { return maximize(function, functionTolerance, initial, maxIterations, true); } public static MultivariateFunctionResult maximize( DiffMultivariateFunction function, double functionTolerance, double[] initial, int maxIterations, boolean silent) { return minimize(new Minimand(function), functionTolerance, initial, maxIterations, silent, true); } public static MultivariateFunctionResult minimize( DiffMultivariateFunction function, double[] initial) { return minimize(function, FUNC_DEFAULT_TOL, initial); } public static MultivariateFunctionResult minimize( DiffMultivariateFunction function, double functionTolerance, double[] initial) { return minimize(function, functionTolerance, initial, ITMAX); } public static MultivariateFunctionResult minimize( DiffMultivariateFunction function, double functionTolerance, double[] initial, int maxIterations) { return minimize(function, functionTolerance, initial, maxIterations, true); } public static MultivariateFunctionResult minimize( DiffMultivariateFunction function, double functionTolerance, double[] initial, int maxIterations, boolean silent) { return minimize(function, functionTolerance, initial, maxIterations, silent, false); } private static MultivariateFunctionResult minimize( DiffMultivariateFunction function, double functionTolerance, double[] initial, int maxIterations, boolean silent, boolean isMaximization) { int dimension = initial.length; double sign = isMaximization ? -1.0 : 1.0; // evaluate function double fp = function.valueAt(initial); if (VERBOSE) { System.err.println("Initial: " + fp); } double[] xi = new double[dimension]; function.derivativeAt(initial, xi); if (VERBOSE) { System.err.println("Initial at: " + arrayToString(initial)); System.err.println("Initial deriv: " + arrayToString(xi)); } // make some vectors double[] g = new double[dimension]; double[] h = new double[dimension]; double[] p = new double[dimension]; double[] bracketing = new double[3]; for (int j = 0; j < dimension; j++) { g[j] = -xi[j]; xi[j] = g[j]; h[j] = g[j]; p[j] = initial[j]; } // iterations boolean simpleGDStep = false; int iter = 1; while (iter < maxIterations) { if (!silent) { System.err.print("Iter " + iter + ' '); } // do a line min along descent direction if (VERBOSE) { System.err.println("Minimizing along " + arrayToString(xi)); } double[] p2 = lineMinimize(function, p, xi, bracketing); double fp2 = function.valueAt(p2); if (VERBOSE) { System.err.println("Result is " + fp2 + " after " + iter); System.err.println("Result at " + arrayToString(p2)); } if (!silent) { System.err.printf(" %s (delta: %s)\n", NF.format(fp2), NF.format(fp - fp2)); } // check convergence if (2.0 * fabs(fp2 - fp) <= functionTolerance * (fabs(fp2) + fabs(fp) + EPS)) { // convergence if (!CHECK_SIMPLE_GD_CONVERGENCE || simpleGDStep || SIMPLE_GD) { return new MultivariateFunctionResult(p, sign * fp2, iter); } simpleGDStep = true; } else { simpleGDStep = false; } // shift variables for (int j = 0; j < dimension; j++) { xi[j] = p2[j] - p[j]; p[j] = p2[j]; } fp = fp2; // find the new gradient function.derivativeAt(p, xi); if (!simpleGDStep && !SIMPLE_GD && (iter % RESET_FREQ != 0)) { // do the magic -- part i) // (calculate some dot products we'll need) double dgg = 0.0; double gg = 0.0; for (int j = 0; j < dimension; j++) { // g dot g gg += g[j] * g[j]; // grad dot grad // FR method is: // dgg += x[j]*x[j]; // PR method is: dgg += (xi[j] + g[j]) * xi[j]; } // check for miraculous convergence if (gg == 0.0) { return new MultivariateFunctionResult(p, sign * function.valueAt(p), iter); } // do the magic -- part ii) // (update the sequence in a way that tries to preserve // conjugacy) double gam = dgg / gg; for (int j = 0; j < dimension; j++) { g[j] = -xi[j]; h[j] = g[j] + gam * h[j]; xi[j] = h[j]; } } else { // miraculous simpleGD convergence double xixi = 0.0; for (int j = 0; j < dimension; j++) { xixi += xi[j] * xi[j]; } // reset cgd for (int j = 0; j < dimension; j++) { g[j] = -xi[j]; xi[j] = g[j]; h[j] = g[j]; } if (xixi == 0.0) { return new MultivariateFunctionResult(p, sign * function.valueAt(p), iter); } } ++iter; } // while // too many iterations System.err.println("Warning: exiting minimize because ITER exceeded!"); return new MultivariateFunctionResult(p, sign * function.valueAt(p), iter, false); } private static double[] lineMinimize(DiffMultivariateFunction function, double[] initial, double[] direction, double[] bracketing) { // make a 1-dim function along the direction line // THIS IS A HACK (but it's the NRiC peoples' hack) OneDimDiffFunction oneDim = new OneDimDiffFunction(function, initial, direction); // do a 1-dim line min on this function // bracket the extreme point double guess = 0.01; bracketing[0] = 0.0; bracketing[1] = guess; bracketing[2] = 0.0; mnbrak(bracketing, oneDim); double ax = bracketing[0]; double xx = bracketing[1]; double bx = bracketing[2]; // CHECK FOR END OF WORLD if (!(ax <= xx && xx <= bx) && !(bx <= xx && xx <= ax)) { System.err.println("Bad bracket order!"); } if (VERBOSE) { System.err.println("Bracketing found: " + ax + ' ' + xx + ' ' + bx); System.err.println("Bracketing found: " + oneDim.valueAt(ax) + ' ' + oneDim.valueAt(xx) + ' ' + oneDim.valueAt(bx)); } // find the extreme point double xmin = dbrent(oneDim, ax, xx, bx); // return the full vector return oneDim.vectorOf(xmin); } private static void mnbrak(double[] bracketing, OneDimDiffFunction func) { // inputs double ax = bracketing[0]; double fa = func.valueAt(ax); double bx = bracketing[1]; double fb = func.valueAt(bx); if (fb > fa) { // swap double tmp = fa; fa = fb; fb = tmp; tmp = ax; ax = bx; bx = tmp; } // guess cx double cx = bx + GOLD * (bx - ax); double fc = func.valueAt(cx); // loop until we get a bracket while (fb > fc) { double r = (bx - ax) * (fb - fc); double q = (bx - cx) * (fb - fa); double u = bx - ((bx - cx) * q - (bx - ax) * r) / (2.0 * sign(fmax(fabs(q - r), TINY), q - r)); double fu; double ulim = bx + GLIMIT * (cx - bx); if ((bx - u) * (u - cx) > 0.0) { fu = func.valueAt(u); if (fu < fc) { // Ax = new Double(bx); // Bx = new Double(u); // Cx = new Double(cx); bracketing[0] = bx; bracketing[1] = u; bracketing[2] = cx; return; } else if (fu > fb) { // Ax = new Double(ax); // Bx = new Double(bx); // Cx = new Double(u); bracketing[0] = ax; bracketing[1] = bx; bracketing[2] = u; return; } u = cx + GOLD * (cx - bx); fu = func.valueAt(u); } else if ((cx - u) * (u - ulim) > 0.0) { fu = func.valueAt(u); if (fu < fc) { bx = cx; cx = u; u = cx + GOLD * (cx - bx); fb = fc; fc = fu; fu = func.valueAt(u); } } else if ((u - ulim) * (ulim - cx) >= 0.0) { u = ulim; fu = func.valueAt(u); } else { u = cx + GOLD * (cx - bx); fu = func.valueAt(u); } ax = bx; bx = cx; cx = u; fa = fb; fb = fc; fc = fu; } // Ax = new Double(ax); // Bx = new Double(bx); // Cx = new Double(cx); bracketing[0] = ax; bracketing[1] = bx; bracketing[2] = cx; } private static double dbrent(OneDimDiffFunction func, double ax, double bx, double cx) { // constants final boolean dbVerbose = false; final int ITMAX = 100; final double TOL = 1.0e-4; double d = 0.0, e = 0.0; double a = (ax < cx ? ax : cx); double b = (ax > cx ? ax : cx); double x = bx; double v = bx; double w = bx; double fx = func.valueAt(x); double fv = fx; double fw = fx; double dx = func.derivativeAt(x); double dv = dx; double dw = dx; for (int iteration = 0; iteration < ITMAX; iteration++) { double xm = 0.5 * (a + b); double tol1 = TOL * fabs(x); double tol2 = 2.0 * tol1; if (fabs(x - xm) <= (tol2 - 0.5 * (b - a))) { if (dbVerbose) { System.err .println("dbrent returning because min is cornered"); } return x; } double u; if (fabs(e) > tol1) { double d1 = 2.0 * (b - a); double d2 = d1; if (dw != dx) { d1 = (w - x) * dx / (dx - dw); } if (dv != dx) { d2 = (v - x) * dx / (dx - dv); } double u1 = x + d1; double u2 = x + d2; boolean ok1 = ((a - u1) * (u1 - b) > 0.0 && dx * d1 <= 0.0); boolean ok2 = ((a - u2) * (u2 - b) > 0.0 && dx * d2 <= 0.0); double olde = e; e = d; if (ok1 || ok2) { if (ok1 && ok2) { d = (fabs(d1) < fabs(d2) ? d1 : d2); } else if (ok1) { d = d1; } else { d = d2; } if (fabs(d) <= fabs(0.5 * olde)) { u = x + d; if (u - a < tol2 || b - u < tol2) { d = sign(tol1, xm - x); } } else { e = (dx >= 0.0 ? a - x : b - x); d = 0.5 * e; } } else { e = (dx >= 0.0 ? a - x : b - x); d = 0.5 * e; } } else { e = (dx >= 0.0 ? a - x : b - x); d = 0.5 * e; } double fu; if (fabs(d) >= tol1) { u = x + d; fu = func.valueAt(u); } else { u = x + sign(tol1, d); fu = func.valueAt(u); if (fu > fx) { if (dbVerbose) { System.err .println("dbrent returning because derivative is broken"); } return x; } } double du = func.derivativeAt(u); if (fu <= fx) { if (u >= x) { a = x; } else { b = x; } v = w; fv = fw; dv = dw; w = x; fw = fx; dw = dx; x = u; fx = fu; dx = du; } else { if (u < x) { a = u; } else { b = u; } if (fu <= fw || w == x) { v = w; fv = fw; dv = dw; w = u; fw = fu; dw = du; } else if (fu < fv || v == x || v == w) { v = u; fv = fu; dv = du; } } } // Dan's addition: if (fx < func.valueAt(0.0)) { return x; } if (dbVerbose) { System.err .println("Warning: exiting dbrent because ITMAX exceeded!"); } return 0.0; } private static String arrayToString(double[] x) { int numToPrint = 5; StringBuilder sb = new StringBuilder("("); if (numToPrint > x.length) { numToPrint = x.length; } for (int i = 0; i < numToPrint; i++) { sb.append(x[i]); if (i != x.length - 1) { sb.append(", "); } } if (numToPrint < x.length) { sb.append("..."); } sb.append(')'); return sb.toString(); } private static double fabs(double x) { if (x < 0.0) { return -x; } return x; } private static double fmax(double x, double y) { if (x < y) { return y; } return x; } private static double sign(double x, double y) { if (y >= 0.0) { return fabs(x); } return -fabs(x); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy