
space.kscience.kmath.nd4j.Nd4jArrayAlgebra.kt Maven / Gradle / Ivy
package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
return array
}
/**
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
*/
public interface Nd4jArrayAlgebra : NDAlgebra {
/**
* Wraps [INDArray] to [N].
*/
public fun INDArray.wrap(): Nd4jArrayStructure
public val NDStructure.ndArray: INDArray
get() = when {
!shape.contentEquals([email protected]) -> throw ShapeMismatchException(
[email protected],
shape
)
this is Nd4jArrayStructure -> ndArray //TODO check strides
else -> {
TODO()
}
}
public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure {
val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
return struct
}
public override fun NDStructure.map(transform: C.(T) -> T): Nd4jArrayStructure {
val newStruct = ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct
}
public override fun NDStructure.mapIndexed(
transform: C.(index: IntArray, T) -> T,
): Nd4jArrayStructure {
val new = Nd4j.create(*[email protected]).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) }
return new
}
public override fun combine(
a: NDStructure,
b: NDStructure,
transform: C.(T, T) -> T,
): Nd4jArrayStructure {
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
return new
}
}
/**
* Represents [NDSpace] over [Nd4jArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements.
*/
public interface Nd4jArraySpace> : NDSpace, Nd4jArrayAlgebra {
public override val zero: Nd4jArrayStructure
get() = Nd4j.zeros(*shape).wrap()
public override fun add(a: NDStructure, b: NDStructure): Nd4jArrayStructure {
return a.ndArray.add(b.ndArray).wrap()
}
public override operator fun NDStructure.minus(b: NDStructure): Nd4jArrayStructure {
return ndArray.sub(b.ndArray).wrap()
}
public override operator fun NDStructure.unaryMinus(): Nd4jArrayStructure {
return ndArray.neg().wrap()
}
public override fun multiply(a: NDStructure, k: Number): Nd4jArrayStructure {
return a.ndArray.mul(k).wrap()
}
public override operator fun NDStructure.div(k: Number): Nd4jArrayStructure {
return ndArray.div(k).wrap()
}
public override operator fun NDStructure.times(k: Number): Nd4jArrayStructure {
return ndArray.mul(k).wrap()
}
}
/**
* Represents [NDRing] over [Nd4jArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param R the type of ring of structure elements.
*/
@OptIn(UnstableKMathAPI::class)
public interface Nd4jArrayRing> : NDRing, Nd4jArraySpace {
public override val one: Nd4jArrayStructure
get() = Nd4j.ones(*shape).wrap()
public override fun multiply(a: NDStructure, b: NDStructure): Nd4jArrayStructure {
return a.ndArray.mul(b.ndArray).wrap()
}
//
// public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure {
// check(this)
// return ndArray.sub(b).wrap()
// }
//
// public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure {
// check(this)
// return ndArray.add(b).wrap()
// }
//
// public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure {
// check(b)
// return b.ndArray.rsub(this).wrap()
// }
public companion object {
private val intNd4jArrayRingCache: ThreadLocal> =
ThreadLocal.withInitial { hashMapOf() }
private val longNd4jArrayRingCache: ThreadLocal> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
*/
public fun long(vararg shape: Int): Nd4jArrayRing =
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing>
T::class == Long::class -> long(*shape) as Nd4jArrayRing>
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
}
}
}
/**
* Represents [NDField] over [Nd4jArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
public interface Nd4jArrayField> : NDField, Nd4jArrayRing {
public override fun divide(a: NDStructure, b: NDStructure): Nd4jArrayStructure =
a.ndArray.div(b.ndArray).wrap()
public override operator fun Number.div(b: NDStructure): Nd4jArrayStructure = b.ndArray.rdiv(this).wrap()
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal> =
ThreadLocal.withInitial { hashMapOf() }
private val realNd4jArrayFieldCache: ThreadLocal> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing =
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun auto(vararg shape: Int): Nd4jArrayField> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField>
T::class == Double::class -> real(*shape) as Nd4jArrayField>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
}
/**
* Represents [NDField] over [Nd4jArrayRealStructure].
*/
public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField {
public override val elementContext: RealField
get() = RealField
public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asRealStructure()
public override operator fun NDStructure.div(arg: Double): Nd4jArrayStructure {
return ndArray.div(arg).wrap()
}
public override operator fun NDStructure.plus(arg: Double): Nd4jArrayStructure {
return ndArray.add(arg).wrap()
}
public override operator fun NDStructure.minus(arg: Double): Nd4jArrayStructure {
return ndArray.sub(arg).wrap()
}
public override operator fun NDStructure.times(arg: Double): Nd4jArrayStructure {
return ndArray.mul(arg).wrap()
}
public override operator fun Double.div(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rdiv(this).wrap()
}
public override operator fun Double.minus(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [Nd4jArrayStructure] of [Float].
*/
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField {
public override val elementContext: FloatField
get() = FloatField
public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure()
public override operator fun NDStructure.div(arg: Float): Nd4jArrayStructure {
return ndArray.div(arg).wrap()
}
public override operator fun NDStructure.plus(arg: Float): Nd4jArrayStructure {
return ndArray.add(arg).wrap()
}
public override operator fun NDStructure.minus(arg: Float): Nd4jArrayStructure {
return ndArray.sub(arg).wrap()
}
public override operator fun NDStructure.times(arg: Float): Nd4jArrayStructure {
return ndArray.mul(arg).wrap()
}
public override operator fun Float.div(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rdiv(this).wrap()
}
public override operator fun Float.minus(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [Nd4jArrayIntStructure].
*/
public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing {
public override val elementContext: IntRing
get() = IntRing
public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure()
public override operator fun NDStructure.plus(arg: Int): Nd4jArrayStructure {
return ndArray.add(arg).wrap()
}
public override operator fun NDStructure.minus(arg: Int): Nd4jArrayStructure {
return ndArray.sub(arg).wrap()
}
public override operator fun NDStructure.times(arg: Int): Nd4jArrayStructure {
return ndArray.mul(arg).wrap()
}
public override operator fun Int.minus(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [Nd4jArrayStructure] of [Long].
*/
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing {
public override val elementContext: LongRing
get() = LongRing
public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asLongStructure()
public override operator fun NDStructure.plus(arg: Long): Nd4jArrayStructure {
return ndArray.add(arg).wrap()
}
public override operator fun NDStructure.minus(arg: Long): Nd4jArrayStructure {
return ndArray.sub(arg).wrap()
}
public override operator fun NDStructure.times(arg: Long): Nd4jArrayStructure {
return ndArray.mul(arg).wrap()
}
public override operator fun Long.minus(arg: NDStructure): Nd4jArrayStructure {
return arg.ndArray.rsub(this).wrap()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy