All Downloads are FREE. Search and download functionalities are using the official Maven repository.

space.kscience.kmath.nd4j.Nd4jArrayAlgebra.kt Maven / Gradle / Ivy

There is a newer version: 0.3.1
Show newest version
package space.kscience.kmath.nd4j

import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*

internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray {
    val arrayShape = array.shape().toIntArray()
    if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
    return array
}


/**
 * Represents [NDAlgebra] over [Nd4jArrayAlgebra].
 *
 * @param T the type of ND-structure element.
 * @param C the type of the element context.
 */
public interface Nd4jArrayAlgebra : NDAlgebra {
    /**
     * Wraps [INDArray] to [N].
     */
    public fun INDArray.wrap(): Nd4jArrayStructure

    public val NDStructure.ndArray: INDArray
        get() = when {
            !shape.contentEquals([email protected]) -> throw ShapeMismatchException(
                [email protected],
                shape
            )
            this is Nd4jArrayStructure -> ndArray //TODO check strides
            else -> {
                TODO()
            }
        }

    public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure {
        val struct = Nd4j.create(*shape)!!.wrap()
        struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
        return struct
    }

    public override fun NDStructure.map(transform: C.(T) -> T): Nd4jArrayStructure {
        val newStruct = ndArray.dup().wrap()
        newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
        return newStruct
    }

    public override fun NDStructure.mapIndexed(
        transform: C.(index: IntArray, T) -> T,
    ): Nd4jArrayStructure {
        val new = Nd4j.create(*[email protected]).wrap()
        new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) }
        return new
    }

    public override fun combine(
        a: NDStructure,
        b: NDStructure,
        transform: C.(T, T) -> T,
    ): Nd4jArrayStructure {
        val new = Nd4j.create(*shape).wrap()
        new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
        return new
    }
}

/**
 * Represents [NDSpace] over [Nd4jArrayStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param S the type of space of structure elements.
 */
public interface Nd4jArraySpace> : NDSpace, Nd4jArrayAlgebra {

    public override val zero: Nd4jArrayStructure
        get() = Nd4j.zeros(*shape).wrap()

    public override fun add(a: NDStructure, b: NDStructure): Nd4jArrayStructure {
        return a.ndArray.add(b.ndArray).wrap()
    }

    public override operator fun NDStructure.minus(b: NDStructure): Nd4jArrayStructure {
        return ndArray.sub(b.ndArray).wrap()
    }

    public override operator fun NDStructure.unaryMinus(): Nd4jArrayStructure {
        return ndArray.neg().wrap()
    }

    public override fun multiply(a: NDStructure, k: Number): Nd4jArrayStructure {
        return a.ndArray.mul(k).wrap()
    }

    public override operator fun NDStructure.div(k: Number): Nd4jArrayStructure {
        return ndArray.div(k).wrap()
    }

    public override operator fun NDStructure.times(k: Number): Nd4jArrayStructure {
        return ndArray.mul(k).wrap()
    }
}

/**
 * Represents [NDRing] over [Nd4jArrayStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param R the type of ring of structure elements.
 */
@OptIn(UnstableKMathAPI::class)
public interface Nd4jArrayRing> : NDRing, Nd4jArraySpace {

    public override val one: Nd4jArrayStructure
        get() = Nd4j.ones(*shape).wrap()

    public override fun multiply(a: NDStructure, b: NDStructure): Nd4jArrayStructure {
        return a.ndArray.mul(b.ndArray).wrap()
    }
//
//    public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure {
//        check(this)
//        return ndArray.sub(b).wrap()
//    }
//
//    public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure {
//        check(this)
//        return ndArray.add(b).wrap()
//    }
//
//    public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure {
//        check(b)
//        return b.ndArray.rsub(this).wrap()
//    }

    public companion object {
        private val intNd4jArrayRingCache: ThreadLocal> =
            ThreadLocal.withInitial { hashMapOf() }

        private val longNd4jArrayRingCache: ThreadLocal> =
            ThreadLocal.withInitial { hashMapOf() }

        /**
         * Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
         */
        public fun int(vararg shape: Int): Nd4jArrayRing =
            intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }

        /**
         * Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
         */
        public fun long(vararg shape: Int): Nd4jArrayRing =
            longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }

        /**
         * Creates a most suitable implementation of [NDRing] using reified class.
         */
        @Suppress("UNCHECKED_CAST")
        public inline fun  auto(vararg shape: Int): Nd4jArrayRing> = when {
            T::class == Int::class -> int(*shape) as Nd4jArrayRing>
            T::class == Long::class -> long(*shape) as Nd4jArrayRing>
            else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
        }
    }
}

/**
 * Represents [NDField] over [Nd4jArrayStructure].
 *
 * @param T the type of the element contained in ND structure.
 * @param N the type of ND structure.
 * @param F the type field of structure elements.
 */
public interface Nd4jArrayField> : NDField, Nd4jArrayRing {

    public override fun divide(a: NDStructure, b: NDStructure): Nd4jArrayStructure =
        a.ndArray.div(b.ndArray).wrap()

    public override operator fun Number.div(b: NDStructure): Nd4jArrayStructure = b.ndArray.rdiv(this).wrap()


    public companion object {
        private val floatNd4jArrayFieldCache: ThreadLocal> =
            ThreadLocal.withInitial { hashMapOf() }

        private val realNd4jArrayFieldCache: ThreadLocal> =
            ThreadLocal.withInitial { hashMapOf() }

        /**
         * Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
         */
        public fun float(vararg shape: Int): Nd4jArrayRing =
            floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }

        /**
         * Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
         */
        public fun real(vararg shape: Int): Nd4jArrayRing =
            realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }

        /**
         * Creates a most suitable implementation of [NDRing] using reified class.
         */
        @Suppress("UNCHECKED_CAST")
        public inline fun  auto(vararg shape: Int): Nd4jArrayField> = when {
            T::class == Float::class -> float(*shape) as Nd4jArrayField>
            T::class == Double::class -> real(*shape) as Nd4jArrayField>
            else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
        }
    }
}

/**
 * Represents [NDField] over [Nd4jArrayRealStructure].
 */
public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField {
    public override val elementContext: RealField
        get() = RealField

    public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asRealStructure()

    public override operator fun NDStructure.div(arg: Double): Nd4jArrayStructure {
        return ndArray.div(arg).wrap()
    }

    public override operator fun NDStructure.plus(arg: Double): Nd4jArrayStructure {
        return ndArray.add(arg).wrap()
    }

    public override operator fun NDStructure.minus(arg: Double): Nd4jArrayStructure {
        return ndArray.sub(arg).wrap()
    }

    public override operator fun NDStructure.times(arg: Double): Nd4jArrayStructure {
        return ndArray.mul(arg).wrap()
    }

    public override operator fun Double.div(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rdiv(this).wrap()
    }

    public override operator fun Double.minus(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rsub(this).wrap()
    }
}

/**
 * Represents [NDField] over [Nd4jArrayStructure] of [Float].
 */
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField {
    public override val elementContext: FloatField
        get() = FloatField

    public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure()

    public override operator fun NDStructure.div(arg: Float): Nd4jArrayStructure {
        return ndArray.div(arg).wrap()
    }

    public override operator fun NDStructure.plus(arg: Float): Nd4jArrayStructure {
        return ndArray.add(arg).wrap()
    }

    public override operator fun NDStructure.minus(arg: Float): Nd4jArrayStructure {
        return ndArray.sub(arg).wrap()
    }

    public override operator fun NDStructure.times(arg: Float): Nd4jArrayStructure {
        return ndArray.mul(arg).wrap()
    }

    public override operator fun Float.div(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rdiv(this).wrap()
    }

    public override operator fun Float.minus(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rsub(this).wrap()
    }
}

/**
 * Represents [NDRing] over [Nd4jArrayIntStructure].
 */
public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing {
    public override val elementContext: IntRing
        get() = IntRing

    public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure()

    public override operator fun NDStructure.plus(arg: Int): Nd4jArrayStructure {
        return ndArray.add(arg).wrap()
    }

    public override operator fun NDStructure.minus(arg: Int): Nd4jArrayStructure {
        return ndArray.sub(arg).wrap()
    }

    public override operator fun NDStructure.times(arg: Int): Nd4jArrayStructure {
        return ndArray.mul(arg).wrap()
    }

    public override operator fun Int.minus(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rsub(this).wrap()
    }
}

/**
 * Represents [NDRing] over [Nd4jArrayStructure] of [Long].
 */
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing {
    public override val elementContext: LongRing
        get() = LongRing

    public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asLongStructure()

    public override operator fun NDStructure.plus(arg: Long): Nd4jArrayStructure {
        return ndArray.add(arg).wrap()
    }

    public override operator fun NDStructure.minus(arg: Long): Nd4jArrayStructure {
        return ndArray.sub(arg).wrap()
    }

    public override operator fun NDStructure.times(arg: Long): Nd4jArrayStructure {
        return ndArray.mul(arg).wrap()
    }

    public override operator fun Long.minus(arg: NDStructure): Nd4jArrayStructure {
        return arg.ndArray.rsub(this).wrap()
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy