All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.optimization.ADMM Maven / Gradle / Ivy

package hex.optimization;

import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;
import water.util.MathUtils.Norm;

/**
 * Created by tomasnykodym on 3/2/15.
 */
public class ADMM {

  public interface ProximalSolver {
    public double []  rho();
    public boolean solve(double [] beta_given, double [] result);
    public boolean hasGradient();
    public OptimizationUtils.GradientInfo gradient(double [] beta);
    public int iter();
  }

  public static class L1Solver {
    final double RELTOL;
    final double ABSTOL;
    double gerr;
    int iter;
    final double _eps;
    final int max_iter;

    MathUtils.Norm _gradientNorm = Norm.L_Infinite;

    public double [] _u;

    public static double DEFAULT_RELTOL = 1e-2;
    public static double DEFAULT_ABSTOL = 1e-4;
    public L1Solver setGradientNorm(MathUtils.Norm n) {_gradientNorm = n; return this;}
    public L1Solver(double eps, int max_iter, double [] u) {
      this(eps,max_iter,DEFAULT_RELTOL,DEFAULT_ABSTOL,u);
    }

    public L1Solver(double eps, int max_iter, double reltol, double abstol, double [] u) {
      _eps = eps; this.max_iter = max_iter;
      _u = u;
      RELTOL = reltol;
      ABSTOL = abstol;
    }

    public L_BFGS.ProgressMonitor _pm;
    public boolean solve(ProximalSolver solver, double[] res, double lambda, boolean hasIntercept) {
      return solve(solver, res, lambda, hasIntercept, null, null);
    }

    private double computeErr(double[] z, double[] grad, double lambda, double[] lb, double[] ub) {
      grad = grad.clone();
      // check the gradient
      gerr = 0;
      if (lb != null)
        for (int j = 0; j < z.length; ++j)
          if (z[j] == lb[j] && grad[j] > 0)
            grad[j] = z[j] >= 0?-lambda:lambda;
      if (ub != null)
        for (int j = 0; j < z.length; ++j)
          if (z[j] == ub[j] && grad[j] < 0)
            grad[j] = z[j] >= 0?-lambda:lambda;
      subgrad(lambda, z, grad);
      switch(_gradientNorm) {
        case L_Infinite:
          gerr = ArrayUtils.linfnorm(grad,false);
          break;
        case L2_2:
          gerr = ArrayUtils.l2norm2(grad, false);
          break;
        case L2:
          gerr = Math.sqrt(ArrayUtils.l2norm2(grad, false));
          break;
        case L1:
          gerr = ArrayUtils.l1norm(grad,false);
          break;
        default:
          throw H2O.unimpl();
      }
      return gerr;
    }

    public boolean solve(ProximalSolver solver, double[] z, double l1pen, boolean hasIntercept, double[] lb, double[] ub) {
      gerr = Double.POSITIVE_INFINITY;
      iter = 0;
      if (l1pen == 0 && lb == null && ub == null) {
        solver.solve(null, z);
        return true;
      }
      int hasIcpt = hasIntercept?1:0;
      int N = z.length;
      double abstol = ABSTOL * Math.sqrt(N);
      double [] rho = solver.rho();
      double [] x = z.clone();
      double [] beta_given = MemoryManager.malloc8d(N);
      double [] u;
      if(_u != null) {
        u = _u;
        for (int i = 0; i < beta_given.length - hasIcpt; ++i)
          beta_given[i] = z[i] - _u[i];
      } else u = _u = MemoryManager.malloc8d(z.length);
      double [] kappa = MemoryManager.malloc8d(rho.length);
      if(l1pen > 0)
        for(int i = 0; i < N-hasIcpt; ++i)
          kappa[i] = rho[i] != 0?l1pen/rho[i]:0;
      int i;
      double orlx = 1.0; // over-relaxation
      double reltol = RELTOL;
      for (i = 0; i < max_iter && solver.solve(beta_given, x); ++i) {
        if(_pm != null && (i + 1) % 5 == 0)_pm.progress(z,solver.gradient(z));
        // compute u and z updateADMM
        double rnorm = 0, snorm = 0, unorm = 0, xnorm = 0;
        for (int j = 0; j < N - hasIcpt; ++j) {
          double xj = x[j];
          double zjold = z[j];
          double x_hat = xj * orlx + (1 - orlx) * zjold;
          double zj = shrinkage(x_hat + u[j], kappa[j]);
          if (lb != null && zj < lb[j])
            zj = lb[j];
          if (ub != null && zj > ub[j])
            zj = ub[j];
          u[j] += x_hat - zj;
          beta_given[j] = zj - u[j];
          double r = xj - zj;
          double s = zj - zjold;
          rnorm += r * r;
          snorm += s * s;
          xnorm += xj * xj;
          unorm += rho[j] * rho[j] * u[j] * u[j];
          z[j] = zj;
        }
        if (hasIntercept) {
          int idx = x.length - 1;
          double icpt = x[idx];
          if (lb != null && icpt < lb[idx])
            icpt = lb[idx];
          if (ub != null && icpt > ub[idx])
            icpt = ub[idx];
          double r = x[idx] - icpt;
          double s = icpt - z[idx];
          u[idx] += r;
          beta_given[idx] = icpt - u[idx];
          rnorm += r * r;
          snorm += s * s;
          xnorm += icpt * icpt;
          unorm += rho[idx] * rho[idx] * u[idx] * u[idx];
          z[idx] = icpt;
        }
        if (rnorm < (abstol + (reltol * Math.sqrt(xnorm))) && snorm < (abstol + reltol * Math.sqrt(unorm))) {
          double oldGerr = gerr;
          computeErr(z, solver.gradient(z)._gradient, l1pen, lb, ub);
          if ((gerr > _eps) /* || solver.improving() */){// && (allzeros || i < 5 /* let some warm up before giving up */ /*|| Math.abs(oldGerr - gerr) > _eps * 0.1*/)) {
            Log.debug("ADMM.L1Solver: iter = " + i + " , gerr =  " + gerr + ", oldGerr = " + oldGerr + ", rnorm = " + rnorm + ", snorm  " + snorm);
            if(abstol > 1e-12) abstol *= .1;
            if(reltol > 1e-10) reltol *= .1;
            reltol *= .1;
            continue;
          }
          if(gerr > _eps)
            Log.warn("ADMM solver finished with gerr = " + gerr + " >  eps = " + _eps);
          iter = i;
          if(_pm != null && (i + 1) % 5 == 0)_pm.progress(z,solver.gradient(z));
          return true;
        }
      }
      computeErr(z, solver.gradient(z)._gradient, l1pen, lb, ub);
      if(iter == max_iter)
        Log.warn("ADMM solver reached maximum number of iterations (" + max_iter + ")");
      else
        Log.warn("ADMM solver stopped after " + i + " iterations. (max_iter=" + max_iter + ")");
      if(gerr > _eps) Log.warn("ADMM solver finished with gerr = " + gerr + " >  eps = " + _eps);
      iter = max_iter;
      if(_pm != null && (i + 1) % 5 == 0)_pm.progress(z,solver.gradient(z));
      return false;
    }

    @Override public String toString(){
      return "iter = " + iter + ", gerr = " + gerr;
    }
    /**
     * Estimate optimal rho based on l1 penalty and (estimate of) solution x without the l1penalty
     * @param x
     * @param l1pen
     * @return
     */
    public static double estimateRho(double x, double l1pen, double lb, double ub){
      if(Double.isInfinite(x))return 0; // happens for all zeros
      double rho = 0;
      if(l1pen != 0 && x != 0) {
        if (x > 0) {
          double D = l1pen * (l1pen + 4 * x);
          if (D >= 0) {
            D = Math.sqrt(D);
            double r = (l1pen + D) / (2 * x);
            if (r > 0) rho = r;
            else
              Log.warn("negative rho estimate(1)! r = " + r);
          }
        } else if (x < 0) {
          double D = l1pen * (l1pen - 4 * x);
          if (D >= 0) {
            D = Math.sqrt(D);
            double r = -(l1pen + D) / (2 * x);
            if (r > 0) rho = r;
            else Log.warn("negative rho estimate(2)!  r = " + r);
          }
        }
        rho *= .25;
      }
      if(!Double.isInfinite(lb) || !Double.isInfinite(ub)) {
        boolean oob = -Math.min(x - lb, ub - x) > -1e-4;
        rho = oob?10:1e-1;
      }
      return rho;
    }
  }

  public static double shrinkage(double x, double kappa) {
    double sign = x < 0?-1:1;
    double sx = x*sign;
    return sx <= kappa?0:sign*(sx - kappa);
  }

  public static void subgrad(final double lambda, final double [] beta, final double [] grad){
    if(beta == null)return;
    for(int i = 0; i < grad.length-1; ++i) {// add l2 reg. term to the gradient
      if(beta[i] < 0) grad[i] = shrinkage(grad[i]-lambda,lambda*1e-4);
      else if(beta[i] > 0) grad[i] = shrinkage(grad[i] + lambda,lambda*1e-4);
      else grad[i] = shrinkage(grad[i], lambda);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy