uk.ac.sussex.gdsc.smlm.fitting.nonlinear.SteppingFunctionSolver Maven / Gradle / Ivy
Show all versions of gdsc-smlm Show documentation
/*-
* #%L
* Genome Damage and Stability Centre SMLM Package
*
* Software for single molecule localisation microscopy (SMLM)
* %%
* Copyright (C) 2011 - 2023 Alex Herbert
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* .
* #L%
*/
package uk.ac.sussex.gdsc.smlm.fitting.nonlinear;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import uk.ac.sussex.gdsc.core.utils.BitFlagUtils;
import uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix;
import uk.ac.sussex.gdsc.smlm.fitting.FitStatus;
import uk.ac.sussex.gdsc.smlm.fitting.FunctionSolverType;
import uk.ac.sussex.gdsc.smlm.function.Gradient1Function;
import uk.ac.sussex.gdsc.smlm.function.Gradient1FunctionStore;
import uk.ac.sussex.gdsc.smlm.function.GradientFunction;
import uk.ac.sussex.gdsc.smlm.function.ValueFunction;
import uk.ac.sussex.gdsc.smlm.function.ValueProcedure;
/**
* Abstract class for FunctionSolvers that use update steps to the current parameters.
*/
public abstract class SteppingFunctionSolver extends BaseFunctionSolver {
private static Logger logger = Logger.getLogger(SteppingFunctionSolver.class.getName());
/** The trace level for debugging the fit computation. */
private static Level traceLevel = Level.FINEST;
/** The gradient indices. */
protected int[] gradientIndices;
/** The tolerance checker. */
protected final ToleranceChecker tc;
/** The bounds. */
protected final ParameterBounds bounds;
private double[] weights;
/**
* Simple class to allow the values to be computed.
*/
private static class SimpleValueProcedure implements ValueProcedure {
int index;
double[] fx;
SimpleValueProcedure(double[] fx) {
this.fx = fx;
}
@Override
public void execute(double value) {
fx[index++] = value;
}
}
/**
* Create a new stepping function solver.
*
* @param type the type
* @param function the function
* @throws NullPointerException if the function is null
*/
public SteppingFunctionSolver(FunctionSolverType type, Gradient1Function function) {
this(type, function, new ToleranceChecker(1e-3, 1e-6), null);
}
/**
* Create a new stepping function solver.
*
* @param type the type
* @param function the function
* @param tc the tolerance checker
* @param bounds the bounds
* @throws NullPointerException if the function or tolerance checker is null
* @throws IllegalArgumentException if the bounds are not constructed with the same gradient
* function
*/
public SteppingFunctionSolver(FunctionSolverType type, Gradient1Function function,
ToleranceChecker tc, ParameterBounds bounds) {
super(type, function);
this.tc = Objects.requireNonNull(tc, "tolerance checker");
if (bounds == null) {
bounds = ParameterBounds.create(function);
} else if (bounds.getGradientFunction() != function) {
throw new IllegalArgumentException(
"Bounds must be constructed with the same gradient function");
}
this.bounds = bounds;
}
/**
* Compute fit.
*
* @param y the y
* @param fx the fx
* @param a the a
* @param parametersVariance the parametersVariance
* @return the fit status
*/
@Override
protected FitStatus computeFit(double[] y, double[] fx, double[] a, double[] parametersVariance) {
// Lay out a simple iteration loop for a stepping solver.
// The sub-class must compute the next step.
// This class handles attenuation of the step.
// The sub-class determines if the step is accepted or rejected.
gradientIndices = function.gradientIndices();
final double[] step = new double[gradientIndices.length];
final double[] newA = a.clone();
// Initialise for fitting
bounds.initialise();
tc.reset();
final String name = this.getClass().getSimpleName();
try {
lastY = prepareFitValue(y, a);
// First evaluation
double currentValue = computeFitValue(a);
if (logger.isLoggable(traceLevel)) {
log("%s Value [%s] = %s : %s", name, tc.getIterations(), currentValue, a);
}
int status = 0;
for (;;) {
// Compute next step
computeStep(step);
if (logger.isLoggable(traceLevel)) {
log("%s Step [%s] = %s", name, tc.getIterations(), step);
}
// Apply bounds to the step
bounds.applyBounds(a, step, newA);
// Evaluate
final double newValue = computeFitValue(newA);
if (logger.isLoggable(traceLevel)) {
log("%s Value [%s] = %s : %s", name, tc.getIterations(), newValue, newA);
}
// Check stopping criteria
status = tc.converged(currentValue, a, newValue, newA);
if (logger.isLoggable(traceLevel)) {
log("%s Status [%s] = %s", name, tc.getIterations(), status);
}
if (status != 0) {
value = newValue;
System.arraycopy(newA, 0, a, 0, a.length);
break;
}
// Check if the step was an improvement
if (accept(currentValue, a, newValue, newA)) {
if (logger.isLoggable(traceLevel)) {
log("%s Accepted [%s]", name, tc.getIterations());
}
currentValue = newValue;
System.arraycopy(newA, 0, a, 0, a.length);
bounds.accepted(a, newA);
}
}
if (logger.isLoggable(traceLevel)) {
log("%s End [%s] = %s", name, tc.getIterations(), status);
}
if (BitFlagUtils.anySet(status, ToleranceChecker.STATUS_CONVERGED)) {
if (logger.isLoggable(traceLevel)) {
log("%s Converged [%s]", name, tc.getIterations());
}
// A solver may compute both at the same time...
if (parametersVariance != null) {
computeDeviationsAndValues(parametersVariance, fx);
} else if (fx != null) {
computeValues(fx);
}
return FitStatus.OK;
}
// Check the iterations
if (BitFlagUtils.areSet(status, ToleranceChecker.STATUS_MAX_ITERATIONS)) {
return FitStatus.TOO_MANY_ITERATIONS;
}
// We should not reach here unless we missed something
return FitStatus.FAILED_TO_CONVERGE;
} catch (final FunctionSolverException ex) {
// Debugging
final String msg = ex.getMessage();
logger.log(Level.FINE, () -> String.format("%s failed: %s%s", getClass().getSimpleName(),
ex.fitStatus.getName(), (msg == null) ? "" : " - " + msg));
return ex.fitStatus;
} finally {
iterations = tc.getIterations();
// Allow subclasses to increment this
if (evaluations == 0) {
evaluations = iterations;
}
}
}
/**
* Log progress from the solver.
*
* @param format the format
* @param args the arguments
*/
private static void log(String format, Object... args) {
// Convert arrays to a single string
for (int i = 0; i < args.length; i++) {
if (args[i] instanceof double[]) {
args[i] = java.util.Arrays.toString((double[]) args[i]);
}
}
logger.log(traceLevel, () -> String.format(format, args));
}
/**
* Prepare y for fitting, e.g. ensure strictly positive values.
*
* @param y the y
* @param a the parameters
* @return the new y
*/
protected abstract double[] prepareFitValue(double[] y, double[] a);
/**
* Compute the fit value using the parameters. The y data is the same as that passed to
* {@link #prepareFitValue(double[], double[])}.
*
* This method is followed by a call to {@link #computeStep(double[])} so the step could be
* pre-computed here.
*
* @param a the parameters
* @return the fit value
*/
protected abstract double computeFitValue(double[] a);
/**
* Compute the update step for the current parameters.
*
* @param step the step
*/
protected abstract void computeStep(double[] step);
/**
* Determine if the step should be accepted. If accepted then the current parameters and function
* value are updated and any bounds on the step size may be updated.
*
*
Note that although this class handles convergence on the value/parameters it is left to the
* sub-class to determine if each step should be accepted.
*
* @param currentValue the current value
* @param a the current parameters
* @param newValue the new value
* @param newA the new parameters
* @return true, if successful
*/
protected abstract boolean accept(double currentValue, double[] a, double newValue,
double[] newA);
/**
* Compute the deviations for the parameters a from the last call to
* {@link #computeFitValue(double[])}. Optionally store the function values.
*
* @param parametersVariance the parameter deviations
* @param fx the function values f(x) (may be null)
*/
protected void computeDeviationsAndValues(double[] parametersVariance, double[] fx) {
// Use a dedicated solver optimised for inverting the matrix diagonal.
// The last Hessian matrix should be stored in the working alpha.
final FisherInformationMatrix m = computeLastFisherInformationMatrix(fx);
setDeviations(parametersVariance, m);
}
/**
* Compute the Fisher Information matrix for the parameters a from the last call to
* {@link #computeFitValue(double[])}. This can be used to set the covariances for each of the
* fitted parameters.
*
*
Alternatively a sub-class can override
* {@link #computeDeviationsAndValues(double[], double[])} directly and provide a dummy
* implementation of this function as it will not be used, e.g. throw an exception.
*
* @param fx the function values f(x) (may be null)
* @return the Fisher Information matrix
*/
protected abstract FisherInformationMatrix computeLastFisherInformationMatrix(double[] fx);
/**
* Compute the function y-values using the y and parameters a from the last call to
* {@link #computeFitValue(double[])}.
*
*
Utility method to compute the function values using the preinitialised function. Sub-classes
* may override this if they have cached the function values from the last execution of a forEach
* procedure.
*
*
The base gradient function is used. If sub-classes wrap the function (e.g. with
* per-observation weights) then these will be omitted.
*
* @param fx the function values f(x)
*/
protected void computeValues(double[] fx) {
final ValueFunction function = (ValueFunction) this.function;
function.forEach(new SimpleValueProcedure(fx));
}
@Override
protected boolean computeValue(double[] y, double[] fx, double[] a) {
// If the fx array is not null then wrap the gradient function.
// Compute the value and the wrapper will store the values appropriately.
// Then reset the gradient function.
// Note: If a sub class wraps the function with weights
// then the weights will not be stored in the function value.
// Only the value produced by the original function is stored:
// Wrapped (+weights) < FunctionStore < Function
// However if the base function is already wrapped then this will occur:
// Wrapped (+weights) < FunctionStore < Wrapped (+precomputed) < Function
gradientIndices = function.gradientIndices();
if (fx != null && fx.length == ((Gradient1Function) function).size()) {
final GradientFunction tmp = function;
function = new Gradient1FunctionStore((Gradient1Function) function, fx, null);
lastY = prepareFunctionValue(y, a);
value = computeFunctionValue(a);
function = tmp;
} else {
lastY = prepareFunctionValue(y, a);
value = computeFunctionValue(a);
}
return true;
}
/**
* Prepare y for computing the function value, e.g. ensure strictly positive values.
*
* @param y the y
* @param a the parameters
* @return the new y
*/
protected abstract double[] prepareFunctionValue(double[] y, double[] a);
/**
* Compute the function value. The y data is the same as that passed to
* {@link #prepareFunctionValue(double[], double[])}
*
* @param a the parameters
* @return the function value
*/
protected abstract double computeFunctionValue(double[] a);
@Override
protected FisherInformationMatrix computeFisherInformationMatrix(double[] y, double[] a) {
gradientIndices = function.gradientIndices();
y = prepareFunctionFisherInformationMatrix(y, a);
return computeFunctionFisherInformationMatrix(y, a);
}
/**
* Prepare y for computing the Fisher information matrix, e.g. ensure strictly positive values.
*
* @param y the y
* @param a the parameters
* @return the new y
*/
protected abstract double[] prepareFunctionFisherInformationMatrix(double[] y, double[] a);
/**
* Compute the Fisher information matrix.
*
* @param y the y
* @param a the parameters
* @return the Fisher Information matrix
*/
protected abstract FisherInformationMatrix computeFunctionFisherInformationMatrix(double[] y,
double[] a);
@Override
public boolean isBounded() {
// Bounds are tighter than constraints and we support those
return true;
}
@Override
public void setBounds(double[] lower, double[] upper) {
bounds.setBounds(lower, upper);
}
/**
* Warning: If the function is changed then the clamp values may require updating. However setting
* a new function does not set the clamp values to null to allow caching when the clamp values are
* unchanged, e.g. evaluation of a different function in the same parameter space.
*
*
Setting a new function removes the current bounds.
*
* @param function the new gradient function
*/
@Override
public void setGradientFunction(GradientFunction function) {
super.setGradientFunction(function);
bounds.setGradientFunction(function);
}
@Override
public boolean isWeighted() {
return true;
}
@Override
public void setWeights(double[] weights) {
this.weights = weights;
}
/**
* Gets the weights for observations of size n, e.g. the per observation variance term.
*
* @param n the size
* @return the weights
*/
public double[] getWeights(int n) {
return (weights == null || weights.length != n) ? null : weights;
}
}