All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.optimizer.LBFGS Maven / Gradle / Ivy

The newest version!
/* Heavily modified from:
 * RISO: an implementation of distributed belief networks. (Copyright 1999, Robert Dodier).
 *
 * License: Apache license, version 2.
 */
package com.stripe.rainier.optimizer;

class LBFGS
{
	private double stp1;
	private double ys;
	private double yy;
	private double yr;
	private int iter;
	private int point;
	private int ispt;
	private int iypt;
	private int bound;
	private int npt;

	private int inmc;

	private int info;
	private double stp;
	private int nfev;
	private double[] w;
	private int m;
	private int n;
	private double eps;
	private double[] diag;
	private double[] x;


	/*  m The number of corrections used in the BFGS update. 
	*		Values of less than 3 are not recommended;
	*		large values will result in excessive
	*		computing time. 3-7 is recommended.
	*
	*	 eps Determines the accuracy with which the solution
	*		is to be found. The subroutine terminates when
	*       ||G|| < eps * max(1,||X||)
	*/

	public LBFGS(double[] x, int m, double eps) {
		this.x = x;
		this.m = m;
		this.n = x.length;
		this.eps = eps;

		w = new double[ n*(2*m+1)+2*m ];

		iter = 0;
		point= 0;

		diag = new double[n];
		for (int i = 0 ; i < n ; i += 1 )
			diag [i] = 1;

		ispt= n+2*m;
		iypt= ispt+n*m;
	}

	public boolean apply (double f, double[] g)
	{
		boolean execute_entire_while_loop = false;
		if ( iter == 0 )
		{
			//initialize
			for (int i = 0 ; i < n ; i += 1 )
			{
				w[ispt + i] = -g[i] * diag[i];
			}
	
			double gnorm = Math.sqrt ( ddot ( n , g , 0, 1 , g , 0, 1 ) );
			stp1= 1/gnorm;	
			execute_entire_while_loop = true;
		}

		while ( true )
		{
			if ( execute_entire_while_loop )
			{
				iter= iter+1;
				info=0;
				bound=iter-1;
				if ( iter != 1 )
				{
					if ( iter > m ) bound = m;
					ys = ddot ( n , w , iypt + npt , 1 , w , ispt + npt , 1 );
					yy = ddot ( n , w , iypt + npt , 1 , w , iypt + npt , 1 );

					for ( int i = 0 ; i < n ; i += 1 )
						diag [i] = ys / yy;
				}
			}

			if ( execute_entire_while_loop)
			{
				if ( iter != 1 )
				{
					int cp= point;
					if ( point == 0 ) cp = m;
					w [ n + cp -1] = 1 / ys;

					for ( int i = 0 ; i < n ; i += 1 )
					{
						w[i] = -g[i];
					}

					cp= point;

					for ( int i = 0 ; i < bound ; i += 1 )
					{
						cp=cp-1;
						if ( cp == - 1 ) cp = m - 1;
						double sq = ddot ( n , w , ispt + cp * n , 1 , w , 0 , 1 );
						inmc=n+m+cp;
						int iycn=iypt+cp*n;
						w [inmc] = w [n + cp] * sq;
						daxpy ( n , -w[inmc] , w , iycn , 1 , w , 0 , 1 );
					}

					for ( int i = 0 ; i < n ; i += 1 )
					{
						w [i] = diag [i] * w[i];
					}

					for ( int i = 0 ; i < bound ; i += 1 )
					{
						yr = ddot ( n , w , iypt + cp * n , 1 , w , 0 , 1 );
						double beta = w [ n + cp] * yr;
						inmc=n+m+cp;
						beta = w [inmc] - beta;
						int iscn=ispt+cp*n;
						daxpy ( n , beta , w , iscn , 1 , w , 0 , 1 );
						cp=cp+1;
						if ( cp == m ) cp = 0;
					}

					for ( int i = 0 ; i < n ; i += 1 )
					{
						w[ispt + point * n + i] = w[i];
					}
				}

				nfev=0;
				stp=1;
				if ( iter == 1 ) stp = stp1;

				for (int i = 0; i < n; i += 1 )
				{
					w[i] = g[i];
				}
			}

			mcsrch(f , g);

			if ( info == -1 )
				return false;

			npt=point*n;

			for ( int i = 0 ; i < n ; i += 1 )
			{
				w [ ispt + npt + i] = stp * w [ ispt + npt + i];
				w [ iypt + npt + i] = g [i] - w[i];
			}

			point=point+1;
			if ( point == m ) point = 0;

			double gnorm = Math.sqrt ( ddot ( n , g , 0 , 1 , g , 0 , 1 ) );
			double xnorm = Math.sqrt ( ddot ( n , x , 0 , 1 , x , 0 , 1 ) );
			xnorm = Math.max ( 1.0 , xnorm );

			if ( gnorm / xnorm <= eps )
				return true;

			execute_entire_while_loop = true;		// from now on, execute whole loop
		}
	}


	private static double gtol = 0.9;

	private static double STPMIN = 1e-20;
	private static double STPMAX = 1e20;

	private static double xtol = 1e-16;
	private static double ftol= 0.0001; 
	private static int maxfev= 20;

	private static double p5 = 0.5;
	private static double p66 = 0.66;

	private static double xtrapf = 4;


	private double dg;
	private double dgm;
	private double dginit;
	private double dgtest;
	private double finit;
	private double ftest1;
	private double fm;
	private double stmin;
	private double stmax;
	private double width;
	private double width1;

	private boolean stage1 = false;

	private int infoc;
	private boolean brackt;

	private double dgx[] = new double[1];
	private double dgy[] = new double[1];
	private double fx[] = new double[1];
	private double fy[] = new double[1];

	private double dgxm[] = new double[1];
	private double dgym[] = new double[1];
	private double fxm[] = new double[1];
	private double fym[] = new double[1];

	private double stx;
	private double sty;

	private void mcsrch (double f , double[] g)
	{
		int is0 = ispt + point * n;
		if ( info != - 1 )
		{
			infoc = 1;

			// Compute the initial gradient in the search direction
			// and check that s is a descent direction.

			dginit = 0;

			for ( int j = 0 ; j < n ; j += 1 )
				dginit = dginit + g [j] * w[is0+j];

			if ( dginit >= 0 )
				throw new RuntimeException("dginit");

			brackt = false;
			stage1 = true;
			nfev = 0;
			finit = f;
			dgtest = ftol*dginit;
			width = STPMAX - STPMIN;
			width1 = width/p5;

			for ( int j = 0 ; j < n ; j += 1 )
				diag[j] = x[j];

			// The variables stx, fx, dgx contain the values of the step,
			// function, and directional derivative at the best step.
			// The variables sty, fy, dgy contain the value of the step,
			// function, and derivative at the other endpoint of
			// the interval of uncertainty.
			// The variables stp, f, dg contain the values of the step,
			// function, and derivative at the current step.

			stx = 0;
			fx[0] = finit;
			dgx[0] = dginit;
			sty = 0;
			fy[0] = finit;
			dgy[0] = dginit;
		}

		while ( true )
		{
			if ( info != -1 )
			{
				// Set the minimum and maximum steps to correspond
				// to the present interval of uncertainty.

				if ( brackt )
				{
					stmin = Math.min ( stx , sty );
					stmax = Math.max ( stx , sty );
				}
				else
				{
					stmin = stx;
					stmax = stp + xtrapf * ( stp - stx );
				}

				// Force the step to be within the bounds stmax and stmin.

				stp = Math.max ( stp , STPMIN );
				stp = Math.min ( stp , STPMAX );

				// If an unusual termination is to occur then let
				// stp be the lowest point obtained so far.

				if ( ( brackt && ( stp <= stmin || stp >= stmax ) ) || nfev >= maxfev - 1 || infoc == 0 || ( brackt && stmax - stmin <= xtol * stmax ) ) stp = stx;

				// Evaluate the function and gradient at stp
				// and compute the directional derivative.
				// We return to main program to obtain F and G.

				for (int j = 0; j < n ; j += 1 )
					x [j] = diag[j] + stp * w[ is0+j];

				info=-1;
				return;
			}

			info=0;
			nfev = nfev + 1;
			dg = 0;

			for (int j = 0 ; j < n ; j += 1 )
			{
				dg = dg + g [ j] * w [ is0+j];
			}

			ftest1 = finit + stp*dgtest;

			// Test for convergence.

			if ( ( brackt && ( stp <= stmin || stp >= stmax ) ) || infoc == 0 ) info = 6;

			if ( stp == STPMAX && f <= ftest1 && dg <= dgtest ) info = 5;

			if ( stp == STPMIN && ( f > ftest1 || dg >= dgtest ) ) info = 4;

			if ( nfev >= maxfev ) info = 3;

			if ( brackt && stmax - stmin <= xtol * stmax ) info = 2;

			if ( f <= ftest1 && Math.abs ( dg ) <= gtol * ( - dginit ) ) info = 1;

			// Check for termination.

			if ( info != 0 ) return;

			// In the first stage we seek a step for which the modified
			// function has a nonpositive value and nonnegative derivative.

			if ( stage1 && f <= ftest1 && dg >= Math.min ( ftol , gtol ) * dginit ) stage1 = false;

			// A modified function is used to predict the step only if
			// we have not obtained a step for which the modified
			// function has a nonpositive function value and nonnegative
			// derivative, and if a lower function value has been
			// obtained but the decrease is not sufficient.

			if ( stage1 && f <= fx[0] && f > ftest1 )
			{
				// Define the modified function and derivative values.

				fm = f - stp*dgtest;
				fxm[0] = fx[0] - stx*dgtest;
				fym[0] = fy[0] - sty*dgtest;
				dgm = dg - dgtest;
				dgxm[0] = dgx[0] - dgtest;
				dgym[0] = dgy[0] - dgtest;

				// Call cstep to update the interval of uncertainty
				// and to compute the new step.

				mcstep (fxm , dgxm , fym , dgym , fm , dgm);

				// Reset the function and gradient values for f.

				fx[0] = fxm[0] + stx*dgtest;
				fy[0] = fym[0] + sty*dgtest;
				dgx[0] = dgxm[0] + dgtest;
				dgy[0] = dgym[0] + dgtest;
			}
			else
			{
				// Call mcstep to update the interval of uncertainty
				// and to compute the new step.

				mcstep (fx , dgx , fy , dgy , f , dg);
			}

			// Force a sufficient decrease in the size of the
			// interval of uncertainty.

			if ( brackt )
			{
				if ( Math.abs ( sty - stx ) >= p66 * width1 )
					stp = stx + p5 * ( sty - stx );
				width1 = width;
				width = Math.abs ( sty - stx );
			}
		}
	}

	/** The purpose of this function is to compute a safeguarded step for
	  * a linesearch and to update an interval of uncertainty for
	  * a minimizer of the function.

* * The parameter stx contains the step with the least function * value. The parameter stp contains the current step. It is * assumed that the derivative at stx is negative in the * direction of the step. If brackt is true * when mcstep returns then a * minimizer has been bracketed in an interval of uncertainty * with endpoints stx and sty.

* * Variables that must be modified by mcstep are * implemented as 1-element arrays. * * @param stx Step at the best step obtained so far. * This variable is modified by mcstep. * @param fx Function value at the best step obtained so far. * This variable is modified by mcstep. * @param dx Derivative at the best step obtained so far. The derivative * must be negative in the direction of the step, that is, dx * and stp-stx must have opposite signs. * This variable is modified by mcstep. * * @param sty Step at the other endpoint of the interval of uncertainty. * This variable is modified by mcstep. * @param fy Function value at the other endpoint of the interval of uncertainty. * This variable is modified by mcstep. * @param dy Derivative at the other endpoint of the interval of * uncertainty. This variable is modified by mcstep. * * @param stp Step at the current step. If brackt is set * then on input stp must be between stx * and sty. On output stp is set to the * new step. * @param fp Function value at the current step. * @param dp Derivative at the current step. * * @param brackt Tells whether a minimizer has been bracketed. * If the minimizer has not been bracketed, then on input this * variable must be set false. If the minimizer has * been bracketed, then on output this variable is true. * * @param stmin Lower bound for the step. * @param stmax Upper bound for the step. * * @param info On return from mcstep, this is set as follows: * If info is 1, 2, 3, or 4, then the step has been * computed successfully. Otherwise info = 0, and this * indicates improper input parameters. * * @author Jorge J. More, David J. Thuente: original Fortran version, * as part of Minpack project. Argonne Nat'l Laboratory, June 1983. * Robert Dodier: Java translation, August 1997. */ private void mcstep (double[] fx , double[] dx , double[] fy , double[] dy , double fp , double dp) { boolean bound; double gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; infoc = 0; if ( ( brackt && ( stp <= Math.min ( stx , sty ) || stp >= Math.max ( stx , sty ) ) ) || dx[0] * ( stp - stx ) >= 0.0 || stmax < stmin ) return; // Determine if the derivatives have opposite sign. sgnd = dp * ( dx[0] / Math.abs ( dx[0] ) ); if ( fp > fx[0] ) { // First case. A higher function value. // The minimum is bracketed. If the cubic step is closer // to stx than the quadratic step, the cubic step is taken, // else the average of the cubic and quadratic steps is taken. infoc = 1; bound = true; theta = 3 * ( fx[0] - fp ) / ( stp - stx ) + dx[0] + dp; s = max3 ( Math.abs ( theta ) , Math.abs ( dx[0] ) , Math.abs ( dp ) ); gamma = s * Math.sqrt ( sqr( theta / s ) - ( dx[0] / s ) * ( dp / s ) ); if ( stp < stx ) gamma = - gamma; p = ( gamma - dx[0] ) + theta; q = ( ( gamma - dx[0] ) + gamma ) + dp; r = p/q; stpc = stx + r * ( stp - stx ); stpq = stx + ( ( dx[0] / ( ( fx[0] - fp ) / ( stp - stx ) + dx[0] ) ) / 2 ) * ( stp - stx ); if ( Math.abs ( stpc - stx ) < Math.abs ( stpq - stx ) ) { stpf = stpc; } else { stpf = stpc + ( stpq - stpc ) / 2; } brackt = true; } else if ( sgnd < 0.0 ) { // Second case. A lower function value and derivatives of // opposite sign. The minimum is bracketed. If the cubic // step is closer to stx than the quadratic (secant) step, // the cubic step is taken, else the quadratic step is taken. infoc = 2; bound = false; theta = 3 * ( fx[0] - fp ) / ( stp - stx ) + dx[0] + dp; s = max3 ( Math.abs ( theta ) , Math.abs ( dx[0] ) , Math.abs ( dp ) ); gamma = s * Math.sqrt ( sqr( theta / s ) - ( dx[0] / s ) * ( dp / s ) ); if ( stp > stx ) gamma = - gamma; p = ( gamma - dp ) + theta; q = ( ( gamma - dp ) + gamma ) + dx[0]; r = p/q; stpc = stp + r * ( stx - stp ); stpq = stp + ( dp / ( dp - dx[0] ) ) * ( stx - stp ); if ( Math.abs ( stpc - stp ) > Math.abs ( stpq - stp ) ) { stpf = stpc; } else { stpf = stpq; } brackt = true; } else if ( Math.abs ( dp ) < Math.abs ( dx[0] ) ) { // Third case. A lower function value, derivatives of the // same sign, and the magnitude of the derivative decreases. // The cubic step is only used if the cubic tends to infinity // in the direction of the step or if the minimum of the cubic // is beyond stp. Otherwise the cubic step is defined to be // either stmin or stmax. The quadratic (secant) step is also // computed and if the minimum is bracketed then the the step // closest to stx is taken, else the step farthest away is taken. infoc = 3; bound = true; theta = 3 * ( fx[0] - fp ) / ( stp - stx ) + dx[0] + dp; s = max3 ( Math.abs ( theta ) , Math.abs ( dx[0] ) , Math.abs ( dp ) ); gamma = s * Math.sqrt ( Math.max ( 0, sqr( theta / s ) - ( dx[0] / s ) * ( dp / s ) ) ); if ( stp > stx ) gamma = - gamma; p = ( gamma - dp ) + theta; q = ( gamma + ( dx[0] - dp ) ) + gamma; r = p/q; if ( r < 0.0 && gamma != 0.0 ) { stpc = stp + r * ( stx - stp ); } else if ( stp > stx ) { stpc = stmax; } else { stpc = stmin; } stpq = stp + ( dp / ( dp - dx[0] ) ) * ( stx - stp ); if ( brackt ) { if ( Math.abs ( stp - stpc ) < Math.abs ( stp - stpq ) ) { stpf = stpc; } else { stpf = stpq; } } else { if ( Math.abs ( stp - stpc ) > Math.abs ( stp - stpq ) ) { stpf = stpc; } else { stpf = stpq; } } } else { // Fourth case. A lower function value, derivatives of the // same sign, and the magnitude of the derivative does // not decrease. If the minimum is not bracketed, the step // is either stmin or stmax, else the cubic step is taken. infoc = 4; bound = false; if ( brackt ) { theta = 3 * ( fp - fy[0] ) / ( sty - stp ) + dy[0] + dp; s = max3 ( Math.abs ( theta ) , Math.abs ( dy[0] ) , Math.abs ( dp ) ); gamma = s * Math.sqrt ( sqr( theta / s ) - ( dy[0] / s ) * ( dp / s ) ); if ( stp > sty ) gamma = - gamma; p = ( gamma - dp ) + theta; q = ( ( gamma - dp ) + gamma ) + dy[0]; r = p/q; stpc = stp + r * ( sty - stp ); stpf = stpc; } else if ( stp > stx ) { stpf = stmax; } else { stpf = stmin; } } // Update the interval of uncertainty. This update does not // depend on the new step or the case analysis above. if ( fp > fx[0] ) { sty = stp; fy[0] = fp; dy[0] = dp; } else { if ( sgnd < 0.0 ) { sty = stx; fy[0] = fx[0]; dy[0] = dx[0]; } stx = stp; fx[0] = fp; dx[0] = dp; } // Compute the new step and safeguard it. stpf = Math.min ( stmax , stpf ); stpf = Math.max ( stmin , stpf ); stp = stpf; if ( brackt && bound ) { if ( sty > stx ) { stp = Math.min ( stx + 0.66 * ( sty - stx ) , stp ); } else { stp = Math.max ( stx + 0.66 * ( sty - stx ) , stp ); } } return; } /** Compute the sum of a vector times a scalar plus another vector. * Adapted from the subroutine daxpy in lbfgs.f. * There could well be faster ways to carry out this operation; this * code is a straight translation from the Fortran. */ public static void daxpy ( int n , double da , double[] dx , int ix0, int incx , double[] dy , int iy0, int incy ) { int i, ix, iy, m, mp1; if ( n <= 0 ) return; if ( da == 0 ) return; if ( ! ( incx == 1 && incy == 1 ) ) { ix = 1; iy = 1; if ( incx < 0 ) ix = ( - n + 1 ) * incx + 1; if ( incy < 0 ) iy = ( - n + 1 ) * incy + 1; for ( i = 1 ; i <= n ; i += 1 ) { dy [ iy0+iy -1] = dy [ iy0+iy -1] + da * dx [ ix0+ix -1]; ix = ix + incx; iy = iy + incy; } return; } m = n % 4; if ( m != 0 ) { for ( i = 1 ; i <= m ; i += 1 ) { dy [ iy0+i -1] = dy [ iy0+i -1] + da * dx [ ix0+i -1]; } if ( n < 4 ) return; } mp1 = m + 1; for ( i = mp1 ; i <= n ; i += 4 ) { dy [ iy0+i -1] = dy [ iy0+i -1] + da * dx [ ix0+i -1]; dy [ iy0+i + 1 -1] = dy [ iy0+i + 1 -1] + da * dx [ ix0+i + 1 -1]; dy [ iy0+i + 2 -1] = dy [ iy0+i + 2 -1] + da * dx [ ix0+i + 2 -1]; dy [ iy0+i + 3 -1] = dy [ iy0+i + 3 -1] + da * dx [ ix0+i + 3 -1]; } return; } /** Compute the dot product of two vectors. * Adapted from the subroutine ddot in lbfgs.f. * There could well be faster ways to carry out this operation; this * code is a straight translation from the Fortran. */ public static double ddot ( int n, double[] dx, int ix0, int incx, double[] dy, int iy0, int incy ) { double dtemp; int i, ix, iy, m, mp1; dtemp = 0; if ( n <= 0 ) return 0; if ( !( incx == 1 && incy == 1 ) ) { ix = 1; iy = 1; if ( incx < 0 ) ix = ( - n + 1 ) * incx + 1; if ( incy < 0 ) iy = ( - n + 1 ) * incy + 1; for ( i = 1 ; i <= n ; i += 1 ) { dtemp = dtemp + dx [ ix0+ix -1] * dy [ iy0+iy -1]; ix = ix + incx; iy = iy + incy; } return dtemp; } m = n % 5; if ( m != 0 ) { for ( i = 1 ; i <= m ; i += 1 ) { dtemp = dtemp + dx [ ix0+i -1] * dy [ iy0+i -1]; } if ( n < 5 ) return dtemp; } mp1 = m + 1; for ( i = mp1 ; i <= n ; i += 5 ) { dtemp = dtemp + dx [ ix0+i -1] * dy [ iy0+i -1] + dx [ ix0+i + 1 -1] * dy [ iy0+i + 1 -1] + dx [ ix0+i + 2 -1] * dy [ iy0+i + 2 -1] + dx [ ix0+i + 3 -1] * dy [ iy0+i + 3 -1] + dx [ ix0+i + 4 -1] * dy [ iy0+i + 4 -1]; } return dtemp; } static double sqr( double x ) { return x*x; } static double max3( double x, double y, double z ) { return x < y ? ( y < z ? z : y ) : ( x < z ? z : x ); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy