All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.nd.NDStructure.kt Maven / Gradle / Ivy

package space.kscience.kmath.nd

import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.asSequence
import kotlin.jvm.JvmName
import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass

/**
 * Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number
 * of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that
 * specify the sizes of each dimension.
 *
 * @param T the type of items.
 */
public interface NDStructure {
    /**
     * The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of
     * this structure.
     */
    public val shape: IntArray

    /**
     * The count of dimensions in this structure. It should be equal to size of [shape].
     */
    public val dimension: Int get() = shape.size

    /**
     * Returns the value at the specified indices.
     *
     * @param index the indices.
     * @return the value.
     */
    public operator fun get(index: IntArray): T

    /**
     * Returns the sequence of all the elements associated by their indices.
     *
     * @return the lazy sequence of pairs of indices to values.
     */
    public fun elements(): Sequence>

    //force override equality and hash code
    public override fun equals(other: Any?): Boolean
    public override fun hashCode(): Int

    /**
     * Feature is additional property or hint that does not directly affect the structure, but could in some cases help
     * optimize operations and performance. If the feature is not present, null is defined.
     */
    @UnstableKMathAPI
    public fun  getFeature(type: KClass): T? = null

    public companion object {
        /**
         * Indicates whether some [NDStructure] is equal to another one.
         */
        public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
            if (st1 === st2) return true

            // fast comparison of buffers if possible
            if (st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides)
                return st1.buffer.contentEquals(st2.buffer)

            //element by element comparison if it could not be avoided
            return st1.elements().all { (index, value) -> value == st2[index] }
        }

        /**
         * Creates a NDStructure with explicit buffer factory.
         *
         * Strides should be reused if possible.
         */
        public fun  build(
            strides: Strides,
            bufferFactory: BufferFactory = Buffer.Companion::boxing,
            initializer: (IntArray) -> T,
        ): NDBuffer = NDBuffer(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })

        /**
         * Inline create NDStructure with non-boxing buffer implementation if it is possible
         */
        public inline fun  auto(
            strides: Strides,
            crossinline initializer: (IntArray) -> T,
        ): NDBuffer = NDBuffer(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })

        public inline fun  auto(
            type: KClass,
            strides: Strides,
            crossinline initializer: (IntArray) -> T,
        ): NDBuffer = NDBuffer(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) })

        public fun  build(
            shape: IntArray,
            bufferFactory: BufferFactory = Buffer.Companion::boxing,
            initializer: (IntArray) -> T,
        ): NDBuffer = build(DefaultStrides(shape), bufferFactory, initializer)

        public inline fun  auto(
            shape: IntArray,
            crossinline initializer: (IntArray) -> T,
        ): NDBuffer = auto(DefaultStrides(shape), initializer)

        @JvmName("autoVarArg")
        public inline fun  auto(
            vararg shape: Int,
            crossinline initializer: (IntArray) -> T,
        ): NDBuffer =
            auto(DefaultStrides(shape), initializer)

        public inline fun  auto(
            type: KClass,
            vararg shape: Int,
            crossinline initializer: (IntArray) -> T,
        ): NDBuffer = auto(type, DefaultStrides(shape), initializer)
    }
}

/**
 * Returns the value at the specified indices.
 *
 * @param index the indices.
 * @return the value.
 */
public operator fun  NDStructure.get(vararg index: Int): T = get(index)

@UnstableKMathAPI
public inline fun  NDStructure<*>.getFeature(): T? = getFeature(T::class)

/**
 * Represents mutable [NDStructure].
 */
public interface MutableNDStructure : NDStructure {
    /**
     * Inserts an item at the specified indices.
     *
     * @param index the indices.
     * @param value the value.
     */
    public operator fun set(index: IntArray, value: T)
}

/**
 * Transform a structure element-by element in place.
 */
public inline fun  MutableNDStructure.mapInPlace(action: (IntArray, T) -> T): Unit =
    elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }

/**
 * A way to convert ND index to linear one and back.
 */
public interface Strides {
    /**
     * Shape of NDStructure
     */
    public val shape: IntArray

    /**
     * Array strides
     */
    public val strides: List

    /**
     * Get linear index from multidimensional index
     */
    public fun offset(index: IntArray): Int

    /**
     * Get multidimensional from linear
     */
    public fun index(offset: Int): IntArray

    /**
     * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
     */
    public val linearSize: Int

    // TODO introduce a fast way to calculate index of the next element?

    /**
     * Iterate over ND indices in a natural order
     */
    public fun indices(): Sequence = (0 until linearSize).asSequence().map {
        index(it)
    }
}

/**
 * Simple implementation of [Strides].
 */
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
    override val linearSize: Int
        get() = strides[shape.size]

    /**
     * Strides for memory access
     */
    override val strides: List by lazy {
        sequence {
            var current = 1
            yield(1)

            shape.forEach {
                current *= it
                yield(current)
            }
        }.toList()
    }

    override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
        if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
        value * strides[i]
    }.sum()

    override fun index(offset: Int): IntArray {
        val res = IntArray(shape.size)
        var current = offset
        var strideIndex = strides.size - 2

        while (strideIndex >= 0) {
            res[strideIndex] = (current / strides[strideIndex])
            current %= strides[strideIndex]
            strideIndex--
        }

        return res
    }

    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (other !is DefaultStrides) return false
        if (!shape.contentEquals(other.shape)) return false
        return true
    }

    override fun hashCode(): Int = shape.contentHashCode()

    @ThreadLocal
    public companion object {
        private val defaultStridesCache = HashMap()

        /**
         * Cached builder for default strides
         */
        public operator fun invoke(shape: IntArray): Strides =
            defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
    }
}

/**
 * Represents [NDStructure] over [Buffer].
 *
 * @param T the type of items.
 * @param strides The strides to access elements of [Buffer] by linear indices.
 * @param buffer The underlying buffer.
 */
public open class NDBuffer(
    public val strides: Strides,
    buffer: Buffer,
) : NDStructure {

    init {
        if (strides.linearSize != buffer.size) {
            error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
        }
    }

    public open val buffer: Buffer = buffer

    override operator fun get(index: IntArray): T = buffer[strides.offset(index)]

    override val shape: IntArray get() = strides.shape

    override fun elements(): Sequence> = strides.indices().map {
        it to this[it]
    }

    override fun equals(other: Any?): Boolean {
        return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
    }

    override fun hashCode(): Int {
        var result = strides.hashCode()
        result = 31 * result + buffer.hashCode()
        return result
    }

    override fun toString(): String {
        val bufferRepr: String = when (shape.size) {
            1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ")
            2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
                (0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j ->
                    val offset = strides.offset(intArrayOf(i, j))
                    buffer[offset].toString()
                }
            }
            else -> "..."
        }
        return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)"
    }
}

/**
 * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [NDBuffer]
 */
public inline fun  NDStructure.mapToBuffer(
    factory: BufferFactory = Buffer.Companion::auto,
    crossinline transform: (T) -> R,
): NDBuffer {
    return if (this is NDBuffer)
        NDBuffer(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
    else {
        val strides = DefaultStrides(shape)
        NDBuffer(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
    }
}

/**
 * Mutable ND buffer based on linear [MutableBuffer].
 */
public class MutableNDBuffer(
    strides: Strides,
    buffer: MutableBuffer,
) : NDBuffer(strides, buffer), MutableNDStructure {

    init {
        require(strides.linearSize == buffer.size) {
            "Expected buffer side of ${strides.linearSize}, but found ${buffer.size}"
        }
    }

    override val buffer: MutableBuffer = super.buffer as MutableBuffer

    override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value)
}

public inline fun  NDStructure.combine(
    struct: NDStructure,
    crossinline block: (T, T) -> T,
): NDStructure {
    require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" }
    return NDStructure.auto(shape) { block(this[it], struct[it]) }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy