net.finmath.optimizer.StochasticPathwiseLevenbergMarquardt Maven / Gradle / Ivy
Show all versions of finmath-lib Show documentation
/*
* (c) Copyright Christian P. Fries, Germany. Contact: [email protected].
*
* Created on 16.06.2006
*/
package net.finmath.optimizer;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;
/**
* This class implements a stochastic Levenberg Marquardt non-linear least-squares fit
* algorithm.
*
* The design avoids the need to define the objective function as a
* separate class. The objective function is defined by overriding a class
* method, see the sample code below.
*
*
*
* The Levenberg-Marquardt solver is implemented in using multi-threading.
* The calculation of the derivatives (in case a specific implementation of
* {@code setDerivatives(RandomVariable[] parameters, RandomVariable[][] derivatives)} is not
* provided) may be performed in parallel by setting the parameter numberOfThreads
.
*
*
*
* To use the solver inherit from it and implement the objective function as
* {@code setValues(RandomVariable[] parameters, RandomVariable[] values)} where values has
* to be set to the value of the objective functions for the given parameters.
*
* You may also provide an a derivative for your objective function by
* additionally overriding the function {@code setDerivatives(RandomVariable[] parameters, RandomVariable[][] derivatives)},
* otherwise the solver will calculate the derivative via finite differences.
*
*
* To reject a point, it is allowed to set an element of values
to {@link java.lang.Double#NaN}
* in the implementation of {@code setValues(RandomVariable[] parameters, RandomVariable[] values)}.
* Put differently: The solver handles NaN values in values
as an error larger than
* the current one (regardless of the current error) and rejects the point.
*
* Note, however, that is is an error if the initial parameter guess results in an NaN value.
* That is, the solver should be initialized with an initial parameter in an admissible region.
*
*
* The following simple example finds a solution for the equation
*
*
* Sample linear system of equations.
*
* 0.0 * x1 + 1.0 * x2 = 5.0
*
*
* 2.0 * x1 + 1.0 * x2 = 10.0
*
*
*
*
*
*
* LevenbergMarquardt optimizer = new LevenbergMarquardt() {
* // Override your objective function here
* public void setValues(RandomVariable[] parameters, RandomVariable[] values) {
* values[0] = parameters[0] * 0.0 + parameters[1];
* values[1] = parameters[0] * 2.0 + parameters[1];
* }
* };
*
* // Set solver parameters
* optimizer.setInitialParameters(new RandomVariable[] { 0, 0 });
* optimizer.setWeights(new RandomVariable[] { 1, 1 });
* optimizer.setMaxIteration(100);
* optimizer.setTargetValues(new RandomVariable[] { 5, 10 });
*
* optimizer.run();
*
* RandomVariable[] bestParameters = optimizer.getBestFitParameters();
*
*
*
* See the example in the main method below.
*
*
* The class can be initialized to use a multi-threaded valuation. If initialized
* this way the implementation of setValues
must be thread-safe.
* The solver will evaluate the gradient of the value vector in parallel, i.e.,
* use as many threads as the number of parameters.
*
*
* Note: Iteration steps will be logged (java.util.logging) with LogLevel.FINE
*
* @author Christian Fries
* @version 1.6
*/
public abstract class StochasticPathwiseLevenbergMarquardt implements Serializable, Cloneable, StochasticOptimizer {
private static final long serialVersionUID = 4560864869394838155L;
private RandomVariable[] initialParameters = null;
private RandomVariable[] parameterSteps = null;
private RandomVariable[] targetValues = null;
private RandomVariable[] weights = null;
private int maxIteration;
// Local state of the solver
private double[] lambda;
private double lambdaInitialValue = 0.001;
private double lambdaDivisor = 1.3;
private double lambdaMultiplicator = 2.0;
private int numberOfPaths;
private RandomVariable errorTolerance;
private int iteration = 0;
private RandomVariable[] parameterTest = null;
private RandomVariable[] valueTest = null;
private RandomVariable[] parameterCurrent = null;
private RandomVariable[] valueCurrent = null;
private RandomVariable[][] derivativeCurrent = null;
private RandomVariable errorMeanSquaredCurrent = new RandomVariableFromDoubleArray(Double.POSITIVE_INFINITY);
private RandomVariable errorRootMeanSquaredChange = new RandomVariableFromDoubleArray(Double.POSITIVE_INFINITY);
private boolean[] isParameterCurrentDerivativeValid;
/*
* Used for multi-threadded calculation of the derivative.
* The use may provide its own executor. If not and numberOfThreads > 1
* we will temporarily create an executor with the specified number of threads.
* Note: If an executor was provided upon construction, it will not receive a shutdown when done.
*/
private ExecutorService executor = null;
private boolean executorShutdownWhenDone = true;
private final Logger logger = Logger.getLogger("net.finmath");
// A simple test
public static void main(String[] args) throws SolverException {
// RandomVariableDifferentiableAAD is possible here!
// RandomVariable[] initialParameters = new RandomVariable[] { new RandomVariableDifferentiableAAD(2), new RandomVariableDifferentiableAAD(2) };
RandomVariable[] initialParameters = new RandomVariable[] { new RandomVariableFromDoubleArray(2), new RandomVariableFromDoubleArray(2) };
RandomVariable[] weights = new RandomVariable[] { new RandomVariableFromDoubleArray(1), new RandomVariableFromDoubleArray(1) };
RandomVariable[] parameterSteps = new RandomVariable[] { new RandomVariableFromDoubleArray(1), new RandomVariableFromDoubleArray(1) };
int maxIteration = 100;
RandomVariable[] targetValues = new RandomVariable[] { new RandomVariableFromDoubleArray(25), new RandomVariableFromDoubleArray(100) };
StochasticPathwiseLevenbergMarquardt optimizer = new StochasticPathwiseLevenbergMarquardt(initialParameters, targetValues, weights, parameterSteps, maxIteration, null, null) {
private static final long serialVersionUID = -282626938650139518L;
// Override your objective function here
@Override
public void setValues(RandomVariable[] parameters, RandomVariable[] values) {
values[0] = parameters[0].mult(0.0).add(parameters[1]).squared();
values[1] = parameters[0].mult(2.0).add(parameters[1]).squared();
}
};
// Set solver parameters
optimizer.run();
RandomVariable[] bestParameters = optimizer.getBestFitParameters();
System.out.println("The solver for problem 1 required " + optimizer.getIterations() + " iterations. The best fit parameters are:");
for (int i = 0; i < bestParameters.length; i++) {
System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
}
System.out.println("The solver accuracy is " + optimizer.getRootMeanSquaredError());
/*
* Creating a clone, continuing the search with new target values.
* Note that we do not re-define the setValues method.
*/
// Optimizer optimizer2 = optimizer.getCloneWithModifiedTargetValues(new double[] { 5.1, 10.2 }, new double[] { 1, 1 }, true);
// optimizer2.run();
// double[] bestParameters2 = optimizer2.getBestFitParameters();
// System.out.println("The solver for problem 2 required " + optimizer2.getIterations() + " iterations. The best fit parameters are:");
// for (int i = 0; i < bestParameters2.length; i++) System.out.println("\tparameter[" + i + "]: " + bestParameters2[i]);
}
/**
* Create a Levenberg-Marquardt solver.
*
* @param initialParameters Initial value for the parameters where the solver starts its search.
* @param targetValues Target values to achieve.
* @param weights Weights applied to the error.
* @param parameterSteps Step used for finite difference approximation.
* @param maxIteration Maximum number of iterations.
* @param errorTolerance Error tolerance / accuracy.
* @param executorService Executor to be used for concurrent valuation of the derivatives. This is only performed if setDerivative is not overwritten. Warning: The implementation of setValues has to be thread safe!
*/
public StochasticPathwiseLevenbergMarquardt(RandomVariable[] initialParameters, RandomVariable[] targetValues, RandomVariable[] weights, RandomVariable[] parameterSteps, int maxIteration, RandomVariable errorTolerance, ExecutorService executorService) {
super();
this.initialParameters = initialParameters;
this.targetValues = targetValues;
this.weights = weights;
this.parameterSteps = parameterSteps;
this.maxIteration = maxIteration;
this.errorTolerance = errorTolerance != null ? errorTolerance : new RandomVariableFromDoubleArray(0.0);
if(weights == null) {
this.weights = new RandomVariable[targetValues.length];
for(int i=0; iWarning: If this number is larger than one, the implementation of setValues has to be thread safe!
*/
public StochasticPathwiseLevenbergMarquardt(RandomVariable[] initialParameters, RandomVariable[] targetValues, int maxIteration, int numberOfThreads) {
this(initialParameters, targetValues, null, null, maxIteration, null, numberOfThreads > 1 ? Executors.newFixedThreadPool(numberOfThreads) : null);
}
/**
* Create a Levenberg-Marquardt solver.
*
* @param initialParameters List of initial values for the parameters where the solver starts its search.
* @param targetValues List of target values to achieve.
* @param maxIteration Maximum number of iterations.
* @param executorService Executor to be used for concurrent valuation of the derivatives. This is only performed if setDerivative is not overwritten. Warning: The implementation of setValues has to be thread safe!
*/
public StochasticPathwiseLevenbergMarquardt(List initialParameters, List targetValues, int maxIteration, ExecutorService executorService) {
this(numberListToDoubleArray(initialParameters), numberListToDoubleArray(targetValues), null, null, maxIteration, null, executorService);
}
/**
* Create a Levenberg-Marquardt solver.
*
* @param initialParameters Initial value for the parameters where the solver starts its search.
* @param targetValues Target values to achieve.
* @param maxIteration Maximum number of iterations.
* @param numberOfThreads Maximum number of threads. Warning: If this number is larger than one, the implementation of setValues has to be thread safe!
*/
public StochasticPathwiseLevenbergMarquardt(List initialParameters, List targetValues, int maxIteration, int numberOfThreads) {
this(numberListToDoubleArray(initialParameters), numberListToDoubleArray(targetValues), maxIteration, numberOfThreads);
}
/**
* Convert a list of numbers to an array of doubles.
*
* @param listOfNumbers A list of numbers.
* @return A corresponding array of doubles executing doubleValue()
on each element.
*/
private static RandomVariable[] numberListToDoubleArray(List listOfNumbers) {
RandomVariable[] array = new RandomVariable[listOfNumbers.size()];
for(int i=0; i 1.");
}
this.lambdaMultiplicator = lambdaMultiplicator;
}
/**
* Get the divisor applied to lambda (for the next iteration) if the inversion of regularized
* Hessian succeeds, that is, if \( H + \lambda \diag H \) is invertable.
*
* @return the lambdaDivisor
*/
public double getLambdaDivisor() {
return lambdaDivisor;
}
/**
* Set the divisor applied to lambda (for the next iteration) if the inversion of regularized
* Hessian succeeds, that is, if \( H + \lambda \diag H \) is invertable.
*
* This will make lambda smaller, hence let the stepping move faster.
*
* @param lambdaDivisor the lambdaDivisor to set. Should be > 1.
*/
public void setLambdaDivisor(double lambdaDivisor) {
if(lambdaDivisor <= 1.0) {
throw new IllegalArgumentException("Parameter lambdaDivisor is required to be > 1.");
}
this.lambdaDivisor = lambdaDivisor;
}
@Override
public RandomVariable[] getBestFitParameters() {
return parameterCurrent;
}
@Override
public double getRootMeanSquaredError() {
return errorMeanSquaredCurrent.average().sqrt().doubleValue();
}
/**
* @param errorMeanSquaredCurrent the errorMeanSquaredCurrent to set
*/
public void setErrorMeanSquaredCurrent(RandomVariable errorMeanSquaredCurrent) {
this.errorMeanSquaredCurrent = errorMeanSquaredCurrent;
}
@Override
public int getIterations() {
return iteration;
}
protected void prepareAndSetValues(RandomVariable[] parameters, RandomVariable[] values) throws SolverException {
setValues(parameters, values);
}
protected void prepareAndSetDerivatives(RandomVariable[] parameters, RandomVariable[] values, RandomVariable[][] derivatives) throws SolverException {
setDerivatives(parameters, derivatives);
}
/**
* The objective function. Override this method to implement your custom
* function.
*
* @param parameters Input value. The parameter vector.
* @param values Output value. The vector of values f(i,parameters), i=1,...,n
* @throws SolverException Thrown if the valuation fails, specific cause may be available via the cause()
method.
*/
public abstract void setValues(RandomVariable[] parameters, RandomVariable[] values) throws SolverException;
/**
* The derivative of the objective function. You may override this method
* if you like to implement your own derivative.
*
* @param parameters Input value. The parameter vector.
* @param derivatives Output value, where derivatives[i][j] is d(value(j)) / d(parameters(i)
* @throws SolverException Thrown if the valuation fails, specific cause may be available via the cause()
method.
*/
public void setDerivatives(RandomVariable[] parameters, RandomVariable[][] derivatives) throws SolverException {
// Calculate new derivatives. Note that this method is called only with
// parameters = parameterTest, so we may use valueTest.
parameters = parameterCurrent;
Vector> valueFutures = new Vector<>(parameterCurrent.length);
for (int parameterIndex = 0; parameterIndex < parameterCurrent.length; parameterIndex++) {
final RandomVariable[] parametersNew = parameters.clone();
final RandomVariable[] derivative = derivatives[parameterIndex];
final int workerParameterIndex = parameterIndex;
Callable worker = new Callable() {
@Override
public RandomVariable[] call() {
RandomVariable parameterFiniteDifference;
if(parameterSteps != null) {
parameterFiniteDifference = parameterSteps[workerParameterIndex];
}
else {
/*
* Try to adaptively set a parameter shift. Note that in some
* applications it may be important to set parameterSteps.
* appropriately.
*/
parameterFiniteDifference = parametersNew[workerParameterIndex].abs().add(1.0).mult(1E-8);
}
// Shift parameter value
parametersNew[workerParameterIndex] = parametersNew[workerParameterIndex].add(parameterFiniteDifference);
// Calculate derivative as (valueUpShift - valueCurrent) / parameterFiniteDifference
try {
prepareAndSetValues(parametersNew, derivative);
} catch (Exception e) {
// We signal an exception to calculate the derivative as NaN
Arrays.fill(derivative, new RandomVariableFromDoubleArray(Double.NaN));
}
for (int valueIndex = 0; valueIndex < valueCurrent.length; valueIndex++) {
derivative[valueIndex] = derivative[valueIndex].sub(valueCurrent[valueIndex]).div(parameterFiniteDifference);
derivative[valueIndex] = derivative[valueIndex].isNaN().sub(0.5).mult(-1).choose(derivative[valueIndex], new Scalar(0.0));
}
return derivative;
}
};
if(executor != null) {
Future valueFuture = executor.submit(worker);
valueFutures.add(parameterIndex, valueFuture);
}
else {
FutureTask valueFutureTask = new FutureTask<>(worker);
valueFutureTask.run();
valueFutures.add(parameterIndex, valueFutureTask);
}
}
for (int parameterIndex = 0; parameterIndex < parameterCurrent.length; parameterIndex++) {
try {
derivatives[parameterIndex] = valueFutures.get(parameterIndex).get();
}
catch (InterruptedException e) {
throw new SolverException(e);
} catch (ExecutionException e) {
throw new SolverException(e);
}
}
}
/**
* You may override this method to implement a custom stop condition.
*
* @return Stop condition.
*/
boolean done() {
// The solver terminates if...
return
// Maximum number of iterations is reached
(iteration > maxIteration)
||
// Error does not improve by more that the given error tolerance
(errorRootMeanSquaredChange.sub(errorTolerance).getMax() <= 0);
}
@Override
public void run() throws SolverException {
try {
// Allocate memory
int numberOfParameters = initialParameters.length;
int numberOfValues = targetValues.length;
parameterTest = initialParameters.clone();
parameterCurrent = initialParameters.clone();
valueTest = new RandomVariable[numberOfValues];
valueCurrent = new RandomVariable[numberOfValues];
Arrays.fill(valueCurrent, new RandomVariableFromDoubleArray(Double.NaN));
derivativeCurrent = new RandomVariable[numberOfParameters][numberOfValues];
iteration = 0;
while(true) {
// Count iterations
iteration++;
// Calculate values for test parameters
prepareAndSetValues(parameterTest, valueTest);
// Calculate error
RandomVariable errorMeanSquaredTest = getMeanSquaredError(valueTest);
/*
* Note: The following test will be false if errorMeanSquaredTest is NaN.
* That is: NaN is consider as a rejected point.
*/
RandomVariable isPointAccepted = errorMeanSquaredCurrent.sub(errorMeanSquaredTest);
for(int parameterIndex = 0; parameterIndex= 0 ? lambda[pathIndex] / lambdaDivisor : lambda[pathIndex] * lambdaMultiplicator;
}
/*
* Calculate new derivative at parameterTest (where point is accepted).
* Note: the first argument should be parameterTest to use shortest operator tree.
*/
prepareAndSetDerivatives(parameterTest, valueTest, derivativeCurrent);
/*
* Calculate new parameterTest
*/
double[][] parameterIncrement = new double[parameterCurrent.length][numberOfPaths];
for(int pathIndex=0; pathIndex newTargetVaues, List newWeights, boolean isUseBestParametersAsInitialParameters) throws CloneNotSupportedException {
StochasticPathwiseLevenbergMarquardt clonedOptimizer = clone();
clonedOptimizer.targetValues = numberListToDoubleArray(newTargetVaues);
clonedOptimizer.weights = numberListToDoubleArray(newWeights);
if(isUseBestParametersAsInitialParameters && this.done()) {
clonedOptimizer.initialParameters = this.getBestFitParameters();
}
return clonedOptimizer;
}
}