All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.optimization.impl.Nesterov1 Maven / Gradle / Ivy

package com.expleague.ml.optimization.impl;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.VecTools;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.util.logging.Logger;
import com.expleague.ml.optimization.FuncConvex;
import com.expleague.ml.optimization.Optimize;

/**
 * User: qde
 * Date: 06.09.13
 * Time: 16:45
 */
public class Nesterov1 implements Optimize {
  private static final Logger LOG = Logger.create(Nesterov1.class);

  private final Vec x0;
  private final double eps;

  public Nesterov1(final Vec x0, final double eps) {
    this.x0 = x0;
    this.eps = eps;
  }

  @Override
  public Vec optimize(final FuncConvex func, final Vec x0) {
    final int n = func.xdim();
    double alpha;
    final double L = func.getGradLipParam();
    final double m = func.getGlobalConvexParam();
    double gamma1 = L;
    double gamma2;

    Vec x1 = VecTools.copy(x0);
    final Vec x2 = new ArrayVec(n);
    Vec v1 = VecTools.copy(x0);
    final Vec v2 = new ArrayVec(n);
    final Vec y = new ArrayVec(n);
    Vec grad;

    int iter = 0;
    do {
//            compute alpha
      {
        final double D = m * m - 2 * gamma1 * m + 4 * gamma1 * L + gamma1 * gamma1;
        final double root1 = ((m - gamma1) - Math.sqrt(D)) / (2 * L);
        final double root2 = ((m - gamma1) + Math.sqrt(D)) / (2 * L);
        if (0 <= root2 && root2 <= 1)
          alpha = root2;
        else {
          if (0 <= root1 && root1 <= 1)
            alpha = root1;
          else
            throw new IllegalArgumentException("Roots are not in the interval, something was wrong on iter#" + iter);
        }
      }
      gamma2 = (1 - alpha) * gamma1 + alpha * m;

      for (int i = 0; i < n; i++) {
        y.set(i, (alpha * gamma1 * v1.get(i) + gamma2 * x1.get(i)) / (gamma1 + alpha * m));
      }

      grad = func.gradient().trans(y);
      for (int i = 0; i < n; i++) {
        x2.set(i, y.get(i) - (1.0 / L) * grad.get(i));
      }

      for (int i = 0; i < n; i++) {
        v2.set(i, (1.0 / gamma2) * ((1 - alpha) * gamma1 * v1.get(i) + alpha * m * y.get(i) - alpha * grad.get(i)));
      }

      x1 = VecTools.copy(x2);
      v1 = VecTools.copy(v2);
      gamma1 = gamma2;
      iter++;

    } while ((VecTools.norm(grad) / m) > eps);

    //LOG.message("N1 iterations = " + iter);
    return x2;
  }

  @Override
  public Vec optimize(FuncConvex func) {
    return optimize(func, x0);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy