org.apache.mahout.math.solver.LSMR Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-math Show documentation
Show all versions of mahout-math Show documentation
High performance scientific and technical computing data structures and methods,
mostly based on CERN's
Colt Java API
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.solver;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Solves sparse least-squares using the LSMR algorithm.
*
* LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves
* the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where
* all cases are allowed: m=n, m>n, or m<n. B is a vector of length m. The matrix A may be dense
* or sparse (usually sparse).
*
* Some additional configurable properties adjust the behavior of the algorithm.
*
* If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min
* ||(B) - ( A )X|| ||(0) (lambda*I) ||_2 where LAMBDA is a scalar. If LAMBDA is not set,
* the system is solved without regularization.
*
* You can also set aTolerance and bTolerance. These cause LSMR to iterate until a certain backward
* error estimate is smaller than some quantity depending on ATOL and BTOL. Let RES = B - A*X be
* the residual vector for the current approximate solution X. If A*X = B seems to be consistent,
* LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates
* when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final
* NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct
* digits, depending on cond(A) and the size of LAMBDA.)
*
* The default value for ATOL and BTOL is 1e-6.
*
* Ideally, they should be estimates of the relative error in the entries of A and B respectively.
* For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the
* algorithm from doing unnecessary work beyond the uncertainty of the input data.
*
* You can also set conditionLimit. In that case, LSMR terminates if an estimate of cond(A) exceeds
* conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say).
* For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not
* set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance =
* bTolerance = conditionLimit = 0, but the number of iterations may then be excessive.
*
* Setting iterationLimit causes LSMR to terminate if the number of iterations reaches
* iterationLimit. The default is iterationLimit = min(m,n). For ill-conditioned systems, a
* larger value of ITNLIM may be needed.
*
* Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's.
* (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without
* reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's.
* Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in
* the SIAM paper.
*
* getTerminationReason() gives the reason for termination. ISTOP = 0 means X=0 is a solution. = 1
* means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X
* approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be
* greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL
* = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other
* stopping conditions were satisfied.
*
* getIterationCount() gives ITN = the number of LSMR iterations.
*
* getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X).
*
* getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR =
* NORM(A'*(B-A*X)).
*
* getANorm() gives an estimate of the Frobenius norm of A.
*
* getCondition() gives an estimate of the condition number of A.
*
* getXNorm() gives an estimate of NORM(X).
*
* LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders
* LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to
* SISC.
*
* David Chin-lung Fong [email protected] Institute for Computational and Mathematical
* Engineering Stanford University
*
* Michael Saunders [email protected] Systems Optimization Laboratory Dept of
* MS&E, Stanford University. -----------------------------------------------------------------------
*/
public final class LSMR {
private static final Logger log = LoggerFactory.getLogger(LSMR.class);
private final double lambda;
private int localSize;
private int iterationLimit;
private double conditionLimit;
private double bTolerance;
private double aTolerance;
private int localPointer;
private Vector[] localV;
private double residualNorm;
private double normalEquationResidual;
private double xNorm;
private int iteration;
private double normA;
private double condA;
public int getIterationCount() {
return iteration;
}
public double getResidualNorm() {
return residualNorm;
}
public double getNormalEquationResidual() {
return normalEquationResidual;
}
public double getANorm() {
return normA;
}
public double getCondition() {
return condA;
}
public double getXNorm() {
return xNorm;
}
/**
* LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L.
* Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr
* 2010, to be submitted to SISC.
*
* 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default
* parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local
* reorthogonalization (full reorthogonalization is also implemented)
*
* David Chin-lung Fong [email protected] Institute for Computational and
* Mathematical Engineering Stanford University
*
* Michael Saunders [email protected] Systems Optimization Laboratory Dept of
* MS&E, Stanford University. -----------------------------------------------------------------------
*/
public LSMR() {
// Set default parameters.
lambda = 0;
aTolerance = 1.0e-6;
bTolerance = 1.0e-6;
conditionLimit = 1.0e8;
iterationLimit = -1;
localSize = 0;
}
public Vector solve(Matrix A, Vector b) {
/*
% Initialize.
hdg1 = ' itn x(1) norm r norm A''r';
hdg2 = ' compatible LS norm A cond A';
pfreq = 20; % print frequency (for repeating the heading)
pcount = 0; % print counter
% Determine dimensions m and n, and
% form the first vectors u and v.
% These satisfy beta*u = b, alpha*v = A'u.
*/
log.debug(" itn x(1) norm r norm A'r");
log.debug(" compatible LS norm A cond A");
Matrix transposedA = A.transpose();
Vector u = b;
double beta = u.norm(2);
if (beta > 0) {
u = u.divide(beta);
}
Vector v = transposedA.times(u);
int m = A.numRows();
int n = A.numCols();
int minDim = Math.min(m, n);
if (iterationLimit == -1) {
iterationLimit = minDim;
}
if (log.isDebugEnabled()) {
log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, "
+ "Mahout version {}", getClass().getPackage().getImplementationVersion());
log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g",
m, n, lambda, aTolerance, bTolerance));
}
double alpha = v.norm(2);
if (alpha > 0) {
v.assign(Functions.div(alpha));
}
// Initialization for local reorthogonalization
localPointer = 0;
// Preallocate storage for storing the last few v_k. Since with
// orthogonal v_k's, Krylov subspace method would converge in not
// more iterations than the number of singular values, more
// space is not necessary.
localV = new Vector[Math.min(localSize, minDim)];
boolean localOrtho = false;
if (localSize > 0) {
localOrtho = true;
localV[0] = v;
}
// Initialize variables for 1st iteration.
iteration = 0;
double zetabar = alpha * beta;
double alphabar = alpha;
Vector h = v;
Vector hbar = zeros(n);
Vector x = zeros(n);
// Initialize variables for estimation of ||r||.
double betadd = beta;
// Initialize variables for estimation of ||A|| and cond(A)
double aNorm = alpha * alpha;
// Items for use in stopping rules.
double normb = beta;
double ctol = 0;
if (conditionLimit > 0) {
ctol = 1 / conditionLimit;
}
residualNorm = beta;
// Exit if b=0 or A'b = 0.
normalEquationResidual = alpha * beta;
if (normalEquationResidual == 0) {
return x;
}
// Heading for iteration log.
if (log.isDebugEnabled()) {
double test2 = alpha / beta;
// log.debug('{} {}', hdg1, hdg2);
log.debug("{} {}", iteration, x.get(0));
log.debug("{} {}", residualNorm, normalEquationResidual);
double test1 = 1;
log.debug("{} {}", test1, test2);
}
//------------------------------------------------------------------
// Main iteration loop.
//------------------------------------------------------------------
double rho = 1;
double rhobar = 1;
double cbar = 1;
double sbar = 0;
double betad = 0;
double rhodold = 1;
double tautildeold = 0;
double thetatilde = 0;
double zeta = 0;
double d = 0;
double maxrbar = 0;
double minrbar = 1.0e+100;
StopCode stop = StopCode.CONTINUE;
while (iteration <= iterationLimit && stop == StopCode.CONTINUE) {
iteration++;
// Perform the next step of the bidiagonalization to obtain the
// next beta, u, alpha, v. These satisfy the relations
// beta*u = A*v - alpha*u,
// alpha*v = A'*u - beta*v.
u = A.times(v).minus(u.times(alpha));
beta = u.norm(2);
if (beta > 0) {
u.assign(Functions.div(beta));
// store data for local-reorthogonalization of V
if (localOrtho) {
localVEnqueue(v);
}
v = transposedA.times(u).minus(v.times(beta));
// local-reorthogonalization of V
if (localOrtho) {
v = localVOrtho(v);
}
alpha = v.norm(2);
if (alpha > 0) {
v.assign(Functions.div(alpha));
}
}
// At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
// Construct rotation Qhat_{k,2k+1}.
double alphahat = Math.hypot(alphabar, lambda);
double chat = alphabar / alphahat;
double shat = lambda / alphahat;
// Use a plane rotation (Q_i) to turn B_i to R_i
double rhoold = rho;
rho = Math.hypot(alphahat, beta);
double c = alphahat / rho;
double s = beta / rho;
double thetanew = s * alpha;
alphabar = c * alpha;
// Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
double rhobarold = rhobar;
double zetaold = zeta;
double thetabar = sbar * rho;
double rhotemp = cbar * rho;
rhobar = Math.hypot(cbar * rho, thetanew);
cbar = cbar * rho / rhobar;
sbar = thetanew / rhobar;
zeta = cbar * zetabar;
zetabar = -sbar * zetabar;
// Update h, h_hat, x.
hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold)));
x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS);
h = v.minus(h.times(thetanew / rho));
// Estimate of ||r||.
// Apply rotation Qhat_{k,2k+1}.
double betaacute = chat * betadd;
double betacheck = -shat * betadd;
// Apply rotation Q_{k,k+1}.
double betahat = c * betaacute;
betadd = -s * betaacute;
// Apply rotation Qtilde_{k-1}.
// betad = betad_{k-1} here.
double thetatildeold = thetatilde;
double rhotildeold = Math.hypot(rhodold, thetabar);
double ctildeold = rhodold / rhotildeold;
double stildeold = thetabar / rhotildeold;
thetatilde = stildeold * rhobar;
rhodold = ctildeold * rhobar;
betad = -stildeold * betad + ctildeold * betahat;
// betad = betad_k here.
// rhodold = rhod_k here.
tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold;
double taud = (zeta - thetatilde * tautildeold) / rhodold;
d += betacheck * betacheck;
residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd);
// Estimate ||A||.
aNorm += beta * beta;
normA = Math.sqrt(aNorm);
aNorm += alpha * alpha;
// Estimate cond(A).
maxrbar = Math.max(maxrbar, rhobarold);
if (iteration > 1) {
minrbar = Math.min(minrbar, rhobarold);
}
condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp);
// Test for convergence.
// Compute norms for convergence testing.
normalEquationResidual = Math.abs(zetabar);
xNorm = x.norm(2);
// Now use these norms to estimate certain other quantities,
// some of which will be small near a solution.
double test1 = residualNorm / normb;
double test2 = normalEquationResidual / (normA * residualNorm);
double test3 = 1 / condA;
double t1 = test1 / (1 + normA * xNorm / normb);
double rtol = bTolerance + aTolerance * normA * xNorm / normb;
// The following tests guard against extremely small values of
// atol, btol or ctol. (The user may have set any or all of
// the parameters atol, btol, conlim to 0.)
// The effect is equivalent to the normAl tests using
// atol = eps, btol = eps, conlim = 1/eps.
if (iteration > iterationLimit) {
stop = StopCode.ITERATION_LIMIT;
}
if (1 + test3 <= 1) {
stop = StopCode.CONDITION_MACHINE_TOLERANCE;
}
if (1 + test2 <= 1) {
stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE;
}
if (1 + t1 <= 1) {
stop = StopCode.CONVERGED_MACHINE_TOLERANCE;
}
// Allow for tolerances set by the user.
if (test3 <= ctol) {
stop = StopCode.CONDITION;
}
if (test2 <= aTolerance) {
stop = StopCode.CONVERGED;
}
if (test1 <= rtol) {
stop = StopCode.TRIVIAL;
}
// See if it is time to print something.
if (log.isDebugEnabled()) {
if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0)
|| (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol)
|| (stop != StopCode.CONTINUE)) {
statusDump(x, normA, condA, test1, test2);
}
}
} // iteration loop
// Print the stopping condition.
log.debug("Finished: {}", stop.getMessage());
return x;
/*
if show
fprintf('\n\nLSMR finished')
fprintf('\n%s', msg(istop+1,:))
fprintf('\nistop =%8g normr =%8.1e' , istop, normr )
fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr)
fprintf('\nitn =%8g condA =%8.1e' , itn , condA )
fprintf(' normx =%8.1e\n', normx)
end
*/
}
private void statusDump(Vector x, double normA, double condA, double test1, double test2) {
log.debug("{} {}", residualNorm, normalEquationResidual);
log.debug("{} {}", iteration, x.get(0));
log.debug("{} {}", test1, test2);
log.debug("{} {}", normA, condA);
}
private static Vector zeros(int n) {
return new DenseVector(n);
}
//-----------------------------------------------------------------------
// stores v into the circular buffer localV
//-----------------------------------------------------------------------
private void localVEnqueue(Vector v) {
if (localV.length > 0) {
localV[localPointer] = v;
localPointer = (localPointer + 1) % localV.length;
}
}
//-----------------------------------------------------------------------
// Perform local reorthogonalization of V
//-----------------------------------------------------------------------
private Vector localVOrtho(Vector v) {
for (Vector old : localV) {
if (old != null) {
double x = v.dot(old);
v = v.minus(old.times(x));
}
}
return v;
}
private enum StopCode {
CONTINUE("Not done"),
TRIVIAL("The exact solution is x = 0"),
CONVERGED("Ax - b is small enough, given atol, btol"),
LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"),
CONDITION("The estimate of cond(Abar) has exceeded condition limit"),
CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"),
LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"),
CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"),
ITERATION_LIMIT("The iteration limit has been reached");
private final String message;
StopCode(String message) {
this.message = message;
}
public String getMessage() {
return message;
}
}
public void setAtolerance(double aTolerance) {
this.aTolerance = aTolerance;
}
public void setBtolerance(double bTolerance) {
this.bTolerance = bTolerance;
}
public void setConditionLimit(double conditionLimit) {
this.conditionLimit = conditionLimit;
}
public void setIterationLimit(int iterationLimit) {
this.iterationLimit = iterationLimit;
}
public void setLocalSize(int localSize) {
this.localSize = localSize;
}
public double getLambda() {
return lambda;
}
public double getAtolerance() {
return aTolerance;
}
public double getBtolerance() {
return bTolerance;
}
}