
org.diffkt.Operations.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of diffkt Show documentation
Show all versions of diffkt Show documentation
Automatic differentiation in Kotlin
The newest version!
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.diffkt
import org.diffkt.Convolve.PaddingStyle
import org.diffkt.model.BatchNormResult
import org.diffkt.model.baseBatchNorm
import org.diffkt.random.RandomKey
import org.diffkt.tracing.TracingTensorOperations
import shapeTyping.annotations.AllowUnreduced
import shapeTyping.annotations.SType
@AllowUnreduced
interface Operations {
val name: String
/** For a scalar operations, returns the corresponding tensor operations. */
val tensor: Operations get() = this
@SType("S: Shape")
fun plus(left: @SType("S") DTensor, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor
@SType("S: Shape")
fun minus(left: @SType("S") DTensor, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor
@SType("S: Shape")
fun times(left: @SType("S") DTensor, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor
@SType("S: Shape")
fun timesScalar(left: DScalar, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor
@SType("S: Shape")
fun div(left: @SType("S") DTensor, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor =
left * right.pow(-1)
@SType("S: Shape")
fun zeroOfSameKind(x: DTensor, shape: @SType("S") Shape): @SType("S") DTensor
@SType("S: Shape")
fun identityGradientOfSameKind(x: DTensor, halfShape: @SType("S") Shape): @SType("concat(S,S)") DTensor
@SType("S: Shape")
fun unaryMinus(x: @SType("S") DTensor): @SType("S") DTensor
fun matmul(x: DTensor, y: DTensor, a: Shape, b: Shape, c: Shape, d: Shape, derivativeId: DerivativeID): DTensor
@SType("S1: Shape, S2: Shape")
fun outerProduct(x: @SType("S1") DTensor, y: @SType("S2") DTensor, derivativeId: DerivativeID): @SType("concat(S1, S2)") DTensor
@SType("S: Shape") fun sin(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun cos(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun tan(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun atan(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun exp(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun ln(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun lgamma(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun digamma(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun polygamma(n: Int, x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun sqrt(x: @SType("S") DTensor): DTensor
@SType("S: Shape") fun tanh(x: @SType("S") DTensor): DTensor
fun meld(values: List, derivativeId: DerivativeID): DTensor
fun split(x: DTensor, shapes: List): List
@SType("S1: Shape, S2: Shape, A: Dim")
fun concat(
left: @SType("S1") DTensor,
right: @SType("S2") DTensor,
axis: @SType("A") Int,
derivativeId: DerivativeID
): @SType("concatOnAxis(S1, S2, A)") DTensor
fun concat(slices: List, axis: Int, derivativeId: DerivativeID): DTensor
@SType("S: Shape")
fun broadcastTo(x: DTensor, newShape: @SType("S") Shape): @SType("S") DTensor
/**
* Applies convolution to the tensor signal, using filter. Both signal and filter must be of rank 4.
* The expected shape of signal is NHWC (num signal, height, width, channels)
* and the expected filter shape is OHWI (output channels, height, width, input channels) where C == I
*
* The type of padding to be applied to the signal can be of type
* [PaddingStyle.Valid], [PaddingStyle.Same], [PaddingStyle.Full] or [PaddingStyle.Explicit]
*/
fun convImpl(
signal: DTensor,
filter: DTensor,
hStride: Int,
vStride: Int,
padding: Convolve.Padding2D,
derivativeId: DerivativeID
): DTensor
@SType("S: Shape")
fun expand(x: DTensor, newShape: @SType("S") Shape): @SType("S") DTensor
@SType("S: Shape")
fun flip(x: @SType("S") DTensor, axes: IntArray): @SType("S") DTensor
fun logSoftmax(x: DTensor, axis: Int): DTensor = baseLogSoftmax(x, axis)
fun logSoftmaxGrad(x: DTensor, axis: Int, logSoftmax: DTensor, upstream: DTensor): DTensor
@SType("S: Shape")
fun pow(base: @SType("S") DTensor, exponent: Float): @SType("S") DTensor
fun view1(x: DTensor, indices: IntArray): DTensor
fun view2(x: DTensor, index: Int, axis: Int): DTensor
fun view3(x: DTensor, index: IntRange, axis: Int): DTensor
fun reshape(x: DTensor, newShape: Shape): DTensor
fun reshapeToScalar(x: DTensor): DScalar
fun squeeze(x: DTensor, axis: Int): DTensor
fun unsqueeze(x: DTensor, axis: Int): DTensor
fun transpose(x: DTensor, axes: IntArray): DTensor
@SType("S: Shape") fun relu(x: @SType("S") DTensor): @SType("S") DTensor
@SType("S: Shape") fun sigmoid(x: @SType("S") DTensor): @SType("S")DTensor
fun reluGrad(x: DTensor, reluUpstream: DTensor, derivativeId: DerivativeID): DTensor
fun sum(x: DTensor, axes: IntArray, keepDims: Boolean): DTensor
fun avgPool(x: DTensor, poolHeight: Int, poolWidth: Int): DTensor
fun avgPoolGrad(x: DTensor, poolHeight: Int, poolWidth: Int): DTensor
fun batchNorm(input: DTensor, scaleShift: DTensor, derivativeId: DerivativeID): BatchNormResult =
baseBatchNorm(input, scaleShift)
fun maxPoolWithIndices(
x: DTensor,
poolHeight: Int,
poolWidth: Int,
withIndices: Boolean): Pair?>
fun gather(x: DTensor, indices: List, axis: Int, paddingIndex: Int): DTensor
fun gatherAtIndices(x: DTensor, indices: List): DTensor
fun scatter(x: DTensor, indices: List, axis: Int, newShape: Shape, paddingIndex: Int): DTensor
fun scatterAtIndices(x: DTensor, indices: List, newShape: Shape): DTensor
fun gamma(alpha: DTensor, randomKey: RandomKey): DTensor
@SType("S: Shape")
fun compare(left: @SType("S") DTensor, right: @SType("S") DTensor, comparison: ComparisonKind): @SType("S") DTensor
@SType("S: Shape")
fun ifThenElse(condition: @SType("S") DTensor, whenTrue: @SType("S") DTensor, whenFalse: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor
}
/**
* For a binary operation, find the common operations kind.
*/
internal fun commonKind(l: DTensor, r: DTensor): Pair {
val ls = l.derivativeID.sequence
val rs = r.derivativeID.sequence
if (ls > rs) {
val operations = if (l.isScalar && !r.isScalar) l.operations.tensor else l.operations
return Pair(operations, l.derivativeID)
}
if (rs > ls || rs != 0) {
val operations = if (r.isScalar && !l.isScalar) r.operations.tensor else r.operations
return Pair(operations, r.derivativeID)
}
assert(rs == 0 && ls == 0)
if (l.operations == TracingTensorOperations || r.operations == TracingTensorOperations)
return Pair(TracingTensorOperations, NoDerivativeID)
if (l.operations == SparseFloatTensorOperations || r.operations == SparseFloatTensorOperations)
return Pair(SparseFloatTensorOperations, NoDerivativeID)
if (l.operations == SparseRowFloatTensorOperations || r.operations == SparseRowFloatTensorOperations)
return Pair(SparseRowFloatTensorOperations, NoDerivativeID)
if (l.operations == StridedFloatTensorOperations || r.operations == StridedFloatTensorOperations)
return Pair(StridedFloatTensorOperations, NoDerivativeID)
if (l.operations == FloatScalarOperations && r.isScalar || r.operations == FloatScalarOperations && l.isScalar)
return Pair(FloatScalarOperations, NoDerivativeID)
if (l.operations == r.operations) // Sparse or GPU
return Pair(l.operations, NoDerivativeID)
throw Error("Incompatible tensor kinds: ${l.operations.name} and ${r.operations.name}")
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy