All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.diffkt.FloatTensorOperations.kt Maven / Gradle / Ivy

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

package org.diffkt

import org.diffkt.external.*
import org.diffkt.model.BatchNormResult
import org.diffkt.model.baseBatchNorm
import org.diffkt.random.RandomKey
import org.diffkt.random.Sha512Random
import kotlin.math.ceil
import kotlin.math.pow
import shapeTyping.annotations.AllowUnreduced
import shapeTyping.annotations.SType

/**
 * Operations that can be shared among multiple float tensor implementations.
 */
internal abstract class FloatTensorOperations : Operations {
    @SType("S: Shape")
    private fun wrap(value: @SType("S") DTensor): @SType("S") FloatTensor {
        if (value is FloatTensor) return value
        TODO("Cannot (automatically) convert to FloatTensor")
    }

    @SType("S: Shape")
    override fun plus(
        left: @SType("S") DTensor,
        right: @SType("S") DTensor,
        derivativeId: DerivativeID
    ): @SType("S") DTensor {
        require(derivativeId == NoDerivativeID)
        val l = wrap(left)
        val r = wrap(right)
        return l.zip(r) { x, y -> x + y }
    }

    @SType("S: Shape")
    override fun minus(
        left: @SType("S") DTensor,
        right: @SType("S") DTensor,
        derivativeId: DerivativeID
    ): @SType("S") DTensor {
        require(derivativeId == NoDerivativeID)
        val l = wrap(left)
        val r = wrap(right)
        return l.zip(r) { x, y -> x - y }
    }

    @SType("S: Shape")
    override fun times(
        left: @SType("S") DTensor,
        right: @SType("S") DTensor,
        derivativeId: DerivativeID
    ): @SType("S") DTensor {
        require(derivativeId == NoDerivativeID)
        return wrap(left).zip(wrap(right)) { l, r -> l * r }
    }

    @SType("S: Shape")
    override fun timesScalar(left: DScalar, right: @SType("S") DTensor, derivativeId: DerivativeID): @SType("S") DTensor {
        require(derivativeId == NoDerivativeID)
        val r = wrap(right)
        require(left is FloatScalar)
        val leftValue = left.value
        return r.map { x -> leftValue * x }
    }

    @SType("S: Shape")
    override fun div(
        left: @SType("S") DTensor,
        right: @SType("S") DTensor,
        derivativeId: DerivativeID
    ): @SType("S") DTensor {
        require(left is FloatTensor)
        require(right is FloatTensor)
        return left.zip(right) { xx, yy -> xx / yy }
    }

    @SType("S: Shape")
    override fun zeroOfSameKind(x: DTensor, shape: @SType("S") Shape): @SType("S") FloatTensor {
        return FloatTensor.zeros(shape)
    }

    @SType("S: Shape")
    @AllowUnreduced
    override fun identityGradientOfSameKind(x: DTensor, halfShape: @SType("S") Shape): @SType("concat(S,S)") FloatTensor {
        return StridedFloatTensor.identityGradient(halfShape)
    }

    @SType("S: Shape")
    override fun unaryMinus(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { -it }
    }

    override fun matmul(
        x: DTensor,
        y: DTensor,
        a: Shape,
        b: Shape,
        c: Shape,
        d: Shape,
        derivativeId: DerivativeID
    ): DTensor {
        val left = (x as FloatTensor).asStrided()
        val right = (y as FloatTensor).asStrided()
        require(c.rank != 0)
        if (a.rank < 12 && b.rank == 1 && c.rank == 1 && d.rank == 1)
            return Dnnl.matmul(left, right, a, b, d)

        fun processC(axis: Int, leftOffset: Int, rightOffset: Int): Float {
            assert(axis < c.rank)
            val leftStride = left.strides[a.rank + b.rank + axis]
            val rightStride = right.strides[a.rank + axis]
            var leftOff = leftOffset
            var rightOff = rightOffset
            var sum = 0F
            if (axis == c.rank - 1) {
                // This is the inner loop.  Can Dnnl help?
                for (i in 0 until c[axis]) {
                    sum += left.data[leftOff] * right.data[rightOff]
                    leftOff += leftStride
                    rightOff += rightStride
                }
            } else {
                for (i in 0 until c[axis]) {
                    sum += processC(axis + 1, leftOff, rightOff)
                    leftOff += leftStride
                    rightOff += rightStride
                }
            }

            return sum
        }

        val resultShape = a + b + d
        val data = FloatArray(resultShape.product)
        var next = 0

        fun processD(axis: Int, leftOffset: Int, rightOffset: Int) {
            assert(axis <= d.rank)
            if (d.rank == 0) {
                data[next++] = processC(0, leftOffset, rightOffset)
                return
            }
            val rightStride = right.strides[a.rank + c.rank + axis]
            var rightOff = rightOffset
            if (axis == d.rank - 1) {
                for (i in 0 until d[axis]) {
                    data[next++] = processC(0, leftOffset, rightOff)
                    rightOff += rightStride
                }
            } else {
                for (i in 0 until d[axis]) {
                    processD(axis + 1, leftOffset, rightOff)
                    rightOff += rightStride
                }
            }
        }

        fun processB(axis: Int, leftOffset: Int, rightOffset: Int) {
            if (axis >= b.rank) {
                processD(0, leftOffset, rightOffset)
                return
            }

            val leftStride = left.strides[a.rank + axis]
            var leftOff = leftOffset
            for (i in 0 until b[axis]) {
                processB(axis + 1, leftOff, rightOffset)
                leftOff += leftStride
            }
        }

        fun processA(axis: Int, leftOffset: Int, rightOffset: Int) {
            if (axis >= a.rank) {
                processB(0, leftOffset, rightOffset)
                return
            }

            val leftStride = left.strides[axis]
            val rightStride = right.strides[axis]
            var leftOff = leftOffset
            var rightOff = rightOffset
            for (i in 0 until a[axis]) {
                processA(axis + 1, leftOff, rightOff)
                leftOff += leftStride
                rightOff += rightStride
            }
        }

        processA(0, left.offset, right.offset)
        assert(next == data.size)
        return FloatTensor(resultShape, data)
    }

    @SType("S1: Shape, S2: Shape")
    @AllowUnreduced
    override fun outerProduct(
        x: @SType("S1") DTensor,
        y: @SType("S2") DTensor,
        derivativeId: DerivativeID
    ): @SType("concat(S1, S2)") DTensor {
        val left = (x as @SType("S1") FloatTensor).asStrided()
        val right = (y as @SType("S2") FloatTensor).asStrided()
        val resultData = FloatArray(left.size * right.size)
        var k = 0
        for (i in 0 until left.size) {
            val l = left.at(i)
            for (j in 0 until right.size) {
                val r = right.at(j)
                resultData[k++] = l * r
            }
        }
        return FloatTensor(left.shape + right.shape, resultData)
    }

    @SType("S: Shape")
    override fun sin(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.sin(it) }
    }

    @SType("S: Shape")
    override fun cos(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.cos(it) }
    }

    @SType("S: Shape")
    override fun tan(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.tan(it) }
    }

    @SType("S: Shape")
    override fun atan(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.atan(it) }
    }

    @SType("S: Shape")
    override fun exp(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.exp(it) }
    }

    @SType("S: Shape")
    override fun ln(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.ln(it) }
    }

    @SType("S: Shape")
    override fun lgamma(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { Math.lgamma(it) }
    }

    @SType("S: Shape")
    override fun digamma(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { Math.digamma(it) }
    }

    @SType("S: Shape")
    override fun polygamma(n: Int, x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { Math.polygamma(n, it) }
    }

    @SType("S: Shape")
    override fun sqrt(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.sqrt(it) }
    }

    @SType("S: Shape")
    override fun tanh(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { kotlin.math.tanh(it) }
    }

    override fun meld(values: List, derivativeId: DerivativeID): DTensor {
        require(derivativeId == NoDerivativeID)
        // Compute the sizes of the values
        val totalSize = values.map { it.shape.product() }.sum()
        val serializedData = FloatArray(totalSize)
        var i = 0
        for (v in values) {
            when (v) {
                is FloatScalar ->
                    serializedData[i++] = v.value
                is FloatTensor -> {
                    for (pos in 0 until v.size) serializedData[i++] = v.at(pos)
                }
                else -> throw IllegalArgumentException("meld not supported for ${v::class.qualifiedName}")
            }
        }
        assert(i == totalSize)
        return FloatTensor(Shape(totalSize), serializedData)
    }

    override fun split(x: DTensor, shapes: List): List {
        require(x is FloatTensor)
        val sizes = shapes.map { it.product() }
        var nextStart = 0
        return List(shapes.size) {
            val shape = shapes[it]
            val size = sizes[it]
            val partData = FloatArray(size) { i -> x.at(i + nextStart) }
            val part = FloatTensor(shape, partData)
            nextStart += size
            part
        }
    }

    @AllowUnreduced
    @SType("S1: Shape, S2: Shape, A: Dim")
    override fun concat(
        left: @SType("S1")  DTensor,
        right: @SType("S2") DTensor,
        axis: @SType("A") Int,
        derivativeId: DerivativeID
    ): @SType("concatOnAxis(S1, S2, A)") DTensor {
        require(derivativeId == NoDerivativeID)
        val l = (left as FloatTensor).asStrided()
        val r = (right as FloatTensor).asStrided()
        return concat(listOf(l, r), axis, derivativeId)
    }

    override fun concat(slices: List, axis: Int, derivativeId: DerivativeID): DTensor {
        /** Copy a value to a section of a destination array */
        fun fillDim(
            value: Float,
            dest: FloatArray,
            destShape: Shape,
            destStart: Int,
            len: Int,
            axis: Int,
            skipZero: Boolean
        ) {
            if (skipZero && value == 0f) return

            val nframes = destShape.take(axis).product()
            val cellSize = destShape.drop(axis + 1).product()
            val destFrameSize = destShape[axis]

            for (i in 0 until nframes) {
                val dstOffset = ((i * destFrameSize) + destStart) * cellSize
                for (cellOffset in 0 until cellSize * len) {
                    dest[dstOffset + cellOffset] = value
                }
            }
        }

        /**
         * copy some portion of a tensor data array along a particular dimension
         * to a destination array.
         * NOTE: assumes src Tensor has Natural (contiguous, full) layout
         */
        fun copyDim(
            src: FloatArray,
            srcShape: Shape,
            srcStart: Int,
            dest: FloatArray,
            destShape: Shape,
            destStart: Int,
            len: Int,
            axis: Int
        ): FloatArray {
            val nframes = destShape.take(axis).product()
            val cellSize = destShape.drop(axis + 1).product()
            val srcFrameSize = srcShape[axis]
            val destFrameSize = destShape[axis]

            var i = 0
            while (i < nframes) {
                val srcOff = ((i * srcFrameSize) + srcStart) * cellSize
                val dstOff = ((i * destFrameSize) + destStart) * cellSize
                System.arraycopy(src, srcOff, dest, dstOff, cellSize * len)
                i += 1
            }
            return dest
        }

        /** helper, calls [copyDim] or [fillDim] as needed to copy source tensor data to dest
         * as specified.
         * Note: may copy src tensor to a contiguous layout version, in which case that tensor
         * is returned, otherwise src itself.
         *
         * @param skipZero doesn't copy a singleton zero over to the dst array.
         *   This should be used as an optimization when the dst array was already
         *   initialized to zero.
         */
        fun copyDimFromTensor(
            src: StridedFloatTensor,
            srcStart: Int,
            dest: FloatArray,
            destShape: Shape,
            destStart: Int,
            len: Int,
            axis: Int,
            skipZero: Boolean = false
        ): FloatTensor {
            return if (src.layout == StridedUtils.Layout.SINGLETON) {
                fillDim(src.data[src.offset], dest, destShape, destStart, len, axis, skipZero)
                src
            } else {
                val srcContig = src.normalize()
                copyDim(srcContig.data, srcContig.shape, srcStart, dest, destShape, destStart, len, axis)
                srcContig
            }
        }

        require(derivativeId == NoDerivativeID)
        require(slices.size > 0)
        val stridedSlices = slices.map { (it as FloatTensor).asStrided() }
        val first = stridedSlices[0]
        require(axis >= 0 && axis < first.rank)
        if (stridedSlices.size == 1)
            return first
        val s = first.shape.updated(axis, 1)
        require(stridedSlices.all { first.rank == it.rank && s == it.shape.updated(axis, 1) })

        val newDim = stridedSlices.map { it.shape[axis] }.sum()
        val newShape = s.updated(axis, newDim)
        val newData = FloatArray(newShape.product)
        var next = 0
        for (slice in stridedSlices) {
            val amount = slice.shape[axis]
            copyDimFromTensor(slice, 0, newData, newShape, next, amount, axis, skipZero = true)
            next += amount
        }
        return FloatTensor(newShape, newData)
    }

    @SType("S: Shape")
    override fun broadcastTo(x: DTensor, newShape: @SType("S") Shape): @SType("S") DTensor {
        require(x is FloatTensor)
        val s = x.asStrided()
        val (newStrides, _) = Broadcasting.getBroadcastedStridesAndAxes(s.shape, s.strides, newShape)
        return StridedFloatTensor(newShape, offset = s.offset, newStrides, s.data)
    }

    override fun convImpl(
        signal: DTensor,
        filter: DTensor,
        hStride: Int,
        vStride: Int,
        padding: Convolve.Padding2D,
        derivativeId: DerivativeID
    ): DTensor {
        require(signal is FloatTensor)
        require(filter is FloatTensor)
        val sN = signal.normalize()
        val fN = filter.normalize()

        require(shouldSendToCpp(0, Dnnl, sN, fN))
        val signalShape = sN.shape
        val filterShape = fN.shape

        val imageChannels = signalShape[3]
        val filterChannels = filterShape[3]

        if (imageChannels != filterChannels)
            throw RuntimeException("the size of the filter's inChannel ($filterChannels) must match the input depth ($imageChannels)")
        if (hStride < 1 || vStride < 1)
            throw RuntimeException("Horizontal stride ($hStride) and vertical stride ($vStride) must be greater than 0.")

        // Make out shape
        val numsignal = signalShape[Convolve.N_AXIS]
        val numfilter = filterShape[Convolve.N_AXIS]

        val endRow = signalShape[Convolve.H_AXIS] + padding.bottom - filterShape[Convolve.H_AXIS]
        val endCol = signalShape[Convolve.W_AXIS] + padding.right - filterShape[Convolve.W_AXIS]

        val outHeight = ceil((endRow + padding.top + 1).toFloat() / vStride).toInt()
        val outWidth = ceil((endCol + padding.left + 1).toFloat() / hStride).toInt()

        val outShape = Shape(numsignal, outHeight, outWidth, numfilter)
        // End make out shape

        return StridedFloatTensor.contiguous(outShape) {
            Dnnl.conv2d(
                // output
                outShape.dims,
                it,
                // imgs
                signalShape.dims,
                sN.data,
                // filter
                filterShape.dims,
                fN.data,
                // strides
                vStride, // height
                hStride, // width
                // padding
                padding.left,
                padding.right,
                padding.top,
                padding.bottom
            )
        }
    }

    /**
     * Helper function to do the stride calculations for expand.
     *
     * @return strides for the return value of expand
     */
    private fun expandStrides(x: StridedFloatTensor, shape: Shape): IntArray {
        require(x.rank == shape.rank)
        return IntArray(x.rank) { i ->
            if (x.shape[i] != shape[i]) 0 else x.strides[i]
        }
    }
    override fun expand(x: DTensor, newShape: Shape): DTensor {
        require(x is FloatTensor)
        val s = x.asStrided()
        val newStrides = expandStrides(s, newShape)
        return StridedFloatTensor(newShape, s.offset, newStrides, s.data)
    }

    /**
     * A copy-free implementation of flip for strided float tensors.
     */
    private fun flip(x: StridedFloatTensor, axis: Int): StridedFloatTensor {
        require(axis >= 0 && axis < x.rank)
        if (x.shape[axis] == 1 || x.strides[axis] == 0) return x
        val newOffset = x.offset + (x.shape[axis] - 1) * x.strides[axis]
        val newStrides = x.strides.clone(); newStrides[axis] = -x.strides[axis]
        return StridedFloatTensor(x.shape, newOffset, newStrides, x.data)
    }

    @SType("S: Shape")
    override fun flip(x: @SType("S") DTensor, axes: IntArray): @SType("S") DTensor {
        require(x is FloatTensor)
        return axes.fold(x.asStrided()) { a, i -> flip(a, i) }
    }

    override fun transpose(x: DTensor, axes: IntArray): DTensor {
        require(x is FloatTensor)
        val n = x.asStrided()
        return n.operations.transpose(n, axes)
    }

    override fun logSoftmax(x: DTensor, axis: Int): DTensor {
        require(x is FloatTensor)
        val normalized = x.normalize()
        val resData = FloatArray(normalized.size)
        Dnnl.logSoftmax(normalized.shape.dims, normalized.data, resData, axis)
        return FloatTensor(normalized.shape, resData)
    }

    override fun logSoftmaxGrad(x: DTensor, axis: Int, logSoftmax: DTensor, upstream: DTensor): DTensor {
        require(upstream.shape == x.shape) {
            "LogSoftmax does not support derivatives of functions that do not return scalars"
        }
        val normalUpstream = wrap(upstream).normalize()
        val normalLogSoftmax = wrap(logSoftmax).normalize()
        val gradData = FloatArray(x.size)
        Dnnl.logSoftmaxGrad(x.shape.dims, gradData, normalUpstream.data, normalLogSoftmax.data, axis)
        return FloatTensor(x.shape, gradData)
    }

    @SType("S: Shape")
    override fun pow(base: @SType("S") DTensor, exponent: Float): @SType("S") DTensor {
        require(base is FloatTensor)
        return base.map { it.pow(exponent) }
    }

    override fun view1(x: DTensor, indices: IntArray): DTensor {
        require(x is FloatTensor)
        return StridedFloatTensorOperations.view1(x.asStrided(), indices)
    }

    override fun view2(x: DTensor, index: Int, axis: Int): DTensor {
        require(x is FloatTensor)
        return StridedFloatTensorOperations.view2(x.asStrided(), index, axis)
    }

    override fun view3(x: DTensor, index: IntRange, axis: Int): DTensor {
        require(x is FloatTensor)
        return StridedFloatTensorOperations.view3(x.asStrided(), index, axis)
    }

    override fun reshape(x: DTensor, newShape: Shape): DTensor {
        require(x is FloatTensor)
        val n = x.normalize()
        return n.operations.reshape(n, newShape)
    }

    override fun reshapeToScalar(x: DTensor): DScalar {
        require(x is FloatTensor)
        return FloatScalar(x.at(0))
    }

    override fun squeeze(x: DTensor, axis: Int): DTensor {
        require(x is FloatTensor)
        return StridedFloatTensorOperations.squeeze(x.asStrided(), axis)
    }

    override fun unsqueeze(x: DTensor, axis: Int): DTensor {
        require(x is FloatTensor)
        return StridedFloatTensorOperations.unsqueeze(x.asStrided(), axis)
    }

    @SType("S: Shape")
    override fun relu(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { v -> if (v <= 0F) 0f else v }
    }

    override fun reluGrad(x: DTensor, reluUpstream: DTensor, derivativeId: DerivativeID): DTensor {
        val xx = x.expandAndBroadcastToTangent(reluUpstream)
        require(xx is FloatTensor)
        require(reluUpstream is FloatTensor)
        return xx.zip(reluUpstream) { x1, up1 -> if (x1 <= 0F) 0f else up1 }
    }

    @SType("S: Shape")
    override fun sigmoid(x: @SType("S") DTensor): @SType("S") DTensor {
        require(x is FloatTensor)
        return x.map { sigmoidElem(it) }
    }

    override fun sum(x: DTensor, axes: IntArray, keepDims: Boolean): DTensor {
        require(x is FloatTensor)
        return x.reduce(Float::plus, axes, keepDims)
    }

    override fun avgPool(x: DTensor, poolHeight: Int, poolWidth: Int): DTensor {
        require(x is FloatTensor)
        val numItems = x.shape[Convolve.N_AXIS]
        val inHeight = x.shape[Convolve.H_AXIS]
        val inWidth = x.shape[Convolve.W_AXIS]
        val channels = x.shape.drop(Convolve.C_AXIS)
        val numChannels = channels.product

        require(inHeight % poolHeight == 0) {
            "input height ($inHeight) must be divisible by pool height ($poolHeight)" }
        require(inWidth % poolWidth == 0) {
            "input width ($inWidth) must be divisible by pool width ($poolWidth)" }

        val outHeight = inHeight / poolHeight
        val outWidth = inWidth / poolWidth

        val outShape = Shape(numItems, outHeight, outWidth) + channels
        val outStream = FloatArray(outShape.product)
        Dnnl.avgPool(
            // result
            intArrayOf(numItems, outHeight, outWidth, numChannels),
            outStream,
            // input
            intArrayOf(numItems, inHeight, inWidth, numChannels),
            x.normalize().data,
            // pool height and width
            poolHeight,
            poolWidth
        )

        return FloatTensor(outShape, outStream)
    }

    override fun avgPoolGrad(x: DTensor, poolHeight: Int, poolWidth: Int): DTensor {
        require(x is FloatTensor)
        val numItems = x.shape[Convolve.N_AXIS]
        val inHeight = x.shape[Convolve.H_AXIS]
        val inWidth = x.shape[Convolve.W_AXIS]
        val channels = x.shape.drop(Convolve.C_AXIS)
        val numChannels = channels.product
        val outHeight = inHeight * poolHeight
        val outWidth = inWidth * poolWidth

        val outShape = Shape(numItems, outHeight, outWidth) + channels
        val outStream = FloatArray(outShape.product)
        Dnnl.avgPoolGrad(
            // result
            intArrayOf(numItems, outHeight, outWidth, numChannels),
            outStream,
            // seed
            intArrayOf(numItems, inHeight, inWidth, numChannels),
            x.normalize().data,
            // pool height and width
            poolHeight,
            poolWidth
        )
        return FloatTensor(outShape, outStream)
    }

    override fun batchNorm(input: DTensor, scaleShift: DTensor, derivativeId: DerivativeID): BatchNormResult {
        return baseBatchNorm(input, scaleShift)
    }

    override fun maxPoolWithIndices(
        x: DTensor,
        poolHeight: Int,
        poolWidth: Int,
        withIndices: Boolean
    ): Pair?> {
        require(x is FloatTensor)
        val N_AXIS = 0
        val H_AXIS = 1
        val W_AXIS = 2
        val C_AXIS = 3

        val numItems = x.shape[N_AXIS]
        val inHeight = x.shape[H_AXIS]
        val inWidth = x.shape[W_AXIS]
        val numChannels = x.shape[C_AXIS]
        require(inHeight % poolHeight == 0) {
            "input height ($inHeight) must be divisible by pool height ($poolHeight)" }
        require(inWidth % poolWidth == 0) {
            "input width ($inWidth) must be divisible by pool width ($poolWidth)" }
        val outHeight = inHeight / poolHeight
        val outWidth = inWidth / poolWidth
        val outShape = Shape(numItems, outHeight, outWidth, numChannels)

        val values = FloatArray(outShape.product)
        var nextValuePos = 0
        val positions: ArrayList? = if (withIndices) ArrayList() else null

        val indices = IntArray(4)
        for (item in 0 until numItems) {
            indices[N_AXIS] = item
            for (h0 in 0 until inHeight step poolHeight) {
                for (w0 in 0 until inWidth step poolWidth) {
                    for (channel in 0 until numChannels) {
                        indices[C_AXIS] = channel
                        var maxH = 0
                        var maxW = 0
                        var maxValue = Float.NEGATIVE_INFINITY
                        for (h1 in 0 until poolHeight) {
                            for (w1 in 0 until poolHeight) {
                                val h = h0 + h1
                                val w = w0 + w1
                                indices[H_AXIS] = h
                                indices[W_AXIS] = w
                                val value = x.getAt(indices)
                                if (value > maxValue) {
                                    maxH = h
                                    maxW = w
                                    maxValue = value
                                }
                            }
                        }
                        indices[H_AXIS] = maxH
                        indices[W_AXIS] = maxW
                        values[nextValuePos++] = maxValue
                        positions?.add(indices.clone())
                    }
                }
            }
        }

        return Pair(FloatTensor(outShape, values), positions)
    }

    override fun gather(x: DTensor, indices: List, axis: Int, paddingIndex: Int): DTensor {
        require(x is FloatTensor)

        val rank = x.rank
        require(rank > 0) { "gather: tensor must be of rank > 0" }
        require(axis >= 0) { "gather: axis $axis < 0" }
        require(axis < rank) { "gather: axis $axis >= tensor rank $rank" }
        require(indices.all { it < x.shape[axis] })

        val newShape = x.shape.updated(axis, indices.size)
        val newData = FloatArray(newShape.product)

        val numOfIterations = x.shape.take(axis).product
        val elementsPerIteration = x.shape.drop(axis + 1).product
        val stride = x.shape.drop(axis).product

        var idx = 0
        for (n in 0 until numOfIterations) {
            for (i in indices) {
                if (i == paddingIndex) {
                    idx += elementsPerIteration
                } else {
                    val startingPoint = (n * stride) + (i * elementsPerIteration)
                    for (e in 0 until elementsPerIteration) {
                        newData[idx] = x.at(startingPoint + e)
                        idx++
                    }
                }
            }
        }
        return FloatTensor(newShape, newData)
    }

    override fun gatherAtIndices(x: DTensor, indices: List): DTensor {
        require(x is FloatTensor)
        require(indices.all { it.size == indices[0].size })
        // TODO: https://github.com/facebookincubator/diffkt/issues/89 Dnnl could probably help us here
        val shapePerIndex = x.shape.drop(indices[0].size)
        val elementsPerIndex = shapePerIndex.product
        val totalSize = indices.size * elementsPerIndex
        val newData = FloatArray(totalSize)
        val contigStrides = StridedUtils.contigStrides(x.shape)
        for (i in indices.indices) {
            val index = indices[i]
            val destPos = i * elementsPerIndex
            val srcPos = index.foldIndexed(0, { elementIndex, accum, elem -> accum + elem*contigStrides[elementIndex] })
            for (j in 0 until elementsPerIndex) {
                newData[destPos + j] = x.at(srcPos + j)
            }
        }
        val newShape = shapePerIndex.prepend(indices.size)
        assert(newShape.product == newData.size)
        return FloatTensor(newShape, newData)
    }

    override fun scatter(x: DTensor, indices: List, axis: Int, newShape: Shape, paddingIndex: Int): DTensor {
        require(x is FloatTensor)

        val rank = x.rank
        require(rank > 0) { "scatter: tensor must be of rank > 0" }
        require(axis >= 0) { "scatter: axis $axis < 0" }
        require(axis < rank) { "scatter: axis $axis >= tensor rank $rank" }
        require(indices.all { it < newShape[axis] })

        if (axis == 0 && newShape.rank == 2) {
            return scatterSparseRow(x, indices, newShape, paddingIndex)
        } else {
            return scatterDense(x, indices, axis, newShape, paddingIndex)
        }

    }

    private fun scatterDense(x: FloatTensor, indices: List, axis: Int, newShape: Shape, paddingIndex: Int): FloatTensor {
        val newData = FloatArray(newShape.product)

        val numOfIterations = x.shape.take(axis).product
        val elementsPerIteration = x.shape.drop(axis + 1).product
        val xStride = x.shape.drop(axis).product
        val newDataStride = newShape.drop(axis).product

        for (n in 0 until numOfIterations) {
            val xStartingPoint = n * xStride
            val newDataStartingPoint = n * newDataStride
            indices.forEachIndexed { i, index ->
                if (index == paddingIndex) return@forEachIndexed
                for (e in 0 until elementsPerIteration) {
                    val newIndex = newDataStartingPoint + (index * elementsPerIteration) + e
                    newData[newIndex] += x.at(xStartingPoint+ (i * elementsPerIteration) + e)
                }
            }
        }
        return FloatTensor(newShape, newData)
    }

    private fun scatterSparseRow(x: FloatTensor, indices: List, newShape: Shape, paddingIndex: Int): FloatTensor {
        val sortedIndices = indices.withIndex().sortedBy { it.value }
        var numUnique = 0
        var prev = -1
        for (i in sortedIndices) {
            if (i.value != prev && i.value != paddingIndex) {
                numUnique++
                prev = i.value
            }
        }
        if (numUnique == newShape[0]) return scatterDense(x, indices, 0, newShape, paddingIndex)
        val tableWidth = x.shape.last
        require(tableWidth == newShape[1]) { "scatterSparseRow: different number of columns not yet supported"}
        val dataWidth = newShape[1] + 1
        val newData = FloatArray(dataWidth * numUnique)

        prev = -1
        var newRowIndex = 0
        sortedIndices.forEach { i ->
            if (i.value == paddingIndex) return@forEach
            if (prev != -1 && i.value != prev) {
                newRowIndex++
            }
            val offset = newRowIndex * dataWidth
            newData[offset] = i.value.toFloat()

            for (j in 1 until dataWidth) {
                newData[offset + j] += x.at(tableWidth * i.index + j - 1)
            }
            prev = i.value
        }
        return SparseRowFloatTensor(newShape, newData)
    }

    override fun scatterAtIndices(x: DTensor, indices: List, newShape: Shape): DTensor {
        require(x is FloatTensor)
        // TODO: https://github.com/facebookincubator/diffkt/issues/89 Dnnl could probably help us here
        val shapePerIndex = newShape.drop(indices[0].size)
        val elementsPerIndex = shapePerIndex.product
        assert(x.size == indices.size * elementsPerIndex)
        val newData = FloatArray(newShape.product)
        val contigStrides = StridedUtils.contigStrides(newShape)
        for (i in indices.indices) {
            val index = indices[i]
            val srcPos = i * elementsPerIndex
            val destPos = index.foldIndexed(0, { elementIndex, accum, elem -> accum + elem*contigStrides[elementIndex] })
            for (j in 0 until elementsPerIndex) {
                newData[destPos + j] = x.at(srcPos + j)
            }
        }
        return FloatTensor(newShape, newData)
    }

    override fun gamma(alpha: DTensor, randomKey: RandomKey): DTensor {
        return (randomKey as Sha512Random).gamma(alpha as FloatTensor)
    }

    @SType("S: Shape")
    override fun compare(
        left: @SType("S") DTensor,
        right: @SType("S") DTensor,
        comparison: ComparisonKind
    ): @SType("S") DTensor {
        require(left is FloatTensor)
        require(right is FloatTensor)
        return left.zip(right) { l, r -> if (compare(l, r, comparison)) 1f else 0f }
    }

    @SType("S: Shape")
    override fun ifThenElse(
        condition: @SType("S") DTensor,
        whenTrue: @SType("S") DTensor,
        whenFalse: @SType("S") DTensor,
        derivativeId: DerivativeID
    ): @SType("S") DTensor {
        require(condition is FloatTensor)
        require(whenTrue is FloatTensor)
        require(whenFalse is FloatTensor)
        return condition.zip2(whenTrue, whenFalse) { a, b, c -> if (a > 0f) b else c }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy