org.hipparchus.analysis.FunctionUtils Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is not the original file distributed by the Apache Software Foundation
* It has been modified by the Hipparchus project
*/
package org.hipparchus.analysis;
import org.hipparchus.analysis.differentiation.DSFactory;
import org.hipparchus.analysis.differentiation.DerivativeStructure;
import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
import org.hipparchus.analysis.function.Identity;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.util.MathArrays;
import org.hipparchus.util.MathUtils;
/**
* Utilities for manipulating function objects.
*
*/
public class FunctionUtils {
/**
* Class only contains static methods.
*/
private FunctionUtils() {}
/**
* Composes functions.
*
* The functions in the argument list are composed sequentially, in the
* given order. For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).
*
* @param f List of functions.
* @return the composite function.
*/
public static UnivariateFunction compose(final UnivariateFunction ... f) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
double r = x;
for (int i = f.length - 1; i >= 0; i--) {
r = f[i].value(r);
}
return r;
}
};
}
/**
* Composes functions.
*
* The functions in the argument list are composed sequentially, in the
* given order. For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).
*
* @param f List of functions.
* @return the composite function.
*/
public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
return new UnivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double t) {
double r = t;
for (int i = f.length - 1; i >= 0; i--) {
r = f[i].value(r);
}
return r;
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure t) {
DerivativeStructure r = t;
for (int i = f.length - 1; i >= 0; i--) {
r = f[i].value(r);
}
return r;
}
};
}
/**
* Adds functions.
*
* @param f List of functions.
* @return a function that computes the sum of the functions.
*/
public static UnivariateFunction add(final UnivariateFunction ... f) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
double r = f[0].value(x);
for (int i = 1; i < f.length; i++) {
r += f[i].value(x);
}
return r;
}
};
}
/**
* Adds functions.
*
* @param f List of functions.
* @return a function that computes the sum of the functions.
*/
public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
return new UnivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double t) {
double r = f[0].value(t);
for (int i = 1; i < f.length; i++) {
r += f[i].value(t);
}
return r;
}
/** {@inheritDoc}
* @throws MathIllegalArgumentException if functions are not consistent with each other
*/
@Override
public DerivativeStructure value(final DerivativeStructure t)
throws MathIllegalArgumentException {
DerivativeStructure r = f[0].value(t);
for (int i = 1; i < f.length; i++) {
r = r.add(f[i].value(t));
}
return r;
}
};
}
/**
* Multiplies functions.
*
* @param f List of functions.
* @return a function that computes the product of the functions.
*/
public static UnivariateFunction multiply(final UnivariateFunction ... f) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
double r = f[0].value(x);
for (int i = 1; i < f.length; i++) {
r *= f[i].value(x);
}
return r;
}
};
}
/**
* Multiplies functions.
*
* @param f List of functions.
* @return a function that computes the product of the functions.
*/
public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
return new UnivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double t) {
double r = f[0].value(t);
for (int i = 1; i < f.length; i++) {
r *= f[i].value(t);
}
return r;
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure t) {
DerivativeStructure r = f[0].value(t);
for (int i = 1; i < f.length; i++) {
r = r.multiply(f[i].value(t));
}
return r;
}
};
}
/**
* Returns the univariate function
* {@code h(x) = combiner(f(x), g(x)).}
*
* @param combiner Combiner function.
* @param f Function.
* @param g Function.
* @return the composite function.
*/
public static UnivariateFunction combine(final BivariateFunction combiner,
final UnivariateFunction f,
final UnivariateFunction g) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
return combiner.value(f.value(x), g.value(x));
}
};
}
/**
* Returns a MultivariateFunction h(x[]) defined by
* h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
*
*
* @param combiner Combiner function.
* @param f Function.
* @param initialValue Initial value.
* @return a collector function.
*/
public static MultivariateFunction collector(final BivariateFunction combiner,
final UnivariateFunction f,
final double initialValue) {
return new MultivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double[] point) {
double result = combiner.value(initialValue, f.value(point[0]));
for (int i = 1; i < point.length; i++) {
result = combiner.value(result, f.value(point[i]));
}
return result;
}
};
}
/**
* Returns a MultivariateFunction h(x[]) defined by
* h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
*
*
* @param combiner Combiner function.
* @param initialValue Initial value.
* @return a collector function.
*/
public static MultivariateFunction collector(final BivariateFunction combiner,
final double initialValue) {
return collector(combiner, new Identity(), initialValue);
}
/**
* Creates a unary function by fixing the first argument of a binary function.
*
* @param f Binary function.
* @param fixed value to which the first argument of {@code f} is set.
* @return the unary function h(x) = f(fixed, x)
*/
public static UnivariateFunction fix1stArgument(final BivariateFunction f,
final double fixed) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
return f.value(fixed, x);
}
};
}
/**
* Creates a unary function by fixing the second argument of a binary function.
*
* @param f Binary function.
* @param fixed value to which the second argument of {@code f} is set.
* @return the unary function h(x) = f(x, fixed)
*/
public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
final double fixed) {
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(double x) {
return f.value(x, fixed);
}
};
}
/**
* Samples the specified univariate real function on the specified interval.
*
* The interval is divided equally into {@code n} sections and sample points
* are taken from {@code min} to {@code max - (max - min) / n}; therefore
* {@code f} is not sampled at the upper bound {@code max}.
*
* @param f Function to be sampled
* @param min Lower bound of the interval (included).
* @param max Upper bound of the interval (excluded).
* @param n Number of sample points.
* @return the array of samples.
* @throws MathIllegalArgumentException if the lower bound {@code min} is
* greater than, or equal to the upper bound {@code max}.
* @throws MathIllegalArgumentException if the number of sample points
* {@code n} is negative.
*/
public static double[] sample(UnivariateFunction f, double min, double max, int n)
throws MathIllegalArgumentException {
if (n <= 0) {
throw new MathIllegalArgumentException(
LocalizedCoreFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES,
Integer.valueOf(n));
}
if (min >= max) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE_BOUND_EXCLUDED,
min, max);
}
final double[] s = new double[n];
final double h = (max - min) / n;
for (int i = 0; i < n; i++) {
s[i] = f.value(min + i * h);
}
return s;
}
/** Convert regular functions to {@link UnivariateDifferentiableFunction}.
*
* This method handle the case with one free parameter and several derivatives.
* For the case with several free parameters and only first order derivatives,
* see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
* There are no direct support for intermediate cases, with several free parameters
* and order 2 or more derivatives, as is would be difficult to specify all the
* cross derivatives.
*
*
* Note that the derivatives are expected to be computed only with respect to the
* raw parameter x of the base function, i.e. they are df/dx, df2/dx2, ...
* Even if the built function is later used in a composition like f(sin(t)), the provided
* derivatives should not apply the composition with sine and its derivatives by
* themselves. The composition will be done automatically here and the result will properly
* contain f(sin(t)), df(sin(t))/dt, df2(sin(t))/dt2 despite the
* provided derivatives functions know nothing about the sine function.
*
* @param f base function f(x)
* @param derivatives derivatives of the base function, in increasing differentiation order
* @return a differentiable function with value and all specified derivatives
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
* @see #derivative(UnivariateDifferentiableFunction, int)
*/
public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
final UnivariateFunction ... derivatives) {
return new UnivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double x) {
return f.value(x);
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure x) {
if (x.getOrder() > derivatives.length) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
x.getOrder(), derivatives.length);
}
final double[] packed = new double[x.getOrder() + 1];
packed[0] = f.value(x.getValue());
for (int i = 0; i < x.getOrder(); ++i) {
packed[i + 1] = derivatives[i].value(x.getValue());
}
return x.compose(packed);
}
};
}
/** Convert regular functions to {@link MultivariateDifferentiableFunction}.
*
* This method handle the case with several free parameters and only first order derivatives.
* For the case with one free parameter and several derivatives,
* see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
* There are no direct support for intermediate cases, with several free parameters
* and order 2 or more derivatives, as is would be difficult to specify all the
* cross derivatives.
*
*
* Note that the gradient is expected to be computed only with respect to the
* raw parameter x of the base function, i.e. it is df/dx1, df/dx2, ...
* Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
* gradient should not apply the composition with sine or cosine and their derivative by
* itself. The composition will be done automatically here and the result will properly
* contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
* know nothing about the sine or cosine functions.
*
* @param f base function f(x)
* @param gradient gradient of the base function
* @return a differentiable function with value and gradient
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
* @see #derivative(MultivariateDifferentiableFunction, int[])
*/
public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
final MultivariateVectorFunction gradient) {
return new MultivariateDifferentiableFunction() {
/** {@inheritDoc} */
@Override
public double value(final double[] point) {
return f.value(point);
}
/** {@inheritDoc} */
@Override
public DerivativeStructure value(final DerivativeStructure[] point) {
// set up the input parameters
final double[] dPoint = new double[point.length];
for (int i = 0; i < point.length; ++i) {
dPoint[i] = point[i].getValue();
if (point[i].getOrder() > 1) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
point[i].getOrder(), 1);
}
}
// evaluate regular functions
final double v = f.value(dPoint);
final double[] dv = gradient.value(dPoint);
MathUtils.checkDimension(dv.length, point.length);
// build the combined derivative
final int parameters = point[0].getFreeParameters();
final double[] partials = new double[point.length];
final double[] packed = new double[parameters + 1];
packed[0] = v;
final int orders[] = new int[parameters];
for (int i = 0; i < parameters; ++i) {
// we differentiate once with respect to parameter i
orders[i] = 1;
for (int j = 0; j < point.length; ++j) {
partials[j] = point[j].getPartialDerivative(orders);
}
orders[i] = 0;
// compose partial derivatives
packed[i + 1] = MathArrays.linearCombination(dv, partials);
}
return point[0].getFactory().build(packed);
}
};
}
/** Convert an {@link UnivariateDifferentiableFunction} to an
* {@link UnivariateFunction} computing nth order derivative.
*
* This converter is only a convenience method. Beware computing only one derivative does
* not save any computation as the original function will really be called under the hood.
* The derivative will be extracted from the full {@link DerivativeStructure} result.
*
* @param f original function, with value and all its derivatives
* @param order of the derivative to extract
* @return function computing the derivative at required order
* @see #derivative(MultivariateDifferentiableFunction, int[])
* @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
*/
public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
final DSFactory factory = new DSFactory(1, order);
return new UnivariateFunction() {
/** {@inheritDoc} */
@Override
public double value(final double x) {
final DerivativeStructure dsX = factory.variable(0, x);
return f.value(dsX).getPartialDerivative(order);
}
};
}
/** Convert an {@link MultivariateDifferentiableFunction} to an
* {@link MultivariateFunction} computing nth order derivative.
*
* This converter is only a convenience method. Beware computing only one derivative does
* not save any computation as the original function will really be called under the hood.
* The derivative will be extracted from the full {@link DerivativeStructure} result.
*
* @param f original function, with value and all its derivatives
* @param orders of the derivative to extract, for each free parameters
* @return function computing the derivative at required order
* @see #derivative(UnivariateDifferentiableFunction, int)
* @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
*/
public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
// the maximum differentiation order is the sum of all orders
int sum = 0;
for (final int order : orders) {
sum += order;
}
final int sumOrders = sum;
return new MultivariateFunction() {
/** Factory used for building derivatives. */
private DSFactory factory;
/** {@inheritDoc} */
@Override
public double value(final double[] point) {
if (factory == null || point.length != factory.getCompiler().getFreeParameters()) {
// rebuild the factory in case of mismatch
factory = new DSFactory(point.length, sumOrders);
}
// set up the input parameters
final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
for (int i = 0; i < point.length; ++i) {
dsPoint[i] = factory.variable(i, point[i]);
}
return f.value(dsPoint).getPartialDerivative(orders);
}
};
}
}