
commonMain.space.kscience.kmath.expressions.SimpleAutoDiff.kt Maven / Gradle / Ivy
package space.kscience.kmath.expressions
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
public open class AutoDiffValue(public val value: T)
/**
* Represents result of [simpleAutoDiff] call.
*
* @param T the non-nullable type of value.
* @param value the value of result.
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
* @property context The field over [T].
*/
public class DerivationResult(
public val value: T,
private val derivativeValues: Map,
public val context: Field,
) {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
/**
* Computes the divergence.
*/
public fun div(): T = context { sum(derivativeValues.values) }
}
/**
* Computes the gradient for variables in given order.
*/
public fun DerivationResult.grad(vararg variables: Symbol): Point {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::derivative).asBuffer()
}
/**
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x by symbol // define variable(s) and their values
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation.
*/
public fun > F.simpleAutoDiff(
bindings: Map,
body: SimpleAutoDiffField.() -> AutoDiffValue,
): DerivationResult {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return SimpleAutoDiffField(this, bindings).differentiate(body)
}
public fun > F.simpleAutoDiff(
vararg bindings: Pair,
body: SimpleAutoDiffField.() -> AutoDiffValue,
): DerivationResult = simpleAutoDiff(bindings.toMap(), body)
/**
* Represents field in context of which functions can be derived.
*/
@OptIn(UnstableKMathAPI::class)
public open class SimpleAutoDiffField>(
public val context: F,
bindings: Map,
) : Field>, ExpressionAlgebra>, RingWithNumbers> {
public override val zero: AutoDiffValue
get() = const(context.zero)
public override val one: AutoDiffValue
get() = const(context.one)
// this stack contains pairs of blocks and values to apply them to
private var stack: Array = arrayOfNulls(8)
private var sp: Int = 0
private val derivatives: MutableMap, T> = hashMapOf()
private val bindings: Map> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDerivative(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue(value), Symbol {
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
public override fun bindSymbolOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
private fun setDerivative(variable: AutoDiffValue, value: T) {
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
}
@Suppress("UNCHECKED_CAST")
private fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
override fun const(value: T): AutoDiffValue = AutoDiffValue(value)
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public var AutoDiffValue.d: T
get() = getDerivative(this)
set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block())
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
@Suppress("UNCHECKED_CAST")
public fun derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
internal fun differentiate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
// Overloads for Double constants
public override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue =
derive(const { [email protected]() * one + b.value }) { z ->
b.d += z.d
}
public override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this)
public override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue =
derive(const { [email protected]() * one - b.value }) { z -> b.d -= z.d }
public override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue =
derive(const { [email protected] - one * b.toDouble() }) { z -> [email protected] += z.d }
// Basic math (+, -, *, /)
public override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
derive(const { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
public override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
public override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
public override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue =
derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class SimpleAutoDiffExpression>(
public val field: F,
public val function: SimpleAutoDiffField.() -> AutoDiffValue,
) : FirstDerivativeExpression>() {
public override operator fun invoke(arguments: Map): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value
}
public override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
derivationResult.derivative(symbol)
}
}
/**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/
public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField, Expression> =
AutoDiffProcessor { function ->
SimpleAutoDiffExpression(field, function)
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun > SimpleAutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun > SimpleAutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun > SimpleAutoDiffField.pow(
x: AutoDiffValue,
y: Double,
): AutoDiffValue =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun > SimpleAutoDiffField.pow(
x: AutoDiffValue,
y: Int,
): AutoDiffValue = pow(x, y.toDouble())
// exp(x)
public fun > SimpleAutoDiffField.exp(x: AutoDiffValue): AutoDiffValue =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun > SimpleAutoDiffField.ln(x: AutoDiffValue): AutoDiffValue =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun > SimpleAutoDiffField.pow(
x: AutoDiffValue,
y: AutoDiffValue,
): AutoDiffValue =
exp(y * ln(x))
// sin(x)
public fun > SimpleAutoDiffField.sin(x: AutoDiffValue): AutoDiffValue =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun > SimpleAutoDiffField.cos(x: AutoDiffValue): AutoDiffValue =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun > SimpleAutoDiffField.tan(x: AutoDiffValue): AutoDiffValue =
derive(const { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun > SimpleAutoDiffField.asin(x: AutoDiffValue): AutoDiffValue =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun > SimpleAutoDiffField.acos(x: AutoDiffValue): AutoDiffValue =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun > SimpleAutoDiffField.atan(x: AutoDiffValue): AutoDiffValue =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun > SimpleAutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue =
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun > SimpleAutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue =
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun > SimpleAutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue =
derive(const { tanh(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun > SimpleAutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun > SimpleAutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun > SimpleAutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
public class SimpleAutoDiffExtendedField>(
context: F,
bindings: Map,
) : ExtendedField>, SimpleAutoDiffField(context, bindings) {
// x ^ 2
public fun sqr(x: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).sqr(x)
// x ^ 1/2
public override fun sqrt(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).sqrt(arg)
// x ^ y (const)
public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue =
(this as SimpleAutoDiffField).pow(arg, pow.toDouble())
// exp(x)
public override fun exp(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).exp(arg)
// ln(x)
public override fun ln(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).ln(arg)
// x ^ y (any)
public fun pow(
x: AutoDiffValue,
y: AutoDiffValue,
): AutoDiffValue = exp(y * ln(x))
// sin(x)
public override fun sin(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).sin(arg)
// cos(x)
public override fun cos(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).cos(arg)
public override fun tan(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).tan(arg)
public override fun asin(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).asin(arg)
public override fun acos(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).acos(arg)
public override fun atan(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).atan(arg)
public override fun sinh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).sinh(arg)
public override fun cosh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).cosh(arg)
public override fun tanh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).tanh(arg)
public override fun asinh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).asinh(arg)
public override fun acosh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).acosh(arg)
public override fun atanh(arg: AutoDiffValue): AutoDiffValue =
(this as SimpleAutoDiffField).atanh(arg)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy