
space.kscience.kmath.nd4j.Nd4jArrayIterator.kt Maven / Gradle / Ivy
package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.shape.Shape
private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator {
private var i: Int = 0
override fun hasNext(): Boolean = i < iterateOver.length()
override fun next(): IntArray {
val la = if (iterateOver.ordering() == 'c')
Shape.ind2subC(iterateOver, i++.toLong())!!
else
Shape.ind2sub(iterateOver, i++.toLong())!!
return la.toIntArray()
}
}
internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this)
private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> {
private var i: Int = 0
final override fun hasNext(): Boolean = i < iterateOver.length()
abstract fun getSingle(indices: LongArray): T
final override fun next(): Pair {
val la = if (iterateOver.ordering() == 'c')
Shape.ind2subC(iterateOver, i++.toLong())!!
else
Shape.ind2sub(iterateOver, i++.toLong())!!
return la.toIntArray() to getSingle(la)
}
}
private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
}
internal fun INDArray.realIterator(): Iterator> = Nd4jArrayRealIterator(this)
private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
}
internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this)
private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
}
internal fun INDArray.intIterator(): Iterator> = Nd4jArrayIntIterator(this)
private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
}
internal fun INDArray.floatIterator(): Iterator> = Nd4jArrayFloatIterator(this)
© 2015 - 2025 Weber Informatics LLC | Privacy Policy