All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.finmath.montecarlo.automaticdifferentiation.backward.RandomVariableDifferentiableAAD Maven / Gradle / Ivy

Go to download

finmath lib is a Mathematical Finance Library in Java. It provides algorithms and methodologies related to mathematical finance.

There is a newer version: 6.0.19
Show newest version
/*
 * (c) Copyright Christian P. Fries, Germany. Contact: [email protected].
 *
 * Created on 17.06.2017
 */
package net.finmath.montecarlo.automaticdifferentiation.backward;

import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.IntToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;

import net.finmath.functions.DoubleTernaryOperator;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.montecarlo.automaticdifferentiation.RandomVariableDifferentiable;
import net.finmath.montecarlo.automaticdifferentiation.backward.RandomVariableDifferentiableAADFactory.DiracDeltaApproximationMethod;
import net.finmath.montecarlo.conditionalexpectation.LinearRegression;
import net.finmath.stochastic.ConditionalExpectationEstimator;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;

/**
 * Implementation of RandomVariableDifferentiable using
 * the backward algorithmic differentiation (adjoint algorithmic differentiation, AAD).
 *
 * This class implements the optimized stochastic ADD as it is described in
 * ssrn.com/abstract=2995695.
 *
 * The class implements the special treatment of the conditional expectation operator as it is described in
 * ssrn.com/abstract=3000822.
 *
 * The class implements the special treatment of indicator functions as it is described in
 * ssrn.com/abstract=3282667.
 *
 * For details see http://christianfries.com/finmath/stochasticautodiff/.
 *
 * The class is serializable. Upon de-serialization the value of {@link #getID()} may be changed to ensure unique IDs in de-serialization context.
 *
 * @author Christian Fries
 * @author Stefan Sedlmair
 * @version 1.1
 */
public class RandomVariableDifferentiableAAD implements RandomVariableDifferentiable {

	private static final long serialVersionUID = 2459373647785530657L;

	private static final int typePriorityDefault = 3;

	private static final RandomVariable one = new Scalar(1.0);

	private final int typePriority;

	private static AtomicLong indexOfNextRandomVariable = new AtomicLong(0);

	private enum OperatorType {
		ADD, MULT, DIV, SUB, SQUARED, SQRT, LOG, SIN, COS, EXP, INVERT, CAP, FLOOR, ABS,
		ADDPRODUCT, ADDRATIO, SUBRATIO, CHOOSE, DISCOUNT, ACCRUE, POW, MIN, MAX, AVERAGE, VARIANCE,
		STDEV, STDERROR, SVARIANCE, AVERAGE2, VARIANCE2,
		STDEV2, STDERROR2, CONDITIONAL_EXPECTATION
	}

	/**
	 * A node in the operator tree. It
	 * stores an id (the index m), the operator (the function f_m), and the arguments.
	 * It also stores reference to the argument values, if required.
	 *
	 * @author Christian Fries
	 */
	private static class OperatorTreeNode implements Serializable {

		private static final long serialVersionUID = -8428352552169568990L;

		private final Long id;
		private final OperatorType operatorType;
		private final List arguments;
		private final List argumentValues;
		private final Object operator;
		private final RandomVariableDifferentiableAADFactory factory;

		private static final RandomVariable zero = new Scalar(0.0);
		private static final RandomVariable one = new Scalar(1.0);
		private static final RandomVariable minusOne = new Scalar(-1.0);

		OperatorTreeNode(final OperatorType operatorType, final List arguments, List argumentValues, final Object operator, final RandomVariableDifferentiableAADFactory factory) {
			super();
			id = indexOfNextRandomVariable.getAndIncrement();
			this.operatorType = operatorType;
			this.arguments = arguments;
			this.operator = operator;
			this.factory = factory;

			/*
			 * This is the simple modification which reduces memory requirements.
			 */
			if(operatorType != null && (operatorType.equals(OperatorType.ADD) || operatorType.equals(OperatorType.SUB))) {
				// Addition does not need to retain arguments
				argumentValues = null;
			}
			else if(operatorType != null && operatorType.equals(OperatorType.AVERAGE)) {
				// Average does not need to retain arguments
				argumentValues = null;
			}
			else if(operatorType != null && operatorType.equals(OperatorType.MULT)) {
				// Product only needs to retain factors on differentiables
				if(arguments.get(0) == null) {
					argumentValues.set(1, null);
				}
				if(arguments.get(1) == null) {
					argumentValues.set(0, null);
				}
			}
			else if(operatorType != null && operatorType.equals(OperatorType.DIV)) {
				// Division only needs to retain numerator if denominator is differentiable
				if(arguments.get(1) == null) {
					argumentValues.set(0, null);
				}
			}
			else if(operatorType != null && operatorType.equals(OperatorType.ADDPRODUCT)) {
				// Addition does not need to retain arguments
				argumentValues.set(0, null);
				// Addition of product only needs to retain factors on differentiables
				if(arguments.get(1) == null) {
					argumentValues.set(2, null);
				}
				if(arguments.get(2) == null) {
					argumentValues.set(1, null);
				}
			}
			else if(operatorType != null && operatorType.equals(OperatorType.ACCRUE)) {
				// Addition of product only needs to retain factors on differentiables
				if(arguments.get(1) == null && arguments.get(2) == null) {
					argumentValues.set(0, null);
				}
				if(arguments.get(0) == null && arguments.get(1) == null) {
					argumentValues.set(1, null);
				}
				if(arguments.get(0) == null && arguments.get(2) == null) {
					argumentValues.set(2, null);
				}
			}
			else if(operatorType != null && operatorType.equals(OperatorType.CHOOSE)) {
				if(arguments.get(0) == null) {
					argumentValues.set(1, null);
					argumentValues.set(2, null);
				}
			}

			this.argumentValues = argumentValues;
		}

		/*
		 * This implements the update rule D_i = D_i + Dm * d fm/dxi where i are the arguments of this node and m is this node.
		 */
		private void propagateDerivativesFromResultToArgument(final Map derivatives) {
			if(arguments == null) {
				// The node has no arguments (it is a leaf node the tree). Do nothing.
				return;
			}

			for(int argumentIndex = 0; argumentIndex < arguments.size(); argumentIndex++) {
				final OperatorTreeNode argument = arguments.get(argumentIndex);
				if(argument != null) {
					final Long argumentID = argument.id;

					final RandomVariable partialDerivative	= getPartialDerivative(argument, argumentIndex);
					RandomVariable derivative			= derivatives.get(id);
					RandomVariable argumentDerivative	= derivatives.get(argumentID);

					/*
					 * Special treatment of some stochastic operators
					 */
					switch(operatorType) {
					case AVERAGE:
						// Implementation of AVERAGE (see https://ssrn.com/abstract=2995695 for details).
						derivative = derivative.average();
						break;
					case CONDITIONAL_EXPECTATION:
						// Implementation of CONDITIONAL_EXPECTATION (see https://ssrn.com/abstract=2995695 for details).
						final ConditionalExpectationEstimator estimator = (ConditionalExpectationEstimator)operator;
						derivative = estimator.getConditionalExpectation(derivative);
						break;
					case CHOOSE:
						// Implementation of CHOOSE (INDICATOR_FUNCTION)
						if(argumentIndex == 0 && (factory.getDiracDeltaApproximationMethod() == DiracDeltaApproximationMethod.REGRESSION_ON_DENSITY || factory.getDiracDeltaApproximationMethod() == DiracDeltaApproximationMethod.REGRESSION_ON_DISTRIBUITON)) {
							derivative = getDiracDeltaRegression(derivative, argumentValues.get(0));
						}
						break;
					default:
						// Ordinary operator - nothing to do
						break;
					}

					/*
					 * Add the product of current nodes derivative and the vertex partialDerivative to the argument derivative
					 */
					if(argumentDerivative == null) {
						// argumentDerivative is zero. Initialize value
						argumentDerivative = derivative.mult(partialDerivative);
					}
					else {
						// Add product to given value
						argumentDerivative = argumentDerivative.addProduct(partialDerivative, derivative);
					}

					derivatives.put(argumentID, argumentDerivative);
				}
			}
		}

		/**
		 * Calculate the partial derivative of this node with respect to an argument node.
		 * Since a function f may use an argument node X in multiple arguments, say f(X,X), we need to provide index
		 * of the argument with respect to which the differentiation is performed (thanks to Vincent E. for pointing to this).
		 *
		 * @param differential The node of the argument.
		 * @param differentialIndex The index of the argument in the functions argument list.
		 * @return The value of the partial derivative.
		 */
		private RandomVariable getPartialDerivative(final OperatorTreeNode differential, final int differentialIndex) {

			if(!arguments.contains(differential)) {
				return zero;
			}

			final RandomVariable X = arguments.size() > 0 && argumentValues != null ? argumentValues.get(0) : null;
			final RandomVariable Y = arguments.size() > 1 && argumentValues != null ? argumentValues.get(1) : null;
			final RandomVariable Z = arguments.size() > 2 && argumentValues != null ? argumentValues.get(2) : null;

			RandomVariable derivative;

			switch(operatorType) {
			/* functions with one argument  */
			case SQUARED:
				derivative = X.mult(2.0);
				break;
			case SQRT:
				derivative = X.sqrt().invert().mult(0.5);
				break;
			case EXP:
				derivative = X.exp();
				break;
			case LOG:
				derivative = X.invert();
				break;
			case SIN:
				derivative = X.cos();
				break;
			case COS:
				derivative = X.sin().mult(-1.0);
				break;
			case INVERT:
				derivative = X.invert().squared().mult(-1);
				break;
			case AVERAGE:
				derivative = one;
				break;
			case CONDITIONAL_EXPECTATION:
				derivative = one;
				break;
			case VARIANCE:
				derivative = X.sub(X.getAverage()*(2.0*X.size()-1.0)/X.size()).mult(2.0/X.size());
				break;
			case STDEV:
				derivative = X.sub(X.getAverage()*(2.0*X.size()-1.0)/X.size()).mult(2.0/X.size()).mult(0.5).div(Math.sqrt(X.getVariance()));
				break;
			case MIN:
				final double min = X.getMin();
				derivative = X.apply(new DoubleUnaryOperator() {
					@Override
					public double applyAsDouble(final double x) {
						return (x == min) ? 1.0 : 0.0;
					}
				});
				break;
			case MAX:
				final double max = X.getMax();
				derivative = X.apply(new DoubleUnaryOperator() {
					@Override
					public double applyAsDouble(final double x) {
						return (x == max) ? 1.0 : 0.0;
					}
				});
				break;
			case ABS:
				derivative = X.choose(one, minusOne);
				break;
			case STDERROR:
				derivative = X.sub(X.getAverage()*(2.0*X.size()-1.0)/X.size()).mult(2.0/X.size()).mult(0.5).div(Math.sqrt(X.getVariance() * X.size()));
				break;
			case SVARIANCE:
				derivative = X.sub(X.getAverage()*(2.0*X.size()-1.0)/X.size()).mult(2.0/(X.size()-1));
				break;
			case ADD:
				derivative = one;
				break;
			case SUB:
				derivative = differentialIndex == 0 ? one : minusOne;
				break;
			case MULT:
				derivative = differentialIndex == 0 ? Y : X;
				break;
			case DIV:
				derivative = differentialIndex == 0 ? Y.invert() : X.div(Y.squared()).mult(-1);
				break;
			case CAP:
				if(differentialIndex == 0) {
					derivative = X.sub(Y).choose(zero, one);
				}
				else {
					derivative = X.sub(Y).choose(one, zero);
				}
				break;
			case FLOOR:
				if(differentialIndex == 0) {
					derivative = X.sub(Y).choose(one, zero);
				}
				else {
					derivative = X.sub(Y).choose(zero, one);
				}
				break;
			case AVERAGE2:
				derivative = differentialIndex == 0 ? Y : X;
				break;
			case VARIANCE2:
				derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y)*(X.size()-1)).sub(X.getAverage(Y)))) :
					X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X)*(X.size()-1)).sub(Y.getAverage(X))));
				break;
			case STDEV2:
				derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y)*(X.size()-1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y))) :
					X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X)*(X.size()-1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X)));
				break;
			case STDERROR2:
				derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y)*(X.size()-1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y) * X.size())) :
					X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X)*(X.size()-1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X) * Y.size()));
				break;
			case POW:
				// second argument will always be deterministic and constant (currently pow does not exist with a random variable exponent)
				derivative = (differentialIndex == 0) ? X.pow(Y.doubleValue() - 1.0).mult(Y) : zero;
				break;
			case ADDPRODUCT:
				if(differentialIndex == 0) {
					derivative = one;
				} else if(differentialIndex == 1) {
					derivative = Z;
				} else {
					derivative = Y;
				}
				break;
			case ADDRATIO:
				if(differentialIndex == 0) {
					derivative = one;
				} else if(differentialIndex == 1) {
					derivative = Z.invert();
				} else {
					derivative = Y.div(Z.squared()).mult(-1.0);
				}
				break;
			case SUBRATIO:
				if(differentialIndex == 0) {
					derivative = one;
				} else if(differentialIndex == 1) {
					derivative = Z.invert().mult(-1.0);
				} else {
					derivative = Y.div(Z.squared());
				}
				break;
			case ACCRUE:
				if(differentialIndex == 0) {
					derivative = Y.mult(Z).add(1.0);
				} else if(differentialIndex == 1) {
					derivative = X.mult(Z);
				} else {
					derivative = X.mult(Y);
				}
				break;
			case DISCOUNT:
				if(differentialIndex == 0) {
					derivative = Y.mult(Z).add(1.0).invert();
				} else if(differentialIndex == 1) {
					derivative = X.mult(Z).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
				} else {
					derivative = X.mult(Y).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
				}
				break;
			case CHOOSE:
				if(differentialIndex == 0) {
					switch(factory.getDiracDeltaApproximationMethod()) {
					case ONE:
					{
						derivative = Y.sub(Z);
						break;
					}
					case ZERO:
					{
						derivative = zero;
						break;
					}
					case DISCRETE_DELTA:
					{
						/*
						 * Approximation via local finite difference
						 * (see https://ssrn.com/abstract=2995695 for details).
						 */
						final double epsilon = factory.getDiracDeltaApproximationWidthPerStdDev()*X.getStandardDeviation();
						if(Double.isInfinite(epsilon)) {
							derivative = Y.sub(Z);
						}
						else if(epsilon > 0) {
							derivative = Y.sub(Z);
							derivative = derivative.mult(X.add(epsilon/2).choose(one, zero));
							derivative = derivative.mult(X.sub(epsilon/2).choose(zero, one));
							derivative = derivative.div(epsilon);
						}
						else {
							derivative = zero;
						}
						break;
					}
					case REGRESSION_ON_DENSITY:
					case REGRESSION_ON_DISTRIBUITON:
					{
						derivative = Y.sub(Z);
						break;
					}
					default:
					{
						throw new UnsupportedOperationException("Diract Delta Approximation Method " + factory.getDiracDeltaApproximationMethod().name() + " not supported.");
					}
					}
				} else if(differentialIndex == 1) {
					derivative = X.choose(one, zero);
				} else {
					derivative = X.choose(zero, one);
				}
				break;
			default:
				throw new IllegalArgumentException("Operation " + operatorType.name() + " not supported in differentiation.");
			}

			return derivative;
		}

		private RandomVariable getDiracDeltaRegression(RandomVariable derivative, final RandomVariable indicator) {
			final double diracDeltaApproximationWidthPerStdDev = factory.getDiracDeltaApproximationWidthPerStdDev();
			final double epsilon = diracDeltaApproximationWidthPerStdDev*indicator.getStandardDeviation();

			final RandomVariable localizedOne = (indicator.add(epsilon/2).choose(one, zero)).mult(indicator.sub(epsilon/2).choose(zero, one));

			final boolean isDirectDeltaRegressionUseRegressionOnAdjointDerivative = false;	// currently disabled, was used in experiments
			if(isDirectDeltaRegressionUseRegressionOnAdjointDerivative) {
				final RandomVariable localizedValue = indicator.mult(localizedOne);
				final RandomVariable[] regressionBasisFunctions = new RandomVariable[] {
						localizedOne,
						localizedValue,
						localizedValue.squared()
				};
				derivative = localizedOne.mult((new LinearRegression(regressionBasisFunctions)).getRegressionCoefficients(derivative)[0]).div(localizedOne.getAverage());
			}
			else {
				derivative = derivative.mult(localizedOne).div(localizedOne.getAverage());
			}

			return derivative.mult(getDensityRegression(indicator));

		}

		private double getDensityRegression(final RandomVariable indicator) {
			final double diracDeltaApproximationDensityRegressionWidthPerStdDev = factory.getDiracDeltaApproximationDensityRegressionWidthPerStdDev();

			/*
			 * Density regression
			 */
			final double underlyingStdDev = indicator.getStandardDeviation();
			final int numberOfSamplePointsHalf = 50;			// @TODO numberOfSamplePoints should become a parameter.
			final double sampleIntervalWidthHalf = diracDeltaApproximationDensityRegressionWidthPerStdDev/2 * underlyingStdDev / numberOfSamplePointsHalf;
			final double[] samplePointX = new double[numberOfSamplePointsHalf*2];
			final double[] samplePointY = new double[numberOfSamplePointsHalf*2];
			double sampleInterval = sampleIntervalWidthHalf;
			final RandomVariable indicatorPositiveValues = indicator.choose(new Scalar(1.0), new Scalar(0.0));
			final RandomVariable indicatorNegativeValues = indicator.choose(new Scalar(0.0), new Scalar(1.0));

			switch(factory.getDiracDeltaApproximationMethod()) {
			case REGRESSION_ON_DENSITY:
			{
				for(int i=0; i extractOperatorTreeNodes(final List arguments) {
			return arguments != null ? arguments.stream().map( OperatorTreeNode::of ).collect(Collectors.toList()) : null;
		}

		private static List extractOperatorValues(final List arguments) {
			return arguments != null ? arguments.stream().map( OperatorTreeNode::getValue ).collect(Collectors.toList()) : null;
		}

		private void readObject(final java.io.ObjectInputStream stream) throws IOException, ClassNotFoundException {
			stream.defaultReadObject();
			// Reassign id
			try {
				final Field idField = this.getClass().getDeclaredField("id");
				idField.setAccessible(true);
				idField.set(this, indexOfNextRandomVariable.getAndIncrement());
				idField.setAccessible(false);
			} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
				throw new RuntimeException("Unable to re-assing id of " + this.getClass().getSimpleName() + ".", e);
			}
		}
	}

	/*
	 * Data model. We maintain the underlying values and a link to the node in the operator tree.
	 */
	private RandomVariable values;
	private final OperatorTreeNode operatorTreeNode;
	private final RandomVariableDifferentiableAADFactory factory;

	public RandomVariableDifferentiableAAD(final RandomVariable values, final List argumentOperatorTreeNodes, final List argumentValues, final ConditionalExpectationEstimator estimator, final OperatorType operator, final RandomVariableDifferentiableAADFactory factory, final int methodArgumentTypePriority) {
		super();
		this.values = values;
		operatorTreeNode = new OperatorTreeNode(operator, argumentOperatorTreeNodes, argumentValues, estimator, factory);
		this.factory = factory != null ? factory : new RandomVariableDifferentiableAADFactory();
		typePriority = methodArgumentTypePriority;
	}

	public static RandomVariableDifferentiableAAD of(final double value) {
		return new RandomVariableDifferentiableAAD(value);
	}

	public static RandomVariableDifferentiableAAD of(final RandomVariable randomVariable) {
		return new RandomVariableDifferentiableAAD(randomVariable);
	}

	public RandomVariableDifferentiableAAD(final double value) {
		this(new Scalar(value), null, null, null);
	}

	public RandomVariableDifferentiableAAD(final RandomVariable randomVariable) {
		this(randomVariable, null, null, randomVariable instanceof RandomVariableDifferentiableAAD ? ((RandomVariableDifferentiableAAD)randomVariable).getFactory() : null);
	}

	public RandomVariableDifferentiableAAD(final RandomVariable values, final RandomVariableDifferentiableAADFactory factory) {
		this(values, null, null, factory);
	}

	private RandomVariableDifferentiableAAD(final RandomVariable values, final List arguments, final OperatorType operator, final RandomVariableDifferentiableAADFactory factory) {
		this(values, arguments, null, operator, factory);
	}

	public RandomVariableDifferentiableAAD(final RandomVariable values, final List arguments, final ConditionalExpectationEstimator estimator, final OperatorType operator, final RandomVariableDifferentiableAADFactory factory) {
		this(values, arguments, estimator, operator, factory, typePriorityDefault);
	}

	public RandomVariableDifferentiableAAD(final RandomVariable values, final List arguments, final ConditionalExpectationEstimator estimator, final OperatorType operator, final RandomVariableDifferentiableAADFactory factory, final int methodArgumentTypePriority) {
		this(values, OperatorTreeNode.extractOperatorTreeNodes(arguments), OperatorTreeNode.extractOperatorValues(arguments), estimator, operator, factory, methodArgumentTypePriority);
	}

	public OperatorTreeNode getOperatorTreeNode() {
		return operatorTreeNode;
	}

	/**
	 * Returns the underlying values.
	 *
	 * @return The underling values.
	 */
	@Override
	public RandomVariable getValues(){
		return values;
	}

	public RandomVariableDifferentiableAADFactory getFactory() {
		return factory;
	}

	@Override
	public Long getID(){
		return getOperatorTreeNode().id;
	}

	/**
	 * Returns the gradient of this random variable with respect to all its leaf nodes.
	 * The method calculated the map \( v \mapsto \frac{d u}{d v} \) where \( u \) denotes this.
	 *
	 * Performs a backward automatic differentiation.
	 *
	 * @return The gradient map.
	 */
	@Override
	public Map getGradient(final Set independentIDs) {

		// The map maintaining the derivatives id -> derivative
		final Map derivatives = new HashMap<>();
		// Put derivative of this node w.r.t. itself
		derivatives.put(getID(), one);

		// The set maintaining the independents. Note: TreeMap is maintaining a sorting on the keys.
		final TreeMap independents = new TreeMap<>();
		// Initialize with root node
		independents.put(getID(), getOperatorTreeNode());

		while(independents.size() > 0) {
			// Get and remove node with the highest id in independents
			final Map.Entry independentEntry = independents.pollLastEntry();
			final Long id = independentEntry.getKey();
			final OperatorTreeNode independent = independentEntry.getValue();

			// Process this node (node with highest id in independents)
			final List arguments = independent.arguments;
			if(arguments != null && arguments.size() > 0) {
				// Node has arguments: Propagate derivative to arguments.
				independent.propagateDerivativesFromResultToArgument(derivatives);

				// Remove id of this node from derivatives - keep only leaf nodes.
				if(isGradientRetainsLeafNodesOnly()) {
					derivatives.remove(id);
				}

				// Add all non leaf node arguments to the list of independents
				for(final OperatorTreeNode argument : arguments) {
					// If an argument is null, it is a (non-differentiable) constant.
					if(argument != null) {
						independents.put(argument.id, argument);
					}
				}
			}

			if(independentIDs != null && independentIDs.contains(id)) {
				derivatives.remove(id);
			}
		}

		return derivatives;
	}

	private boolean isGradientRetainsLeafNodesOnly() {
		return getFactory() != null && getFactory().isGradientRetainsLeafNodesOnly();
	}

	@Override
	public Map getTangents(final Set dependentIDs) {
		throw new UnsupportedOperationException();
	}

	/*
	 * The following methods are end points since they return double values.
	 * You cannot differentiate these results.
	 */

	@Override
	public boolean equals(final RandomVariable randomVariable) {
		return getValues().equals(randomVariable);
	}

	@Override
	public double getFiltrationTime() {
		return getValues().getFiltrationTime();
	}

	@Override
	public int getTypePriority() {
		return typePriority;
	}

	@Override
	public double get(final int pathOrState) {
		return getValues().get(pathOrState);
	}

	@Override
	public int size() {
		return getValues().size();
	}

	@Override
	public boolean isDeterministic() {
		return getValues().isDeterministic();
	}

	@Override
	public double[] getRealizations() {
		return getValues().getRealizations();
	}

	@Override
	public Double doubleValue() {
		return getValues().doubleValue();
	}

	@Override
	public double getMin() {
		return getValues().getMin();
	}

	@Override
	public double getMax() {
		return getValues().getMax();
	}

	@Override
	public double getAverage() {
		return getValues().getAverage();
	}

	@Override
	public double getAverage(final RandomVariable probabilities) {
		return getValues().getAverage(probabilities);
	}

	@Override
	public double getVariance() {
		return getValues().getVariance();
	}

	@Override
	public double getVariance(final RandomVariable probabilities) {
		return getValues().getVariance(probabilities);
	}

	@Override
	public double getSampleVariance() {
		return getValues().getSampleVariance();
	}

	@Override
	public double getStandardDeviation() {
		return getValues().getStandardDeviation();
	}

	@Override
	public double getStandardDeviation(final RandomVariable probabilities) {
		return getValues().getStandardDeviation(probabilities);
	}

	@Override
	public double getStandardError() {
		return getValues().getStandardError();
	}

	@Override
	public double getStandardError(final RandomVariable probabilities) {
		return getValues().getStandardError(probabilities);
	}

	@Override
	public double getQuantile(final double quantile) {
		return getValues().getQuantile(quantile);
	}

	@Override
	public double getQuantile(final double quantile, final RandomVariable probabilities) {
		return getValues().getQuantile(quantile, probabilities);
	}

	@Override
	public double getQuantileExpectation(final double quantileStart, final double quantileEnd) {
		return getValues().getQuantileExpectation(quantileStart, quantileEnd);
	}

	@Override
	public double[] getHistogram(final double[] intervalPoints) {
		return getValues().getHistogram(intervalPoints);
	}

	@Override
	public double[][] getHistogram(final int numberOfPoints, final double standardDeviations) {
		return getValues().getHistogram(numberOfPoints, standardDeviations);
	}

	/*
	 * The following methods are differentiable operations.
	 */

	@Override
	public RandomVariable cache() {
		values = values.cache();
		return this;
	}

	@Override
	public RandomVariable cap(final double cap) {
		return new RandomVariableDifferentiableAAD(
				getValues().cap(cap),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(this.getValues(), new Scalar(cap)),
				null,
				OperatorType.CAP,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable floor(final double floor) {
		return new RandomVariableDifferentiableAAD(
				getValues().floor(floor),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(this.getValues(), new Scalar(floor)),
				null,
				OperatorType.FLOOR,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable add(final double value) {
		return new RandomVariableDifferentiableAAD(
				getValues().add(value),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(null /* this */, null /* new RandomVariableFromDoubleArray(value) */),		// For ADD we do not need all arguments to evaluate the differential
				null,
				OperatorType.ADD,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable sub(final double value) {
		return new RandomVariableDifferentiableAAD(
				getValues().sub(value),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(null /* this */, null /* new RandomVariableFromDoubleArray(value) */),		// For SUB we do not need all arguments to evaluate the differential
				null,
				OperatorType.SUB,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable mult(final double value) {
		return new RandomVariableDifferentiableAAD(
				getValues().mult(value),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(null, new Scalar(value)),		// For MULT with constant we do not need all arguments to evaluate the differential
				null,
				OperatorType.MULT,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable div(final double value) {
		return new RandomVariableDifferentiableAAD(
				getValues().div(value),
				Arrays.asList(this.getOperatorTreeNode(), null),
				Arrays.asList(null, new Scalar(value)),		// For DIV with constant we do not need all arguments to evaluate the differential
				null,
				OperatorType.DIV,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable pow(final double exponent) {
		return new RandomVariableDifferentiableAAD(
				getValues().pow(exponent),
				Arrays.asList(this, new Scalar(exponent)),
				OperatorType.POW,
				getFactory());
	}

	@Override
	public RandomVariable average() {
		return new RandomVariableDifferentiableAAD(
				getValues().average(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.AVERAGE,
				getFactory());
	}

	@Override
	public RandomVariable getConditionalExpectation(final ConditionalExpectationEstimator estimator) {
		return new RandomVariableDifferentiableAAD(
				getValues().getConditionalExpectation(estimator),
				Arrays.asList(new RandomVariable[]{ this }),
				estimator,
				OperatorType.CONDITIONAL_EXPECTATION,
				getFactory());

	}

	@Override
	public RandomVariable squared() {
		return new RandomVariableDifferentiableAAD(
				getValues().squared(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.SQUARED,
				getFactory());
	}

	@Override
	public RandomVariable sqrt() {
		return new RandomVariableDifferentiableAAD(
				getValues().sqrt(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.SQRT,
				getFactory());
	}

	@Override
	public RandomVariable exp() {
		return new RandomVariableDifferentiableAAD(
				getValues().exp(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.EXP,
				getFactory());
	}

	@Override
	public RandomVariable log() {
		return new RandomVariableDifferentiableAAD(
				getValues().log(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.LOG,
				getFactory());
	}

	@Override
	public RandomVariable sin() {
		return new RandomVariableDifferentiableAAD(
				getValues().sin(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.SIN,
				getFactory());
	}

	@Override
	public RandomVariable cos() {
		return new RandomVariableDifferentiableAAD(
				getValues().cos(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.COS,
				getFactory());
	}

	/*
	 * Binary operators: checking for return type priority.
	 */

	@Override
	public RandomVariable add(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.add(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().add(randomVariable.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)),
				Arrays.asList(null /* this.getValues() */, null /* randomVariable.getValues() */),		// For ADD we do not need all arguments to evaluate the differential
				null,
				OperatorType.ADD,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable sub(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.bus(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().sub(randomVariable.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)),
				Arrays.asList(null /* this.getValues() */, null /* randomVariable.getValues() */),		// For SUB we do not need all arguments to evaluate the differential
				null,
				OperatorType.SUB,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable bus(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.sub(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().bus(randomVariable.getValues()),
				Arrays.asList(OperatorTreeNode.of(randomVariable), this.getOperatorTreeNode()),			// SUB with swapped arguments
				Arrays.asList(null, null),																// For SUB we do not need all arguments to evaluate the differential
				null,
				OperatorType.SUB,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable mult(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.mult(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().mult(randomVariable.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)),
				Arrays.asList(this.getValues(), randomVariable.getValues()),
				null,
				OperatorType.MULT,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable div(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.vid(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().div(randomVariable.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(randomVariable)),
				Arrays.asList(this.getValues(), randomVariable.getValues()),
				null,
				OperatorType.DIV,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable vid(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.div(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().vid(randomVariable.getValues()),
				Arrays.asList(OperatorTreeNode.of(randomVariable), this.getOperatorTreeNode()),	// DIV with swapped arguments
				Arrays.asList(randomVariable.getValues(), this.getValues()),					// DIV with swapped arguments
				null,
				OperatorType.DIV,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable cap(final RandomVariable randomVariable) {
		if(randomVariable.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return randomVariable.cap(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().cap(randomVariable.getValues()),
				Arrays.asList(this, randomVariable),
				OperatorType.CAP,
				getFactory());
	}

	@Override
	public RandomVariable floor(final RandomVariable floor) {
		if(floor.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return floor.floor(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().floor(floor.getValues()),
				Arrays.asList(this, floor),
				OperatorType.FLOOR,
				getFactory());
	}

	@Override
	public RandomVariable accrue(final RandomVariable rate, final double periodLength) {
		if(rate.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return rate.mult(periodLength).add(1.0).mult(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().accrue(rate.getValues(), periodLength),
				Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)),
				OperatorType.ACCRUE,
				getFactory());
	}

	@Override
	public RandomVariable discount(final RandomVariable rate, final double periodLength) {
		if(rate.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return rate.mult(periodLength).add(1.0).invert().mult(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().discount(rate.getValues(), periodLength),
				Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)),
				OperatorType.DISCOUNT,
				getFactory());
	}

	@Override
	public RandomVariable choose(final RandomVariable valueIfTriggerNonNegative, final RandomVariable valueIfTriggerNegative) {
		return new RandomVariableDifferentiableAAD(
				getValues().choose(valueIfTriggerNonNegative.getValues(), valueIfTriggerNegative.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(valueIfTriggerNonNegative), OperatorTreeNode.of(valueIfTriggerNegative)),
				Arrays.asList(this.getValues(), valueIfTriggerNonNegative.getValues(), valueIfTriggerNegative.getValues()),
				null,
				OperatorType.CHOOSE,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable invert() {
		return new RandomVariableDifferentiableAAD(
				getValues().invert(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.INVERT,
				getFactory());
	}

	@Override
	public RandomVariable abs() {
		return new RandomVariableDifferentiableAAD(
				getValues().abs(),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.ABS,
				getFactory());
	}

	@Override
	public RandomVariable addProduct(final RandomVariable factor1, final double factor2) {
		if(factor1.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return factor1.mult(factor2).add(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().addProduct(factor1.getValues(), factor2),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(factor1), null),
				Arrays.asList(this.getValues(), factor1.getValues(), new Scalar(factor2)),
				null,
				OperatorType.ADDPRODUCT,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable addProduct(final RandomVariable factor1, final RandomVariable factor2) {
		if(factor1.getTypePriority() > this.getTypePriority() || factor2.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return factor1.mult(factor2).add(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().addProduct(factor1.getValues(), factor2.getValues()),
				Arrays.asList(this.getOperatorTreeNode(), OperatorTreeNode.of(factor1), OperatorTreeNode.of(factor2)),
				Arrays.asList(this.getValues(), factor1.getValues(), factor2.getValues()),
				null,
				OperatorType.ADDPRODUCT,
				getFactory(),
				typePriorityDefault);
	}

	@Override
	public RandomVariable addRatio(final RandomVariable numerator, final RandomVariable denominator) {
		if(numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return numerator.div(denominator).add(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().addRatio(numerator.getValues(), denominator.getValues()),
				Arrays.asList(this, numerator, denominator),
				OperatorType.ADDRATIO,
				getFactory());
	}

	@Override
	public RandomVariable subRatio(final RandomVariable numerator, final RandomVariable denominator) {
		if(numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
			// Check type priority
			return numerator.div(denominator).mult(-1).add(this);
		}

		return new RandomVariableDifferentiableAAD(
				getValues().subRatio(numerator.getValues(), denominator.getValues()),
				Arrays.asList(this, numerator, denominator),
				OperatorType.SUBRATIO,
				getFactory());
	}

	/*
	 * The following methods are end points, the result is not differentiable.
	 */

	@Override
	public RandomVariable isNaN() {
		return getValues().isNaN();
	}

	@Override
	public IntToDoubleFunction getOperator() {
		return getValues().getOperator();
	}

	@Override
	public DoubleStream getRealizationsStream() {
		return getValues().getRealizationsStream();
	}

	@Override
	public RandomVariable apply(final DoubleUnaryOperator operator) {
		throw new UnsupportedOperationException("Applying functions is not supported.");
	}

	@Override
	public RandomVariable apply(final DoubleBinaryOperator operator, final RandomVariable argument) {
		throw new UnsupportedOperationException("Applying functions is not supported.");
	}

	@Override
	public RandomVariable apply(final DoubleTernaryOperator operator, final RandomVariable argument1, final RandomVariable argument2) {
		throw new UnsupportedOperationException("Applying functions is not supported.");
	}

	public RandomVariable getVarianceAsRandomVariableAAD(){
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getVariance()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.VARIANCE,
				getFactory());
	}

	public RandomVariable getSampleVarianceAsRandomVariableAAD() {
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getSampleVariance()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.SVARIANCE,
				getFactory());
	}

	public RandomVariable 	getStandardDeviationAsRandomVariableAAD(){
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getStandardDeviation()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.STDEV,
				getFactory());
	}

	public RandomVariable getStandardErrorAsRandomVariableAAD(){
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getStandardError()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.STDERROR,
				getFactory());
	}

	public RandomVariable 	getMinAsRandomVariableAAD(){
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getMin()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.MIN,
				getFactory());
	}

	public RandomVariable 	getMaxAsRandomVariableAAD(){
		/*returns deterministic AAD random variable */
		return new RandomVariableDifferentiableAAD(
				new RandomVariableFromDoubleArray(getMax()),
				Arrays.asList(new RandomVariable[]{ this }),
				OperatorType.MAX,
				getFactory());
	}

	@Override
	public String toString() {
		return "RandomVariableDifferentiableAAD [values=" + values + ",\n ID=" + getID() + "]";
	}

	@Override
	public RandomVariableDifferentiable getCloneIndependent() {
		return new RandomVariableDifferentiableAAD(this.getValues());
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy