
commonMain.space.kscience.kmath.nd.NDAlgebra.kt Maven / Gradle / Ivy
package space.kscience.kmath.nd
import space.kscience.kmath.operations.Field
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.Space
import space.kscience.kmath.structures.*
/**
* An exception is thrown when the expected ans actual shape of NDArray differs.
*
* @property expected the expected shape.
* @property actual the actual shape.
*/
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
/**
* The base interface for all ND-algebra implementations.
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
* @param N the type of the structure.
*/
public interface NDAlgebra {
/**
* The shape of ND-structures this algebra operates on.
*/
public val shape: IntArray
/**
* The algebra over elements of ND structure.
*/
public val elementContext: C
/**
* Produces a new [N] structure using given initializer function.
*/
public fun produce(initializer: C.(IntArray) -> T): NDStructure
/**
* Maps elements from one structure to another one by applying [transform] to them.
*/
public fun NDStructure.map(transform: C.(T) -> T): NDStructure
/**
* Maps elements from one structure to another one by applying [transform] to them alongside with their indices.
*/
public fun NDStructure.mapIndexed(transform: C.(index: IntArray, T) -> T): NDStructure
/**
* Combines two structures into one.
*/
public fun combine(a: NDStructure, b: NDStructure, transform: C.(T, T) -> T): NDStructure
/**
* Element-wise invocation of function working on [T] on a [NDStructure].
*/
public operator fun Function1.invoke(structure: NDStructure): NDStructure =
structure.map() { value -> this@invoke(value) }
public companion object
}
/**
* Checks if given elements are consistent with this context.
*
* @param structures the structures to check.
* @return the array of valid structures.
*/
internal fun NDAlgebra.checkShape(vararg structures: NDStructure): Array> = structures
.map(NDStructure::shape)
.singleOrNull { !shape.contentEquals(it) }
?.let>> { throw ShapeMismatchException(shape, it) }
?: structures
/**
* Checks if given element is consistent with this context.
*
* @param element the structure to check.
* @return the valid structure.
*/
internal fun NDAlgebra.checkShape(element: NDStructure): NDStructure {
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
return element
}
/**
* Space of [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param S the type of space of structure elements.
*/
public interface NDSpace> : Space>, NDAlgebra {
/**
* Element-wise addition.
*
* @param a the addend.
* @param b the augend.
* @return the sum.
*/
public override fun add(a: NDStructure, b: NDStructure): NDStructure =
combine(a, b) { aValue, bValue -> add(aValue, bValue) }
/**
* Element-wise multiplication by scalar.
*
* @param a the multiplicand.
* @param k the multiplier.
* @return the product.
*/
public override fun multiply(a: NDStructure, k: Number): NDStructure = a.map() { multiply(it, k) }
// TODO move to extensions after KEEP-176
/**
* Adds an ND structure to an element of it.
*
* @receiver the addend.
* @param arg the augend.
* @return the sum.
*/
public operator fun NDStructure.plus(arg: T): NDStructure = this.map() { value -> add(arg, value) }
/**
* Subtracts an element from ND structure of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
public operator fun NDStructure.minus(arg: T): NDStructure = this.map() { value -> add(arg, -value) }
/**
* Adds an element to ND structure of it.
*
* @receiver the addend.
* @param arg the augend.
* @return the sum.
*/
public operator fun T.plus(arg: NDStructure): NDStructure = arg.map() { value -> add(this@plus, value) }
/**
* Subtracts an ND structure from an element of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
public operator fun T.minus(arg: NDStructure): NDStructure = arg.map() { value -> add(-this@minus, value) }
public companion object
}
/**
* Ring of [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param R the type of ring of structure elements.
*/
public interface NDRing> : Ring>, NDSpace {
/**
* Element-wise multiplication.
*
* @param a the multiplicand.
* @param b the multiplier.
* @return the product.
*/
public override fun multiply(a: NDStructure, b: NDStructure): NDStructure =
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176
/**
* Multiplies an ND structure by an element of it.
*
* @receiver the multiplicand.
* @param arg the multiplier.
* @return the product.
*/
public operator fun NDStructure.times(arg: T): NDStructure = this.map() { value -> multiply(arg, value) }
/**
* Multiplies an element by a ND structure of it.
*
* @receiver the multiplicand.
* @param arg the multiplier.
* @return the product.
*/
public operator fun T.times(arg: NDStructure): NDStructure = arg.map() { value -> multiply(this@times, value) }
public companion object
}
/**
* Field of [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
public interface NDField> : Field>, NDRing {
/**
* Element-wise division.
*
* @param a the dividend.
* @param b the divisor.
* @return the quotient.
*/
public override fun divide(a: NDStructure, b: NDStructure): NDStructure =
combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176
/**
* Divides an ND structure by an element of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
public operator fun NDStructure.div(arg: T): NDStructure = this.map() { value -> divide(arg, value) }
/**
* Divides an element by an ND structure of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
public operator fun T.div(arg: NDStructure): NDStructure = arg.map() { divide(it, this@div) }
// @ThreadLocal
// public companion object {
// private val realNDFieldCache: MutableMap = hashMapOf()
//
// /**
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
// */
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
//
// /**
// * Create an ND field with boxing generic buffer.
// */
// public fun > boxing(
// field: F,
// vararg shape: Int,
// bufferFactory: BufferFactory = Buffer.Companion::boxing,
// ): BufferedNDField = BufferedNDField(shape, field, bufferFactory)
//
// /**
// * Create a most suitable implementation for nd-field using reified class.
// */
// @Suppress("UNCHECKED_CAST")
// public inline fun > auto(field: F, vararg shape: Int): NDField =
// when {
// T::class == Double::class -> real(*shape) as NDField
// T::class == Complex::class -> complex(*shape) as BufferedNDField
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
// }
// }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy