All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.expressions.SimpleAutoDiff.kt Maven / Gradle / Ivy

package space.kscience.kmath.expressions

import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract

/*
 * Implementation of backward-mode automatic differentiation.
 * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
 */


public open class AutoDiffValue(public val value: T)

/**
 * Represents result of [simpleAutoDiff] call.
 *
 * @param T the non-nullable type of value.
 * @param value the value of result.
 * @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
 * @property context The field over [T].
 */
public class DerivationResult(
    public val value: T,
    private val derivativeValues: Map,
    public val context: Field,
) {
    /**
     * Returns derivative of [variable] or returns [Ring.zero] in [context].
     */
    public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero

    /**
     * Computes the divergence.
     */
    public fun div(): T = context { sum(derivativeValues.values) }
}

/**
 * Computes the gradient for variables in given order.
 */
public fun  DerivationResult.grad(vararg variables: Symbol): Point {
    check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
    return variables.map(::derivative).asBuffer()
}

/**
 * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
 *
 * The partial derivatives are placed in argument `d` variable
 *
 * Example:
 * ```
 * val x by symbol // define variable(s) and their values
 * val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
 * assertEquals(17.0, y.x) // the value of result (y)
 * assertEquals(9.0, x.d)  // dy/dx
 * ```
 *
 * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
 * @return the result of differentiation.
 */
public fun > F.simpleAutoDiff(
    bindings: Map,
    body: SimpleAutoDiffField.() -> AutoDiffValue,
): DerivationResult {
    contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }

    return SimpleAutoDiffField(this, bindings).differentiate(body)
}

public fun > F.simpleAutoDiff(
    vararg bindings: Pair,
    body: SimpleAutoDiffField.() -> AutoDiffValue,
): DerivationResult = simpleAutoDiff(bindings.toMap(), body)

/**
 * Represents field in context of which functions can be derived.
 */
@OptIn(UnstableKMathAPI::class)
public open class SimpleAutoDiffField>(
    public val context: F,
    bindings: Map,
) : Field>, ExpressionAlgebra>, RingWithNumbers> {
    public override val zero: AutoDiffValue
        get() = const(context.zero)

    public override val one: AutoDiffValue
        get() = const(context.one)

    // this stack contains pairs of blocks and values to apply them to
    private var stack: Array = arrayOfNulls(8)
    private var sp: Int = 0
    private val derivatives: MutableMap, T> = hashMapOf()

    private val bindings: Map> = bindings.entries.associate {
        it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
    }

    /**
     * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
     * with respect to this variable.
     *
     * @param T the non-nullable type of value.
     * @property value The value of this variable.
     */
    private class AutoDiffVariableWithDerivative(
        override val identity: String,
        value: T,
        var d: T,
    ) : AutoDiffValue(value), Symbol {
        override fun toString(): String = identity
        override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
        override fun hashCode(): Int = identity.hashCode()
    }

    public override fun bindSymbolOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity]

    private fun getDerivative(variable: AutoDiffValue): T =
        (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero

    private fun setDerivative(variable: AutoDiffValue, value: T) {
        if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
    }

    @Suppress("UNCHECKED_CAST")
    private fun runBackwardPass() {
        while (sp > 0) {
            val value = stack[--sp]
            val block = stack[--sp] as F.(Any?) -> Unit
            context.block(value)
        }
    }

    override fun const(value: T): AutoDiffValue = AutoDiffValue(value)

    /**
     * A variable accessing inner state of derivatives.
     * Use this value in inner builders to avoid creating additional derivative bindings.
     */
    public var AutoDiffValue.d: T
        get() = getDerivative(this)
        set(value) = setDerivative(this, value)

    public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block())

    /**
     * Performs update of derivative after the rest of the formula in the back-pass.
     *
     * For example, implementation of `sin` function is:
     *
     * ```
     * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
     *     x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
     * }
     * ```
     */
    @Suppress("UNCHECKED_CAST")
    public fun  derive(value: R, block: F.(R) -> Unit): R {
        // save block to stack for backward pass
        if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
        stack[sp++] = block
        stack[sp++] = value
        return value
    }


    internal fun differentiate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult {
        val result = function()
        result.d = context.one // computing derivative w.r.t result
        runBackwardPass()
        return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
    }

    // Overloads for Double constants

    public override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue =
        derive(const { [email protected]() * one + b.value }) { z ->
            b.d += z.d
        }

    public override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this)

    public override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue =
        derive(const { [email protected]() * one - b.value }) { z -> b.d -= z.d }

    public override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue =
        derive(const { [email protected] - one * b.toDouble() }) { z -> [email protected] += z.d }


    // Basic math (+, -, *, /)

    public override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
        derive(const { a.value + b.value }) { z ->
            a.d += z.d
            b.d += z.d
        }

    public override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
        derive(const { a.value * b.value }) { z ->
            a.d += z.d * b.value
            b.d += z.d * a.value
        }

    public override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
        derive(const { a.value / b.value }) { z ->
            a.d += z.d / b.value
            b.d -= z.d * a.value / (b.value * b.value)
        }

    public override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue =
        derive(const { k.toDouble() * a.value }) { z ->
            a.d += z.d * k.toDouble()
        }
}

/**
 * A constructs that creates a derivative structure with required order on-demand
 */
public class SimpleAutoDiffExpression>(
    public val field: F,
    public val function: SimpleAutoDiffField.() -> AutoDiffValue,
) : FirstDerivativeExpression>() {
    public override operator fun invoke(arguments: Map): T {
        //val bindings = arguments.entries.map { it.key.bind(it.value) }
        return SimpleAutoDiffField(field, arguments).function().value
    }

    public override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments ->
        //val bindings = arguments.entries.map { it.key.bind(it.value) }
        val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
        derivationResult.derivative(symbol)
    }
}

/**
 * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
 */
public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField, Expression> =
    AutoDiffProcessor { function ->
        SimpleAutoDiffExpression(field, function)
    }

// Extensions for differentiation of various basic mathematical functions

// x ^ 2
public fun > SimpleAutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue =
    derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }

// x ^ 1/2
public fun > SimpleAutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue =
    derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }

// x ^ y (const)
public fun > SimpleAutoDiffField.pow(
    x: AutoDiffValue,
    y: Double,
): AutoDiffValue =
    derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }

public fun > SimpleAutoDiffField.pow(
    x: AutoDiffValue,
    y: Int,
): AutoDiffValue = pow(x, y.toDouble())

// exp(x)
public fun > SimpleAutoDiffField.exp(x: AutoDiffValue): AutoDiffValue =
    derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }

// ln(x)
public fun > SimpleAutoDiffField.ln(x: AutoDiffValue): AutoDiffValue =
    derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }

// x ^ y (any)
public fun > SimpleAutoDiffField.pow(
    x: AutoDiffValue,
    y: AutoDiffValue,
): AutoDiffValue =
    exp(y * ln(x))

// sin(x)
public fun > SimpleAutoDiffField.sin(x: AutoDiffValue): AutoDiffValue =
    derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }

// cos(x)
public fun > SimpleAutoDiffField.cos(x: AutoDiffValue): AutoDiffValue =
    derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }

public fun > SimpleAutoDiffField.tan(x: AutoDiffValue): AutoDiffValue =
    derive(const { tan(x.value) }) { z ->
        val c = cos(x.value)
        x.d += z.d / (c * c)
    }

public fun > SimpleAutoDiffField.asin(x: AutoDiffValue): AutoDiffValue =
    derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }

public fun > SimpleAutoDiffField.acos(x: AutoDiffValue): AutoDiffValue =
    derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }

public fun > SimpleAutoDiffField.atan(x: AutoDiffValue): AutoDiffValue =
    derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }

public fun > SimpleAutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue =
    derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }

public fun > SimpleAutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue =
    derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }

public fun > SimpleAutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue =
    derive(const { tanh(x.value) }) { z ->
        val c = cosh(x.value)
        x.d += z.d / (c * c)
    }

public fun > SimpleAutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue =
    derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }

public fun > SimpleAutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue =
    derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }

public fun > SimpleAutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue =
    derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

public class SimpleAutoDiffExtendedField>(
    context: F,
    bindings: Map,
) : ExtendedField>, SimpleAutoDiffField(context, bindings) {
    // x ^ 2
    public fun sqr(x: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).sqr(x)

    // x ^ 1/2
    public override fun sqrt(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).sqrt(arg)

    // x ^ y (const)
    public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue =
        (this as SimpleAutoDiffField).pow(arg, pow.toDouble())

    // exp(x)
    public override fun exp(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).exp(arg)

    // ln(x)
    public override fun ln(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).ln(arg)

    // x ^ y (any)
    public fun pow(
        x: AutoDiffValue,
        y: AutoDiffValue,
    ): AutoDiffValue = exp(y * ln(x))

    // sin(x)
    public override fun sin(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).sin(arg)

    // cos(x)
    public override fun cos(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).cos(arg)

    public override fun tan(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).tan(arg)

    public override fun asin(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).asin(arg)

    public override fun acos(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).acos(arg)

    public override fun atan(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).atan(arg)

    public override fun sinh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).sinh(arg)

    public override fun cosh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).cosh(arg)

    public override fun tanh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).tanh(arg)

    public override fun asinh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).asinh(arg)

    public override fun acosh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).acosh(arg)

    public override fun atanh(arg: AutoDiffValue): AutoDiffValue =
        (this as SimpleAutoDiffField).atanh(arg)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy