com.opengamma.strata.math.impl.differentiation.ScalarFirstOrderDifferentiator Maven / Gradle / Ivy
Show all versions of strata-math Show documentation
/*
* Copyright (C) 2009 - present by OpenGamma Inc. and the OpenGamma group of companies
*
* Please see distribution for license.
*/
package com.opengamma.strata.math.impl.differentiation;
import java.util.function.Function;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.math.MathException;
/**
* Differentiates a scalar function with respect to its argument using finite difference.
*
* For a function $y = f(x)$ where $x$ and $y$ are scalars, this class produces
* a gradient function $g(x)$, i.e. a function that returns the gradient for
* each point $x$, where $g$ is the scalar $\frac{dy}{dx}$.
*/
public class ScalarFirstOrderDifferentiator
implements Differentiator {
private static final double DEFAULT_EPS = 1e-5;
private static final double MIN_EPS = Math.sqrt(Double.MIN_NORMAL);
private final double eps;
private final double twoEps;
private final FiniteDifferenceType differenceType;
/**
* Creates an instance using the default value of eps (10-5) and central differencing type.
*/
public ScalarFirstOrderDifferentiator() {
this(FiniteDifferenceType.CENTRAL, DEFAULT_EPS);
}
/**
* Creates an instance using the default value of eps (10-5).
*
* @param differenceType the differencing type to be used in calculating the gradient function
*/
public ScalarFirstOrderDifferentiator(FiniteDifferenceType differenceType) {
this(differenceType, DEFAULT_EPS);
}
/**
* Creates an instance.
*
* If the size of the domain is very small or very large, consider re-scaling first.
* If this value is too small, the result will most likely be dominated by noise.
* Use around 10-5 times the domain size.
*
* @param differenceType the differencing type to be used in calculating the gradient function
* @param eps the step size used to approximate the derivative
*/
public ScalarFirstOrderDifferentiator(FiniteDifferenceType differenceType, double eps) {
ArgChecker.notNull(differenceType, "differenceType");
ArgChecker.isTrue(eps >= MIN_EPS,
"eps of {} is too small. Please choose a value > {}, such as 1e-5*size of domain", eps, MIN_EPS);
this.differenceType = differenceType;
this.eps = eps;
this.twoEps = 2 * eps;
}
//-------------------------------------------------------------------------
@Override
public Function differentiate(Function function) {
ArgChecker.notNull(function, "function");
switch (differenceType) {
case FORWARD:
return new Function() {
@SuppressWarnings("synthetic-access")
@Override
public Double apply(Double x) {
ArgChecker.notNull(x, "x");
return (function.apply(x + eps) - function.apply(x)) / eps;
}
};
case CENTRAL:
return new Function() {
@SuppressWarnings("synthetic-access")
@Override
public Double apply(Double x) {
ArgChecker.notNull(x, "x");
return (function.apply(x + eps) - function.apply(x - eps)) / twoEps;
}
};
case BACKWARD:
return new Function() {
@SuppressWarnings("synthetic-access")
@Override
public Double apply(Double x) {
ArgChecker.notNull(x, "x");
return (function.apply(x) - function.apply(x - eps)) / eps;
}
};
default:
throw new IllegalArgumentException("Can only handle forward, backward and central differencing");
}
}
//-------------------------------------------------------------------------
@Override
public Function differentiate(
Function function,
Function domain) {
ArgChecker.notNull(function, "function");
ArgChecker.notNull(domain, "domain");
double[] wFwd = new double[] {-3. / twoEps, 4. / twoEps, -1. / twoEps};
double[] wCent = new double[] {-1. / twoEps, 0., 1. / twoEps};
double[] wBack = new double[] {1. / twoEps, -4. / twoEps, 3. / twoEps};
return new Function() {
@SuppressWarnings("synthetic-access")
@Override
public Double apply(Double x) {
ArgChecker.notNull(x, "x");
ArgChecker.isTrue(domain.apply(x), "point {} is not in the function domain", x.toString());
double[] y = new double[3];
double[] w;
if (!domain.apply(x + eps)) {
if (!domain.apply(x - eps)) {
throw new MathException("cannot get derivative at point " + x.toString());
}
y[0] = function.apply(x - twoEps);
y[1] = function.apply(x - eps);
y[2] = function.apply(x);
w = wBack;
} else {
if (!domain.apply(x - eps)) {
y[0] = function.apply(x);
y[1] = function.apply(x + eps);
y[2] = function.apply(x + twoEps);
w = wFwd;
} else {
y[0] = function.apply(x - eps);
y[2] = function.apply(x + eps);
w = wCent;
}
}
double res = y[0] * w[0] + y[2] * w[2];
if (w[1] != 0) {
res += y[1] * w[1];
}
return res;
}
};
}
}