
org.diffkt.external.Dnnl.kt Maven / Gradle / Ivy
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.diffkt.external
import org.diffkt.FloatTensor
import org.diffkt.Shape
import org.diffkt.StridedFloatTensor
object Dnnl: ExternalLib {
private const val DYLIB_NAME = "libdnnlops_jni"
private var _isLoaded = false
override val isLoaded get() = _isLoaded
init {
try {
loadLib(DYLIB_NAME)
_isLoaded = true
} catch (e: Exception) { }
}
fun add(left: StridedFloatTensor, right: StridedFloatTensor): StridedFloatTensor {
require(left.shape == right.shape) { "Add requires matching tensor shapes" }
return StridedFloatTensor.contiguous(left.shape) {
add(left.shape.dims, left.strides, right.strides, left.offset, right.offset, it, left.data, right.data)
}
}
fun sub(left: StridedFloatTensor, right: StridedFloatTensor): StridedFloatTensor {
require(left.shape == right.shape) { "Sub requires matching tensor shapes" }
return StridedFloatTensor.contiguous(left.shape) {
sub(left.shape.dims, left.strides, right.strides, left.offset, right.offset, it, left.data, right.data)
}
}
fun matmul(left: StridedFloatTensor, right: StridedFloatTensor, a: Shape, b: Shape, d: Shape): StridedFloatTensor {
val newShape = a + b + d
val res = FloatArray(newShape.product)
matmul(left.shape.dims, left.strides, left.offset, right.shape.dims, right.strides, right.offset, res, left.data, right.data)
return StridedFloatTensor(newShape, res)
}
fun mulScalar(x: FloatTensor, alpha: Float): FloatTensor {
val xn = x.normalize()
return StridedFloatTensor.contiguous(x.shape) {
// mulScalar(x.shape.dims, it, x.normalize().data, alpha)
linear(xn.shape.dims, xn.strides, xn.offset, it, xn.data, alpha, 0f)
}
}
/**
* Convenience wrapper for DNNL batchnorm grad.
*
* @return Pair(input grad, scale-and-shift grad)
*/
fun batchNormGrad(
seed: FloatTensor,
input: FloatTensor,
scaleShift: FloatTensor,
mean: FloatTensor,
variance: FloatTensor
): Pair {
require(input.rank == 4 && input.shape == seed.shape) {
"input and seed must be rank 4 and have the same shape"
}
val C = input.shape[3]
require(mean.shape == Shape(C) && variance.shape == mean.shape) { "mean and variance must have Shape($C)" }
require(scaleShift.shape == Shape(2, C)) { "scaleShift must have shape ${Shape(2, C)}" }
val inputGrad = StridedFloatTensor.contigZeros(input.shape)
val scaleShiftGrad = StridedFloatTensor.contigZeros(scaleShift.shape)
batchNormGrad(inputGrad.shape.dims, inputGrad.data, scaleShiftGrad.data,
seed.normalize().data, input.normalize().data, scaleShift.normalize().data, mean.normalize().data,
variance.normalize().data)
return Pair(inputGrad, scaleShiftGrad)
}
// --- External functions ---
private external fun add(
shape: IntArray,
lhsStrides: IntArray,
rhsStrides: IntArray,
lhsOffset: Int,
rhsOffset: Int,
result: FloatArray,
lhs: FloatArray,
rhs: FloatArray
)
external fun batchNorm(
resultShape: IntArray,
result: FloatArray,
mean: FloatArray,
variance: FloatArray,
input: FloatArray,
scaleShift: FloatArray
)
private external fun batchNormGrad(
resultShape: IntArray,
inputGrad: FloatArray,
scaleShiftGrad: FloatArray,
seed: FloatArray,
input: FloatArray,
scaleShift: FloatArray,
mean: FloatArray,
variance: FloatArray
)
external fun conv2d(
resultShape: IntArray,
result: FloatArray,
inputShape: IntArray,
input: FloatArray,
filtersShape: IntArray,
filters: FloatArray,
hstride: Int,
vstride: Int,
paddingLeft: Int,
paddingRight: Int,
paddingTop: Int,
paddingBottom: Int
)
external fun conv2dGradImage(
resultShape: IntArray,
result: FloatArray,
seedShape: IntArray,
seed: FloatArray,
filtersShape: IntArray,
filters: FloatArray,
hstride: Int,
vstride: Int,
paddingLeft: Int,
paddingRight: Int,
paddingTop: Int,
paddingBottom: Int
)
external fun conv2dGradFilter(
resultShape: IntArray,
result: FloatArray,
seedShape: IntArray,
seed: FloatArray,
imagesShape: IntArray,
images: FloatArray,
hstride: Int,
vstride: Int,
paddingLeft: Int,
paddingRight: Int,
paddingTop: Int,
paddingBottom: Int
)
external fun linear(
shape: IntArray,
strides: IntArray,
offset: Int,
res: FloatArray,
input: FloatArray,
scale: Float,
shift: Float
)
external fun logSoftmax(
shape: IntArray,
input: FloatArray,
res: FloatArray,
axis: Int
)
/** Given the result of the forward op and the seed, returns the grad */
external fun logSoftmaxGrad(
shape: IntArray,
grad: FloatArray,
seed: FloatArray,
fwdRes: FloatArray,
axis: Int
)
external fun maxPool(
resultShape: IntArray,
result: FloatArray,
workspace: ByteArray,
imagesShape: IntArray,
images: FloatArray,
poolHeight: Int,
poolWidth: Int
)
external fun maxPoolGrad(
resultShape: IntArray,
result: FloatArray,
workspace: ByteArray,
seedShape: IntArray,
seed: FloatArray,
poolHeight: Int,
poolWidth: Int
)
private external fun mulScalar(
shape: IntArray,
result: FloatArray,
lhs: FloatArray,
rhs: Float
)
external fun avgPool(
resultShape: IntArray,
result: FloatArray,
imagesShape: IntArray,
images: FloatArray,
poolHeight: Int,
poolWidth: Int
)
external fun avgPoolGrad(
resultShape: IntArray,
result: FloatArray,
seedShape: IntArray,
seed: FloatArray,
poolHeight: Int,
poolWidth: Int
)
external fun reduceSum(
resultShape: IntArray,
result: FloatArray,
inputShape: IntArray,
input: FloatArray
)
external fun relu(
shape: IntArray,
result: FloatArray,
input: FloatArray
)
external fun reluGrad(
shape: IntArray,
result: FloatArray,
seed: FloatArray,
input: FloatArray
)
private external fun sub(
shape: IntArray,
lhsStrides: IntArray,
rhsStrides: IntArray,
lhsOffset: Int,
rhsOffset: Int,
result: FloatArray,
lhs: FloatArray,
rhs: FloatArray
)
external fun matmul(
lhsShape: IntArray,
lhsStrides: IntArray,
lhsOffset: Int,
rhsShape: IntArray,
rhsStrides: IntArray,
rhsOffset: Int,
result: FloatArray,
lhs: FloatArray,
rhs: FloatArray,
)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy