All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.nd.NDAlgebra.kt Maven / Gradle / Ivy

package space.kscience.kmath.nd

import space.kscience.kmath.operations.Field
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.Space
import space.kscience.kmath.structures.*

/**
 * An exception is thrown when the expected ans actual shape of NDArray differs.
 *
 * @property expected the expected shape.
 * @property actual the actual shape.
 */
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
    RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")

/**
 * The base interface for all ND-algebra implementations.
 *
 * @param T the type of ND-structure element.
 * @param C the type of the element context.
 * @param N the type of the structure.
 */
public interface NDAlgebra {
    /**
     * The shape of ND-structures this algebra operates on.
     */
    public val shape: IntArray

    /**
     * The algebra over elements of ND structure.
     */
    public val elementContext: C

    /**
     * Produces a new [N] structure using given initializer function.
     */
    public fun produce(initializer: C.(IntArray) -> T): NDStructure

    /**
     * Maps elements from one structure to another one by applying [transform] to them.
     */
    public fun NDStructure.map(transform: C.(T) -> T): NDStructure

    /**
     * Maps elements from one structure to another one by applying [transform] to them alongside with their indices.
     */
    public fun NDStructure.mapIndexed(transform: C.(index: IntArray, T) -> T): NDStructure

    /**
     * Combines two structures into one.
     */
    public fun combine(a: NDStructure, b: NDStructure, transform: C.(T, T) -> T): NDStructure

    /**
     * Element-wise invocation of function working on [T] on a [NDStructure].
     */
    public operator fun Function1.invoke(structure: NDStructure): NDStructure =
        structure.map() { value -> this@invoke(value) }

    public companion object
}

/**
 * Checks if given elements are consistent with this context.
 *
 * @param structures the structures to check.
 * @return the array of valid structures.
 */
internal fun  NDAlgebra.checkShape(vararg structures: NDStructure): Array> = structures
    .map(NDStructure::shape)
    .singleOrNull { !shape.contentEquals(it) }
    ?.let>> { throw ShapeMismatchException(shape, it) }
    ?: structures

/**
 * Checks if given element is consistent with this context.
 *
 * @param element the structure to check.
 * @return the valid structure.
 */
internal fun  NDAlgebra.checkShape(element: NDStructure): NDStructure {
    if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
    return element
}

/**
 * Space of [NDStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param N the type of ND structure.
 * @param S the type of space of structure elements.
 */
public interface NDSpace> : Space>, NDAlgebra {
    /**
     * Element-wise addition.
     *
     * @param a the addend.
     * @param b the augend.
     * @return the sum.
     */
    public override fun add(a: NDStructure, b: NDStructure): NDStructure =
        combine(a, b) { aValue, bValue -> add(aValue, bValue) }

    /**
     * Element-wise multiplication by scalar.
     *
     * @param a the multiplicand.
     * @param k the multiplier.
     * @return the product.
     */
    public override fun multiply(a: NDStructure, k: Number): NDStructure = a.map() { multiply(it, k) }

    // TODO move to extensions after KEEP-176

    /**
     * Adds an ND structure to an element of it.
     *
     * @receiver the addend.
     * @param arg the augend.
     * @return the sum.
     */
    public operator fun NDStructure.plus(arg: T): NDStructure = this.map() { value -> add(arg, value) }

    /**
     * Subtracts an element from ND structure of it.
     *
     * @receiver the dividend.
     * @param arg the divisor.
     * @return the quotient.
     */
    public operator fun NDStructure.minus(arg: T): NDStructure = this.map() { value -> add(arg, -value) }

    /**
     * Adds an element to ND structure of it.
     *
     * @receiver the addend.
     * @param arg the augend.
     * @return the sum.
     */
    public operator fun T.plus(arg: NDStructure): NDStructure = arg.map() { value -> add(this@plus, value) }

    /**
     * Subtracts an ND structure from an element of it.
     *
     * @receiver the dividend.
     * @param arg the divisor.
     * @return the quotient.
     */
    public operator fun T.minus(arg: NDStructure): NDStructure = arg.map() { value -> add(-this@minus, value) }

    public companion object
}

/**
 * Ring of [NDStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param N the type of ND structure.
 * @param R the type of ring of structure elements.
 */
public interface NDRing> : Ring>, NDSpace {
    /**
     * Element-wise multiplication.
     *
     * @param a the multiplicand.
     * @param b the multiplier.
     * @return the product.
     */
    public override fun multiply(a: NDStructure, b: NDStructure): NDStructure =
        combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }

    //TODO move to extensions after KEEP-176

    /**
     * Multiplies an ND structure by an element of it.
     *
     * @receiver the multiplicand.
     * @param arg the multiplier.
     * @return the product.
     */
    public operator fun NDStructure.times(arg: T): NDStructure = this.map() { value -> multiply(arg, value) }

    /**
     * Multiplies an element by a ND structure of it.
     *
     * @receiver the multiplicand.
     * @param arg the multiplier.
     * @return the product.
     */
    public operator fun T.times(arg: NDStructure): NDStructure = arg.map() { value -> multiply(this@times, value) }

    public companion object
}

/**
 * Field of [NDStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param N the type of ND structure.
 * @param F the type field of structure elements.
 */
public interface NDField> : Field>, NDRing {
    /**
     * Element-wise division.
     *
     * @param a the dividend.
     * @param b the divisor.
     * @return the quotient.
     */
    public override fun divide(a: NDStructure, b: NDStructure): NDStructure =
        combine(a, b) { aValue, bValue -> divide(aValue, bValue) }

    //TODO move to extensions after KEEP-176
    /**
     * Divides an ND structure by an element of it.
     *
     * @receiver the dividend.
     * @param arg the divisor.
     * @return the quotient.
     */
    public operator fun NDStructure.div(arg: T): NDStructure = this.map() { value -> divide(arg, value) }

    /**
     * Divides an element by an ND structure of it.
     *
     * @receiver the dividend.
     * @param arg the divisor.
     * @return the quotient.
     */
    public operator fun T.div(arg: NDStructure): NDStructure = arg.map() { divide(it, this@div) }

//    @ThreadLocal
//    public companion object {
//        private val realNDFieldCache: MutableMap = hashMapOf()
//
//        /**
//         * Create a nd-field for [Double] values or pull it from cache if it was created previously.
//         */
//        public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
//
//        /**
//         * Create an ND field with boxing generic buffer.
//         */
//        public fun > boxing(
//            field: F,
//            vararg shape: Int,
//            bufferFactory: BufferFactory = Buffer.Companion::boxing,
//        ): BufferedNDField = BufferedNDField(shape, field, bufferFactory)
//
//        /**
//         * Create a most suitable implementation for nd-field using reified class.
//         */
//        @Suppress("UNCHECKED_CAST")
//        public inline fun > auto(field: F, vararg shape: Int): NDField =
//            when {
//                T::class == Double::class -> real(*shape) as NDField
//                T::class == Complex::class -> complex(*shape) as BufferedNDField
//                else -> BoxingNDField(shape, field, Buffer.Companion::auto)
//            }
//    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy