All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.optimization.OptimizationUtils Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.optimization;


import water.Iced;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.Arrays;

/**
 * Created by tomasnykodym on 9/29/15.
 */
public class OptimizationUtils {

  public static class GradientInfo extends Iced {
    public double _objVal;
    public double [] _gradient;

    public GradientInfo(double objVal, double [] grad){
      _objVal = objVal;
      _gradient = grad;
    }

    public boolean isValid(){
      if(Double.isNaN(_objVal))
        return false;
      return !ArrayUtils.hasNaNsOrInfs(_gradient);
    }
    @Override
    public String toString(){
      return " objVal = " + _objVal + ", " + Arrays.toString(_gradient);
    }

  }

  /**
   *  Provides ginfo computation and line search evaluation specific to given problem.
   *  Typically just a wrapper around MRTask calls.
   */
  public interface GradientSolver {
    /**
     * Evaluate ginfo at solution beta.
     * @param beta
     * @return
     */
    GradientInfo getGradient(double [] beta);
    GradientInfo getObjective(double [] beta);
  }


  public interface LineSearchSolver {
    boolean evaluate(double [] direction);
    double step();
    GradientInfo ginfo();
    LineSearchSolver setInitialStep(double s);
    int nfeval();
    double getObj();
    double[] getX();
  }

  public static final class SimpleBacktrackingLS implements LineSearchSolver {
    private double [] _beta;
    final double _stepDec = .33;
    private double _step;
    private final GradientSolver _gslvr;
    private GradientInfo _ginfo; // gradient info excluding l1 penalty
    private double _objVal; // objective including l1 penalty
    final double _l1pen;
    int _maxfev = 20;
    double _minStep = 1e-4;


    public SimpleBacktrackingLS(GradientSolver gslvr, double [] betaStart, double l1pen) {
      this(gslvr, betaStart, l1pen, gslvr.getObjective(betaStart));
    }
    public SimpleBacktrackingLS(GradientSolver gslvr, double [] betaStart, double l1pen, GradientInfo ginfo) {
      _gslvr = gslvr;
      _beta = betaStart;
      _ginfo = ginfo;
      _l1pen = l1pen;
      _objVal = _ginfo._objVal + _l1pen * ArrayUtils.l1norm(_beta,true);  // gam smoothness should have been added
    }
    public int nfeval() {return -1;}

    @Override
    public double getObj() {return _objVal;}

    @Override
    public double[] getX() {return _beta;}

    public LineSearchSolver setInitialStep(double s){
      return this;
    }

    @Override
    public boolean evaluate(double[] direction) {
      double step = 1;
      double minStep = 1;
      for(double d:direction) {
        d = Math.abs(1e-4/d);
        if(d < minStep) minStep = d;
      }
      double [] newBeta = direction.clone();
      for(int i = 0; i < _maxfev && step >= minStep; ++i, step*= _stepDec) {
        GradientInfo ginfo = _gslvr.getObjective(ArrayUtils.wadd(_beta,direction,newBeta,step));
        double objVal = ginfo._objVal + _l1pen * ArrayUtils.l1norm(newBeta,true);
        if(objVal < _objVal){
          _ginfo = ginfo;
          _objVal = objVal;
          _beta = newBeta;
          _step = step;
          return true;
        }
      }
      return false;
    }

    @Override
    public double step() {
      return _step;
    }

    @Override
    public GradientInfo ginfo() {
      return _ginfo;
    }

    @Override public String toString(){return "";}
  }



  public static final class MoreThuente implements LineSearchSolver {
    double _stMin, _stMax;

    double _initialStep = 1;
    double _minRelativeImprovement = 1e-8;
    private final GradientSolver _gslvr;
    private double [] _beta;


    public MoreThuente(GradientSolver gslvr, double [] betaStart){
      this(gslvr,betaStart,gslvr.getGradient(betaStart),.1,.1,1e-2);
    }
    public MoreThuente(GradientSolver gslvr, double [] betaStart, GradientInfo ginfo){
      this(gslvr,betaStart,ginfo,.1,.1,1e-8);
    }
    public MoreThuente(GradientSolver gslvr, double [] betaStart, GradientInfo ginfo, double ftol, double gtol, double xtol){
      _gslvr = gslvr;
      _beta = betaStart;
      _ginfox = ginfo;
      if(ginfo._gradient == null)
        throw new IllegalArgumentException("GradientInfo for MoreThuente line search solver must include gradient");
      _ftol = ftol;
      _gtol = gtol;
      _xtol = xtol;
    }

    public MoreThuente setInitialStep(double t) {_initialStep = t; return this;}

    @Override
    public int nfeval() {
      return _iter;
    }

    @Override
    public double getObj() {return ginfo()._objVal;}

    @Override
    public double[] getX() { return _beta;}

    double _xtol = 1e-8;
    double _ftol = .1; // .2/.25 works
    double _gtol = .1;
    double _xtrapf = 4;

    // fval, dg and step of the best step so far
    double _fvx;
    double _dgx;
    double _stx;

    double _bestStep;
    GradientInfo _betGradient; // gradient info with at least minimal relative improvement and best value of augmented function
    double  _bestPsiVal; // best value of augmented function
    GradientInfo _ginfox;

    // fval, dg and step of the best step so far
    double _fvy;
    double _dgy;
    double _sty;

    boolean _brackt;
    boolean _bound;

    int _returnStatus;

    public final String [] messages = new String[]{
      "In progress or not evaluated", // 0
      "The sufficient decrease condition and the directional derivative condition hold.", // 1
      "Relative width of the interval of uncertainty is at most xtol.", // 2
      "Number of calls to gradient solver has reached the limit.", // 3
      "The step is at the lower bound stpmin.", // 4
      "The step is at the upper bound stpmax.", // 5
      "Rounding errors prevent further progress, ftol/gtol tolerances may be too small.", // 6
      "Non-negative differential." // 7
    };

    private double nextStep(GradientInfo ginfo, double dg, double stp, double off) {
      double fvp = ginfo._objVal - stp*off;
      double dgp = dg - off;
      double fvx = _fvx - _stx * off;
      double fvy = _fvy - _sty * off;
      double stx = _stx;
      double sty = _sty;
      double dgx = _dgx - off;
      double dgy = _dgy - off;

      if ((_brackt && (stp <= Math.min(stx,sty) || stp >= Math.max(stx,sty))) || dgx*(stp-stx) >= 0.0)
        return Double.NaN;
      double theta = 3 * (fvx - fvp) / (stp - stx) + dgx + dgp;
      double s = Math.max(Math.max(Math.abs(theta),Math.abs(dgx)),Math.abs(dgp));
      double sInv = 1/s;
      double ts = theta*sInv;
      double gamma = s*Math.sqrt(Math.max(0., (ts*ts) - ((dgx * sInv) * (dgp*sInv))));

      int info = 0;
      // case 1
      double nextStep;
      if (fvp > fvx) {
        info = 1;
        if (stp < stx) gamma = -gamma;
        _bound = true;
        _brackt = true;
        double p = (gamma - dgx) + theta;
        double q = ((gamma - dgx) + gamma) + dgp;
        double r = p / q;
        double stpc = stx + r * (stp - stx);
        double stpq = stx + ((dgx / ((fvx - fvp) / (stp - stx) + dgx)) / 2) * (stp - stx);
        nextStep = (Math.abs(stpc - stx) < Math.abs(stpq - stx)) ? stpc : stpc + (stpq - stpc) / 2;
      } else  if (dgp * dgx  < 0) { // case 2
        info = 2;
        if (stp > stx) gamma = -gamma;
        _bound = false;
        _brackt = true;
        double p = (gamma - dgp) + theta;
        double q = ((gamma - dgp) + gamma) + dgx;
        double r = p / q;
        double stpc = stp + r * (stx - stp);
        double stpq = stp + (dgp / (dgp - dgx)) * (stx - stp);
        nextStep = (Math.abs(stpc - stp) > Math.abs(stpq - stp)) ? stpc : stpq;
      } else if (Math.abs(dgp) < Math.abs(dgx)) { // case 3
        info = 3;
        if (stp > stx) gamma = -gamma;
        _bound = true;
        double p = gamma - dgp + theta;
        double q = gamma + dgx - dgp + gamma;
        double r = p / q;
        double stpc;
        if (r < 0.0 && gamma != 0.0)
          stpc = stp + r * (stx - stp);
        else if (stp > stx)
          stpc = _stMax;
        else
          stpc = _stMin;
        // stpq = stp + (dp/(dp-dx))*(stx - stp);
        double stpq = stp + (dgp / (dgp - dgx)) * (stx - stp);
        if (_brackt)
          nextStep = (Math.abs(stp - stpc) < Math.abs(stp - stpq)) ? stpc : stpq;
        else
          nextStep = (Math.abs(stp - stpc) > Math.abs(stp - stpq)) ? stpc : stpq;
      } else {
        // case 4
        info = 4;
        _bound = false;
        if (_brackt) {
          theta = 3 * (fvp - fvy) / (sty - stp) + dgy + dgp;
          gamma = Math.sqrt(theta * theta - dgy * dgp);
          if (stp > sty) gamma = -gamma;
          double p = (gamma - dgp) + theta;
          double q = ((gamma - dgp) + gamma) + dgy;
          double r = p / q;
          nextStep = stp + r * (sty - stp);
        } else
          nextStep = stp > stx ? _stMax : _stMin;
      }

      if(fvp > fvx) {
        _sty = stp;
        _fvy = ginfo._objVal;
        _dgy = dg;
      } else {
        if(dgp * dgx < 0) {
          _sty = _stx;
          _fvy = _fvx;
          _dgy = _dgx;
        }
        _stx = stp;
        _fvx = ginfo._objVal;
        _dgx = dg;
        _ginfox = ginfo;
      }
      if(nextStep > _stMax)
        nextStep = _stMax;
      if(nextStep < _stMin)
        nextStep = _stMin;
      if (_brackt & _bound)
        if (_sty > _stx)
          nextStep = Math.min(_stx + .66 * (_sty - _stx), nextStep);
        else
          nextStep = Math.max(_stx + .66 * (_sty - _stx), nextStep);
      return nextStep;
    }

    public String toString(){
      return "MoreThuente line search, iter = " + _iter + ", status = " + messages[_returnStatus] + ", step = " + _stx + ", I = " + "[" + _stMin + ", " + _stMax + "], grad = " + _dgx + ", bestObj = "  + _fvx;
    }
    private int _iter;

    int _maxfev = 20;
    double _maxStep = 1e10;
    double _minStep = 1e-10;
    @Override
    public boolean evaluate(double [] direction) {
      double oldObjval = _ginfox._objVal;
      double step = _initialStep;
      _bound = false;
      _brackt = false;
      _stx = _sty = 0;
      _stMin = _stMax = 0;
      _betGradient = null;
      _bestPsiVal = Double.POSITIVE_INFINITY;
      _bestStep = 0;
      double maxObj = _ginfox._objVal - _minRelativeImprovement*_ginfox._objVal;
      final double dgInit = ArrayUtils.innerProduct(_ginfox._gradient, direction);
      final double dgtest = dgInit * _ftol;
      if(dgtest > 1e-4) Log.warn("MoreThuente LS: got possitive differential " + dgtest);
      if(dgtest >= 0) {
        _returnStatus = 7;
        return false;
      }
      double [] beta = new double[_beta.length];
      double width = _maxStep - _minStep;
      double oldWidth = 2*width;
      boolean stage1 = true;
      _fvx = _fvy = _ginfox._objVal;
      _dgx = _dgy = dgInit;
      _iter = 0;

      while (true) {
        if (_brackt) {
          _stMin = Math.min(_stx, _sty);
          _stMax = Math.max(_stx, _sty);
        } else {
          _stMin = _stx;
          _stMax = step + _xtrapf * (step - _stx);
        }
        step = Math.min(step,_maxStep);
        step = Math.max(step,_minStep);
        double maxFval = oldObjval + step * dgtest;

        for (int i = 0; i < beta.length; ++i)
          beta[i] = _beta[i] + step * direction[i];
        GradientInfo newGinfo = _gslvr.getGradient(beta); // gradient for one class, active cols if rcc=true
        if(newGinfo._objVal < maxObj && (_betGradient == null || (newGinfo._objVal - maxFval) < _bestPsiVal)){
          _bestPsiVal = (newGinfo._objVal - maxFval);
          _betGradient = newGinfo;
          _bestStep = step;
        }
        ++_iter;
        if(_iter < _maxfev && (!Double.isNaN(step) && (Double.isNaN(newGinfo._objVal) || Double.isInfinite(newGinfo._objVal) || ArrayUtils.hasNaNsOrInfs(newGinfo._gradient)))) {
          _brackt = true;
          _sty = step;
          _maxStep = step;
          _fvy = Double.POSITIVE_INFINITY;
          _dgy = Double.MAX_VALUE;
          step *= .5;
          continue;
        }
        double dgp = ArrayUtils.innerProduct(newGinfo._gradient, direction);
        if(Double.isNaN(step) || _brackt && (step <= _stMin || step >= _stMax)) {
          _returnStatus = 6;
          break;
        }
        if (step == _maxStep && newGinfo._objVal <= maxFval & dgp <= dgtest){
          _returnStatus = 5;
          _stx = step;
          _ginfox = newGinfo;
          break;
        }
        if (step == _minStep && (newGinfo._objVal > maxFval | dgp >= dgtest)){
          _returnStatus = 4;
          if(_betGradient != null) {
             _stx = _bestStep;
             _ginfox = _betGradient;
          } else {
            _stx = step;
            _ginfox = newGinfo;
          }
          break;
        }
        if (_iter >= _maxfev){
          _returnStatus = 3;
          if(_betGradient != null) {
            _stx = _bestStep;
            _ginfox = _betGradient;
          } else {
            _stx = step;
            _ginfox = newGinfo;
          }
          break;
        }
        if (_brackt && (_stMax-_stMin) <= _xtol*_stMax) {
          _ginfox = newGinfo;
          _returnStatus = 2;
          break;
        }
        // check for convergence
        if (newGinfo._objVal < maxFval && Math.abs(dgp) <= -_gtol * dgInit) { // got solution satisfying both conditions
          _stx = step;
          _dgx = dgp;
          _fvx = newGinfo._objVal;
          _ginfox = newGinfo;
          _returnStatus = 1;
          break;
        }
        // f > ftest1 || dg < min(ftol,gtol)*dginit
        stage1 = stage1 && (newGinfo._objVal > maxFval || dgp < dgtest);
        boolean useAugmentedFuntcion = stage1 && newGinfo._objVal <= _fvx && newGinfo._objVal > maxFval;
        double off = useAugmentedFuntcion?dgtest:0;
        double nextStep = nextStep(newGinfo,dgp,step,off);
        if (_brackt) {
          if (Math.abs(_sty - _stx) >= .66 * oldWidth)
            nextStep = _stx + .5 * (_sty - _stx);
          oldWidth = width;
          width = Math.abs(_sty - _stx);
        }
        step = nextStep;
      }
      boolean succ = _ginfox._objVal < oldObjval;
      if(succ) {
        // make sure we have correct beta (not all return cases have valid current beta!)
        for (int i = 0; i < beta.length; ++i)
          beta[i] = _beta[i] + _stx * direction[i];
        _beta = beta;
      }
      return succ;
    }

    @Override
    public double step() {return _stx;}


    @Override
    public GradientInfo ginfo() {
      return _ginfox;
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy