All Downloads are FREE. Search and download functionalities are using the official Maven repository.

space.kscience.kmath.nd4j.Nd4jArrayIterator.kt Maven / Gradle / Ivy

There is a newer version: 0.3.1
Show newest version
package space.kscience.kmath.nd4j

import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.shape.Shape

private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator {
    private var i: Int = 0

    override fun hasNext(): Boolean = i < iterateOver.length()

    override fun next(): IntArray {
        val la = if (iterateOver.ordering() == 'c')
            Shape.ind2subC(iterateOver, i++.toLong())!!
        else
            Shape.ind2sub(iterateOver, i++.toLong())!!

        return la.toIntArray()
    }
}

internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this)

private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> {
    private var i: Int = 0

    final override fun hasNext(): Boolean = i < iterateOver.length()

    abstract fun getSingle(indices: LongArray): T

    final override fun next(): Pair {
        val la = if (iterateOver.ordering() == 'c')
            Shape.ind2subC(iterateOver, i++.toLong())!!
        else
            Shape.ind2sub(iterateOver, i++.toLong())!!

        return la.toIntArray() to getSingle(la)
    }
}

private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
    override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
}

internal fun INDArray.realIterator(): Iterator> = Nd4jArrayRealIterator(this)

private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
    override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
}

internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this)

private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
    override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
}

internal fun INDArray.intIterator(): Iterator> = Nd4jArrayIntIterator(this)

private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
    override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
}

internal fun INDArray.floatIterator(): Iterator> = Nd4jArrayFloatIterator(this)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy