All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.core.ConjugateGradientOptimization Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    ConjugateGradientOptimization.java
 *    Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.core;

import java.util.Arrays;

import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

/**
 * This subclass of Optimization.java implements conjugate gradient descent
 * rather than BFGS updates, by overriding findArgmin(), with the same tests for
 * convergence, and applies the same line search code. Note that constraints are
 * NOT actually supported. Using this class instead of Optimization.java can
 * reduce runtime when there are many parameters.
 * 
 * Uses the second hybrid method proposed in "An Efficient Hybrid Conjugate
 * Gradient Method for Unconstrained Optimization" by Dai and Yuan (2001). See
 * also information in the getTechnicalInformation() method.
 * 
 * @author Eibe Frank
 * @version $Revision: 10203 $
 */
public abstract class ConjugateGradientOptimization extends Optimization
  implements RevisionHandler {

  /**
   * Returns an instance of a TechnicalInformation object, containing detailed
   * information about the technical background of this class, e.g., paper
   * reference or book this class is based on.
   * 
   * @return the technical information about this class
   */
  @Override
  public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation result;
    result = new TechnicalInformation(Type.ARTICLE);
    result.setValue(Field.AUTHOR, "Y.H. Dai and Y. Yuan");
    result.setValue(Field.YEAR, "2001");
    result
      .setValue(Field.TITLE,
        "An Efficient Hybrid Conjugate Gradient Method for Unconstrained Optimization");
    result.setValue(Field.JOURNAL, "Annals of Operations Research");
    result.setValue(Field.VOLUME, "103");
    result.setValue(Field.PAGES, "33-47");

    result.add(Type.ARTICLE);
    result.setValue(Field.AUTHOR, "W.W. Hager and H. Zhang");
    result.setValue(Field.YEAR, "2006");
    result.setValue(Field.TITLE,
      "A survey of nonlinear conjugate gradient methods");
    result.setValue(Field.JOURNAL, "Pacific Journal of Optimization");
    result.setValue(Field.VOLUME, "2");
    result.setValue(Field.PAGES, "35-58");

    return result;
  }

  /**
   * Constructor that sets MAXITS to 2000 by default and the parameter in the
   * second weak Wolfe condition to 0.1.
   */
  public ConjugateGradientOptimization() {
    setMaxIteration(2000);
    m_BETA = 0.1; // To make line search more exact, recommended for non-linear
                  // CGD
  }

  /**
   * Main algorithm. NOTE: constraints are not actually supported.
   * 
   * @param initX initial point of x, assuming no value's on the bound!
   * @param constraints both arrays must contain Double.NaN
   * @return the solution of x, null if number of iterations not enough
   * @throws Exception if an error occurs
   */
  @Override
  public double[] findArgmin(double[] initX, double[][] constraints)
    throws Exception {

    int l = initX.length;

    // Initial value of obj. function, gradient and inverse of the Hessian
    m_f = objectiveFunction(initX);
    if (Double.isNaN(m_f)) {
      throw new Exception("Objective function value is NaN!");
    }

    // Get gradient at initial point
    double[] grad = evaluateGradient(initX), oldGrad, oldX, deltaX = new double[l], direct = new double[l], x = new double[l];

    // Turn gradient into direction and calculate squared length
    double sum = 0;
    for (int i = 0; i < grad.length; i++) {
      direct[i] = -grad[i];
      sum += grad[i] * grad[i];
    }

    // Same as in Optimization.java
    double stpmax = m_STPMX * Math.max(Math.sqrt(sum), l);

    boolean[] isFixed = new boolean[initX.length];
    DynamicIntArray wsBdsIndx = new DynamicIntArray(initX.length);
    double[][] consts = new double[2][initX.length];
    for (int i = 0; i < initX.length; i++) {
      if (!Double.isNaN(constraints[0][i])
        || (!Double.isNaN(constraints[1][i]))) {
        throw new Exception("Cannot deal with constraints, sorry.");
      }
      consts[0][i] = constraints[0][i];
      consts[1][i] = constraints[1][i];
      x[i] = initX[i];
    }

    boolean finished = false;
    for (int step = 0; step < m_MAXITS; step++) {

      if (m_Debug) {
        System.err.println("\nIteration # " + step + ":");
      }

      oldX = x;
      oldGrad = grad;

      // Make a copy of direction vector because it may get modified in lnsrch
      double[] directB = Arrays.copyOf(direct, direct.length);

      // Perform a line search based on new direction
      m_IsZeroStep = false;
      x = lnsrch(x, grad, directB, stpmax, isFixed, constraints, wsBdsIndx);
      if (m_IsZeroStep) {
        throw new Exception("Exiting due to zero step.");
      }

      double test = 0.0;
      for (int h = 0; h < x.length; h++) {
        deltaX[h] = x[h] - oldX[h];
        double tmp = Math.abs(deltaX[h]) / Math.max(Math.abs(x[h]), 1.0);
        if (tmp > test) {
          test = tmp;
        }
      }
      if (test < m_Zero) {
        if (m_Debug) {
          System.err.println("\nDeltaX converged: " + test);
        }
        finished = true;
        break;
      }

      // Check zero gradient
      grad = evaluateGradient(x);
      test = 0.0;
      for (int g = 0; g < l; g++) {
        double tmp = Math.abs(grad[g]) * Math.max(Math.abs(directB[g]), 1.0)
          / Math.max(Math.abs(m_f), 1.0);
        if (tmp > test) {
          test = tmp;
        }
      }

      if (test < m_Zero) {
        if (m_Debug) {
          for (int i = 0; i < l; i++) {
            System.out.println(grad[i] + " " + directB[i] + " " + m_f);
          }
          System.err.println("Gradient converged: " + test);
        }
        finished = true;
        break;
      }

      // Calculate multiplier
      double betaHSNumerator = 0, betaDYNumerator = 0;
      double betaHSandDYDenominator = 0;
      for (int i = 0; i < grad.length; i++) {
        betaDYNumerator += grad[i] * grad[i];
        betaHSNumerator += (grad[i] - oldGrad[i]) * grad[i];
        betaHSandDYDenominator += (grad[i] - oldGrad[i]) * direct[i];
      }
      double betaHS = betaHSNumerator / betaHSandDYDenominator;
      double betaDY = betaDYNumerator / betaHSandDYDenominator;

      if (m_Debug) {
        System.err.println("Beta HS: " + betaHS);
        System.err.println("Beta DY: " + betaDY);
      }

      for (int i = 0; i < direct.length; i++) {
        direct[i] = -grad[i] + Math.max(0, Math.min(betaHS, betaDY))
          * direct[i];
      }
    }

    if (finished) {
      if (m_Debug) {
        System.err.println("Minimum found.");
      }
      m_f = objectiveFunction(x);
      if (Double.isNaN(m_f)) {
        throw new Exception("Objective function value is NaN!");
      }
      return x;
    }

    if (m_Debug) {
      System.err.println("Cannot find minimum -- too many iterations!");
    }
    m_X = x;
    return null;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy