hex.optimization.OptimizationUtils Maven / Gradle / Ivy
package hex.optimization;
import water.Iced;
import water.util.ArrayUtils;
import water.util.Log;
import java.util.Arrays;
/**
* Created by tomasnykodym on 9/29/15.
*/
public class OptimizationUtils {
public static class GradientInfo extends Iced {
public double _objVal;
public double [] _gradient;
public GradientInfo(double objVal, double [] grad){
_objVal = objVal;
_gradient = grad;
}
public boolean isValid(){
if(Double.isNaN(_objVal))
return false;
return !ArrayUtils.hasNaNsOrInfs(_gradient);
}
@Override
public String toString(){
return " objVal = " + _objVal + ", " + Arrays.toString(_gradient);
}
}
/**
* Provides ginfo computation and line search evaluation specific to given problem.
* Typically just a wrapper around MRTask calls.
*/
public interface GradientSolver {
/**
* Evaluate ginfo at solution beta.
* @param beta
* @return
*/
GradientInfo getGradient(double [] beta);
GradientInfo getObjective(double [] beta);
}
public interface LineSearchSolver {
boolean evaluate(double [] direction);
double step();
GradientInfo ginfo();
LineSearchSolver setInitialStep(double s);
int nfeval();
double getObj();
double[] getX();
}
public static final class SimpleBacktrackingLS implements LineSearchSolver {
private double [] _beta;
final double _stepDec = .33;
private double _step;
private final GradientSolver _gslvr;
private GradientInfo _ginfo; // gradient info excluding l1 penalty
private double _objVal; // objective including l1 penalty
final double _l1pen;
int _maxfev = 20;
double _minStep = 1e-4;
public SimpleBacktrackingLS(GradientSolver gslvr, double [] betaStart, double l1pen) {
this(gslvr, betaStart, l1pen, gslvr.getObjective(betaStart));
}
public SimpleBacktrackingLS(GradientSolver gslvr, double [] betaStart, double l1pen, GradientInfo ginfo) {
_gslvr = gslvr;
_beta = betaStart;
_ginfo = ginfo;
_l1pen = l1pen;
_objVal = _ginfo._objVal + _l1pen * ArrayUtils.l1norm(_beta,true); // gam smoothness should have been added
}
public int nfeval() {return -1;}
@Override
public double getObj() {return _objVal;}
@Override
public double[] getX() {return _beta;}
public LineSearchSolver setInitialStep(double s){
return this;
}
@Override
public boolean evaluate(double[] direction) {
double step = 1;
double minStep = 1;
for(double d:direction) {
d = Math.abs(1e-4/d);
if(d < minStep) minStep = d;
}
double [] newBeta = direction.clone();
for(int i = 0; i < _maxfev && step >= minStep; ++i, step*= _stepDec) {
GradientInfo ginfo = _gslvr.getObjective(ArrayUtils.wadd(_beta,direction,newBeta,step));
double objVal = ginfo._objVal + _l1pen * ArrayUtils.l1norm(newBeta,true);
if(objVal < _objVal){
_ginfo = ginfo;
_objVal = objVal;
_beta = newBeta;
_step = step;
return true;
}
}
return false;
}
@Override
public double step() {
return _step;
}
@Override
public GradientInfo ginfo() {
return _ginfo;
}
@Override public String toString(){return "";}
}
public static final class MoreThuente implements LineSearchSolver {
double _stMin, _stMax;
double _initialStep = 1;
double _minRelativeImprovement = 1e-8;
private final GradientSolver _gslvr;
private double [] _beta;
public MoreThuente(GradientSolver gslvr, double [] betaStart){
this(gslvr,betaStart,gslvr.getGradient(betaStart),.1,.1,1e-2);
}
public MoreThuente(GradientSolver gslvr, double [] betaStart, GradientInfo ginfo){
this(gslvr,betaStart,ginfo,.1,.1,1e-8);
}
public MoreThuente(GradientSolver gslvr, double [] betaStart, GradientInfo ginfo, double ftol, double gtol, double xtol){
_gslvr = gslvr;
_beta = betaStart;
_ginfox = ginfo;
if(ginfo._gradient == null)
throw new IllegalArgumentException("GradientInfo for MoreThuente line search solver must include gradient");
_ftol = ftol;
_gtol = gtol;
_xtol = xtol;
}
public MoreThuente setInitialStep(double t) {_initialStep = t; return this;}
@Override
public int nfeval() {
return _iter;
}
@Override
public double getObj() {return ginfo()._objVal;}
@Override
public double[] getX() { return _beta;}
double _xtol = 1e-8;
double _ftol = .1; // .2/.25 works
double _gtol = .1;
double _xtrapf = 4;
// fval, dg and step of the best step so far
double _fvx;
double _dgx;
double _stx;
double _bestStep;
GradientInfo _betGradient; // gradient info with at least minimal relative improvement and best value of augmented function
double _bestPsiVal; // best value of augmented function
GradientInfo _ginfox;
// fval, dg and step of the best step so far
double _fvy;
double _dgy;
double _sty;
boolean _brackt;
boolean _bound;
int _returnStatus;
public final String [] messages = new String[]{
"In progress or not evaluated", // 0
"The sufficient decrease condition and the directional derivative condition hold.", // 1
"Relative width of the interval of uncertainty is at most xtol.", // 2
"Number of calls to gradient solver has reached the limit.", // 3
"The step is at the lower bound stpmin.", // 4
"The step is at the upper bound stpmax.", // 5
"Rounding errors prevent further progress, ftol/gtol tolerances may be too small.", // 6
"Non-negative differential." // 7
};
private double nextStep(GradientInfo ginfo, double dg, double stp, double off) {
double fvp = ginfo._objVal - stp*off;
double dgp = dg - off;
double fvx = _fvx - _stx * off;
double fvy = _fvy - _sty * off;
double stx = _stx;
double sty = _sty;
double dgx = _dgx - off;
double dgy = _dgy - off;
if ((_brackt && (stp <= Math.min(stx,sty) || stp >= Math.max(stx,sty))) || dgx*(stp-stx) >= 0.0)
return Double.NaN;
double theta = 3 * (fvx - fvp) / (stp - stx) + dgx + dgp;
double s = Math.max(Math.max(Math.abs(theta),Math.abs(dgx)),Math.abs(dgp));
double sInv = 1/s;
double ts = theta*sInv;
double gamma = s*Math.sqrt(Math.max(0., (ts*ts) - ((dgx * sInv) * (dgp*sInv))));
int info = 0;
// case 1
double nextStep;
if (fvp > fvx) {
info = 1;
if (stp < stx) gamma = -gamma;
_bound = true;
_brackt = true;
double p = (gamma - dgx) + theta;
double q = ((gamma - dgx) + gamma) + dgp;
double r = p / q;
double stpc = stx + r * (stp - stx);
double stpq = stx + ((dgx / ((fvx - fvp) / (stp - stx) + dgx)) / 2) * (stp - stx);
nextStep = (Math.abs(stpc - stx) < Math.abs(stpq - stx)) ? stpc : stpc + (stpq - stpc) / 2;
} else if (dgp * dgx < 0) { // case 2
info = 2;
if (stp > stx) gamma = -gamma;
_bound = false;
_brackt = true;
double p = (gamma - dgp) + theta;
double q = ((gamma - dgp) + gamma) + dgx;
double r = p / q;
double stpc = stp + r * (stx - stp);
double stpq = stp + (dgp / (dgp - dgx)) * (stx - stp);
nextStep = (Math.abs(stpc - stp) > Math.abs(stpq - stp)) ? stpc : stpq;
} else if (Math.abs(dgp) < Math.abs(dgx)) { // case 3
info = 3;
if (stp > stx) gamma = -gamma;
_bound = true;
double p = gamma - dgp + theta;
double q = gamma + dgx - dgp + gamma;
double r = p / q;
double stpc;
if (r < 0.0 && gamma != 0.0)
stpc = stp + r * (stx - stp);
else if (stp > stx)
stpc = _stMax;
else
stpc = _stMin;
// stpq = stp + (dp/(dp-dx))*(stx - stp);
double stpq = stp + (dgp / (dgp - dgx)) * (stx - stp);
if (_brackt)
nextStep = (Math.abs(stp - stpc) < Math.abs(stp - stpq)) ? stpc : stpq;
else
nextStep = (Math.abs(stp - stpc) > Math.abs(stp - stpq)) ? stpc : stpq;
} else {
// case 4
info = 4;
_bound = false;
if (_brackt) {
theta = 3 * (fvp - fvy) / (sty - stp) + dgy + dgp;
gamma = Math.sqrt(theta * theta - dgy * dgp);
if (stp > sty) gamma = -gamma;
double p = (gamma - dgp) + theta;
double q = ((gamma - dgp) + gamma) + dgy;
double r = p / q;
nextStep = stp + r * (sty - stp);
} else
nextStep = stp > stx ? _stMax : _stMin;
}
if(fvp > fvx) {
_sty = stp;
_fvy = ginfo._objVal;
_dgy = dg;
} else {
if(dgp * dgx < 0) {
_sty = _stx;
_fvy = _fvx;
_dgy = _dgx;
}
_stx = stp;
_fvx = ginfo._objVal;
_dgx = dg;
_ginfox = ginfo;
}
if(nextStep > _stMax)
nextStep = _stMax;
if(nextStep < _stMin)
nextStep = _stMin;
if (_brackt & _bound)
if (_sty > _stx)
nextStep = Math.min(_stx + .66 * (_sty - _stx), nextStep);
else
nextStep = Math.max(_stx + .66 * (_sty - _stx), nextStep);
return nextStep;
}
public String toString(){
return "MoreThuente line search, iter = " + _iter + ", status = " + messages[_returnStatus] + ", step = " + _stx + ", I = " + "[" + _stMin + ", " + _stMax + "], grad = " + _dgx + ", bestObj = " + _fvx;
}
private int _iter;
int _maxfev = 20;
double _maxStep = 1e10;
double _minStep = 1e-10;
@Override
public boolean evaluate(double [] direction) {
double oldObjval = _ginfox._objVal;
double step = _initialStep;
_bound = false;
_brackt = false;
_stx = _sty = 0;
_stMin = _stMax = 0;
_betGradient = null;
_bestPsiVal = Double.POSITIVE_INFINITY;
_bestStep = 0;
double maxObj = _ginfox._objVal - _minRelativeImprovement*_ginfox._objVal;
final double dgInit = ArrayUtils.innerProduct(_ginfox._gradient, direction);
final double dgtest = dgInit * _ftol;
if(dgtest > 1e-4) Log.warn("MoreThuente LS: got possitive differential " + dgtest);
if(dgtest >= 0) {
_returnStatus = 7;
return false;
}
double [] beta = new double[_beta.length];
double width = _maxStep - _minStep;
double oldWidth = 2*width;
boolean stage1 = true;
_fvx = _fvy = _ginfox._objVal;
_dgx = _dgy = dgInit;
_iter = 0;
while (true) {
if (_brackt) {
_stMin = Math.min(_stx, _sty);
_stMax = Math.max(_stx, _sty);
} else {
_stMin = _stx;
_stMax = step + _xtrapf * (step - _stx);
}
step = Math.min(step,_maxStep);
step = Math.max(step,_minStep);
double maxFval = oldObjval + step * dgtest;
for (int i = 0; i < beta.length; ++i)
beta[i] = _beta[i] + step * direction[i];
GradientInfo newGinfo = _gslvr.getGradient(beta); // gradient for one class, active cols if rcc=true
if(newGinfo._objVal < maxObj && (_betGradient == null || (newGinfo._objVal - maxFval) < _bestPsiVal)){
_bestPsiVal = (newGinfo._objVal - maxFval);
_betGradient = newGinfo;
_bestStep = step;
}
++_iter;
if(_iter < _maxfev && (!Double.isNaN(step) && (Double.isNaN(newGinfo._objVal) || Double.isInfinite(newGinfo._objVal) || ArrayUtils.hasNaNsOrInfs(newGinfo._gradient)))) {
_brackt = true;
_sty = step;
_maxStep = step;
_fvy = Double.POSITIVE_INFINITY;
_dgy = Double.MAX_VALUE;
step *= .5;
continue;
}
double dgp = ArrayUtils.innerProduct(newGinfo._gradient, direction);
if(Double.isNaN(step) || _brackt && (step <= _stMin || step >= _stMax)) {
_returnStatus = 6;
break;
}
if (step == _maxStep && newGinfo._objVal <= maxFval & dgp <= dgtest){
_returnStatus = 5;
_stx = step;
_ginfox = newGinfo;
break;
}
if (step == _minStep && (newGinfo._objVal > maxFval | dgp >= dgtest)){
_returnStatus = 4;
if(_betGradient != null) {
_stx = _bestStep;
_ginfox = _betGradient;
} else {
_stx = step;
_ginfox = newGinfo;
}
break;
}
if (_iter >= _maxfev){
_returnStatus = 3;
if(_betGradient != null) {
_stx = _bestStep;
_ginfox = _betGradient;
} else {
_stx = step;
_ginfox = newGinfo;
}
break;
}
if (_brackt && (_stMax-_stMin) <= _xtol*_stMax) {
_ginfox = newGinfo;
_returnStatus = 2;
break;
}
// check for convergence
if (newGinfo._objVal < maxFval && Math.abs(dgp) <= -_gtol * dgInit) { // got solution satisfying both conditions
_stx = step;
_dgx = dgp;
_fvx = newGinfo._objVal;
_ginfox = newGinfo;
_returnStatus = 1;
break;
}
// f > ftest1 || dg < min(ftol,gtol)*dginit
stage1 = stage1 && (newGinfo._objVal > maxFval || dgp < dgtest);
boolean useAugmentedFuntcion = stage1 && newGinfo._objVal <= _fvx && newGinfo._objVal > maxFval;
double off = useAugmentedFuntcion?dgtest:0;
double nextStep = nextStep(newGinfo,dgp,step,off);
if (_brackt) {
if (Math.abs(_sty - _stx) >= .66 * oldWidth)
nextStep = _stx + .5 * (_sty - _stx);
oldWidth = width;
width = Math.abs(_sty - _stx);
}
step = nextStep;
}
boolean succ = _ginfox._objVal < oldObjval;
if(succ) {
// make sure we have correct beta (not all return cases have valid current beta!)
for (int i = 0; i < beta.length; ++i)
beta[i] = _beta[i] + _stx * direction[i];
_beta = beta;
}
return succ;
}
@Override
public double step() {return _stx;}
@Override
public GradientInfo ginfo() {
return _ginfox;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy