com.expleague.ml.optimization.impl.GradientDescent Maven / Gradle / Ivy
package com.expleague.ml.optimization.impl;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.util.logging.Logger;
import com.expleague.ml.optimization.FuncConvex;
import com.expleague.ml.optimization.Optimize;
import com.expleague.ml.optimization.PDQuadraticFunction;
/**
* User: qde
* Date: 25.04.13
* Time: 23:41
*/
public class GradientDescent implements Optimize {
private static final Logger LOG = Logger.create(GradientDescent.class);
private final Vec x0;
private final double eps;
public GradientDescent(final Vec x0, final double eps) {
this.x0 = x0;
this.eps = eps;
}
@Override
public Vec optimize(final FuncConvex func, Vec x0) {
final boolean isQuadraticFunc = func instanceof PDQuadraticFunction;
final double constStep = 1.0 / func.getGradLipParam();
Vec x1 = VecTools.copy(x0);
final Vec x2 = new ArrayVec(x0.dim());
Vec grad = func.gradient().trans(x0);
int iter = 0;
double distance = 1;
while (distance > eps && iter < 5000000) {
final double step = isQuadraticFunc? getStepSizeForQuadraticFunc(func, grad) : constStep;
for (int i = 0; i < x2.dim(); i++) {
x2.set(i, x1.get(i) - grad.get(i) * step);
}
x1 = VecTools.copy(x2);
grad = func.gradient().trans(x1);
distance = VecTools.norm(grad) / func.getGlobalConvexParam();
iter++;
}
LOG.message("GDM iterations = " + iter + "\n\n");
return x2;
}
@Override
public Vec optimize(FuncConvex func) {
return optimize(func, x0);
}
private double getStepSizeForQuadraticFunc(final FuncConvex func, final Vec grad) {
return VecTools.multiply(grad, grad) / ((PDQuadraticFunction) func).getQuadrPartValue(grad);
}
}