
org.jetbrains.kotlinx.multik.ndarray.data.NDArray.kt Maven / Gradle / Ivy
/*
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package org.jetbrains.kotlinx.multik.ndarray.data
import org.jetbrains.kotlinx.multik.ndarray.operations.concatenate
public typealias D1Array = NDArray
public typealias D2Array = NDArray
public typealias D3Array = NDArray
public typealias D4Array = NDArray
/**
* A class that implements multidimensional arrays. This implementation is based on primitive arrays.
* With the help of [offset], [shape], [strides] there is a multidimensionality representation
* over a sequential homogeneous array.
*
* Native code uses `GetPrimitiveArrayCritical` for calculation.
*
* @param T type of stored values.
* @param D dimension.
*/
public class NDArray constructor(
data: ImmutableMemoryView,
public override val offset: Int = 0,
public override val shape: IntArray,
public override val strides: IntArray = computeStrides(shape),
public override val dim: D,
public override val base: MultiArray? = null
) : MutableMultiArray {
init {
check(shape.isNotEmpty()) { "Shape can't be empty." }
}
public override val data: MemoryView = data as MemoryView
public override val size: Int get() = shape.fold(1, Int::times)
public override val consistent: Boolean
get() {
return offset == 0 && size == data.size && strides.contentEquals(computeStrides(shape))
}
override val indices: IntRange
get() {
// todo?
// if (dim.d != 1) throw IllegalStateException("NDArray of dimension ${dim.d}, use multiIndex.")
return 0..size - 1
}
override val multiIndices: MultiIndexProgression get() = IntArray(dim.d)..IntArray(dim.d) { shape[it] - 1 }
override fun isScalar(): Boolean = shape.isEmpty() || (shape.size == 1 && shape.first() == 1)
public override fun isEmpty(): Boolean = size == 0
public override fun isNotEmpty(): Boolean = !isEmpty()
public override operator fun iterator(): Iterator =
if (consistent) this.data.iterator() else NDArrayIterator(data, offset, strides, shape)
public inline fun asType(): NDArray {
val dataType = DataType.ofKClass(E::class)
return this.asType(dataType)
}
//TODO ???
public fun asType(dataType: DataType): NDArray {
val newData = initMemoryView(this.data.size, dataType) { this.data[it] as E }
return NDArray(newData, this.offset, this.shape, this.strides, this.dim)
}
override fun copy(): NDArray =
NDArray(this.data.copyOf(), this.offset, this.shape.copyOf(), this.strides.copyOf(), this.dim)
override fun deepCopy(): NDArray {
val data: MemoryView
if (consistent) {
data = this.data.copyOf()
} else {
data = initMemoryView(this.size, this.dtype)
var index = 0
for (el in this)
data[index++] = el
}
return NDArray(data, 0, this.shape.copyOf(), dim = this.dim)
}
override fun flatten(): MultiArray {
val data = if (consistent) {
data.copyOf()
} else {
val tmpData = initMemoryView(size, dtype)
var index = 0
for (el in this) tmpData[index++] = el
tmpData
}
return D1Array(data, 0, intArrayOf(size), dim = D1)
}
// TODO(strides? : view.reshape().reshape()?)
override fun reshape(dim1: Int): D1Array {
// todo negative shape?
requirePositiveShape(dim1)
require(dim1 == size) { "Cannot reshape array of size $size into a new shape ($dim1)" }
return if (this.dim.d == 1 && this.shape.first() == dim1) {
this as D1Array
} else {
D1Array(this.data, this.offset, intArrayOf(dim1), dim = D1, base = base ?: this)
}
}
override fun reshape(dim1: Int, dim2: Int): D2Array {
val newShape = intArrayOf(dim1, dim2)
newShape.forEach { requirePositiveShape(it) }
require(dim1 * dim2 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2)" }
return if (this.shape.contentEquals(newShape)) {
this as D2Array
} else {
D2Array(this.data, this.offset, newShape, dim = D2, base = base ?: this)
}
}
override fun reshape(dim1: Int, dim2: Int, dim3: Int): D3Array {
val newShape = intArrayOf(dim1, dim2, dim3)
newShape.forEach { requirePositiveShape(it) }
require(dim1 * dim2 * dim3 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3)" }
return if (this.shape.contentEquals(newShape)) {
this as D3Array
} else {
D3Array(this.data, this.offset, newShape, dim = D3, base = base ?: this)
}
}
override fun reshape(dim1: Int, dim2: Int, dim3: Int, dim4: Int): D4Array {
val newShape = intArrayOf(dim1, dim2, dim3, dim4)
newShape.forEach { requirePositiveShape(it) }
require(dim1 * dim2 * dim3 * dim4 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3, $dim4)" }
return if (this.shape.contentEquals(newShape)) {
this as D4Array
} else {
D4Array(this.data, this.offset, newShape, dim = D4, base = base ?: this)
}
}
override fun reshape(dim1: Int, dim2: Int, dim3: Int, dim4: Int, vararg dims: Int): NDArray {
val newShape = intArrayOf(dim1, dim2, dim3, dim4) + dims
newShape.forEach { requirePositiveShape(it) }
require(newShape.fold(1, Int::times) == size) {
"Cannot reshape array of size $size into a new shape ${newShape.joinToString(prefix = "(", postfix = ")")}"
}
return if (this.shape.contentEquals(newShape)) {
this as NDArray
} else {
NDArray(this.data, this.offset, newShape, dim = DN(newShape.size), base = base ?: this)
}
}
override fun transpose(vararg axes: Int): NDArray {
require(axes.isEmpty() || axes.size == dim.d) { "All dimensions must be indicated." }
for (axis in axes) require(axis in 0 until dim.d) { "Dimension must be from 0 to ${dim.d}." }
require(axes.toSet().size == axes.size) { "The specified dimensions must be unique." }
if (dim.d == 1) return NDArray(this.data, this.offset, this.shape, this.strides, this.dim)
val newShape: IntArray
val newStrides: IntArray
if (axes.isEmpty()) {
newShape = this.shape.reversedArray()
newStrides = this.strides.reversedArray()
} else {
newShape = IntArray(this.shape.size)
newStrides = IntArray(this.strides.size)
for ((i, axis) in axes.withIndex()) {
newShape[i] = this.shape[axis]
newStrides[i] = this.strides[axis]
}
}
return NDArray(this.data, this.offset, newShape, newStrides, this.dim, base = base ?: this)
}
override fun squeeze(vararg axes: Int): NDArray {
val cutAxes = if (axes.isEmpty()) {
shape.withIndex().filter { it.value == 1 }.map { it.index }
} else {
require(axes.all { shape[it] == 1 }) { "Cannot select an axis to squeeze out which has size not equal to one." }
axes.toList()
}
val newShape = this.shape.sliceArray(this.shape.indices - cutAxes)
return NDArray(this.data, this.offset, newShape, dim = DN(newShape.size), base = base ?: this)
}
override fun unsqueeze(vararg axes: Int): NDArray {
val newShape = shape.toMutableList()
for (axis in axes.sorted()) {
newShape.add(axis, 1)
}
return NDArray(
this.data,
this.offset,
newShape.toIntArray(),
dim = DN(newShape.size),
base = base ?: this
)
}
override infix fun cat(other: MultiArray): NDArray =
cat(listOf(other), 0)
override fun cat(other: MultiArray, axis: Int): NDArray =
cat(listOf(other), axis)
override fun cat(other: List>, axis: Int): NDArray {
val actualAxis = actualAxis(axis)
require(actualAxis in 0 until dim.d) { "Axis $axis is out of bounds for array of dimension $dim" }
val arr = other.first()
require(
this.shape.withIndex()
.all { it.index == axis || it.value == arr.shape[it.index] }) { "All dimensions of input arrays for the concatenation axis must match exactly." }
val newShape = this.shape.copyOf()
newShape[actualAxis] = this.shape[actualAxis] + other.sumOf { shape[actualAxis] }
val newSize = this.size + other.sumOf { size }
val arrays = other.toMutableList().also { it.add(0, this) }
val concatShape =
other.first().multiIndices.last.toMutableList().apply { add(actualAxis, arrays.size - 1) }.toIntArray()
val result = NDArray(initMemoryView(newSize, dtype), 0, newShape, dim = dim)
concatenate(arrays, result, IntArray(concatShape.size)..concatShape, actualAxis)
return result
}
//todo extensions
public fun asD1Array(): D1Array {
if (this.dim.d == 1) return this as D1Array
else throw ClassCastException("Cannot cast NDArray of dimension ${this.dim.d} to NDArray of dimension 1.")
}
//todo
public fun asD2Array(): D2Array {
if (this.dim.d == 2) return this as D2Array
else throw ClassCastException("Cannot cast NDArray of dimension ${this.dim.d} to NDArray of dimension 2.")
}
public fun asD3Array(): D3Array {
if (this.dim.d == 3) return this as D3Array
else throw ClassCastException("Cannot cast NDArray of dimension ${this.dim.d} to NDArray of dimension 3.")
}
public fun asD4Array(): D4Array {
if (this.dim.d == 4) return this as D4Array
else throw ClassCastException("Cannot cast NDArray of dimension ${this.dim.d} to NDArray of dimension 4.")
}
public fun asDNArray(): NDArray {
if (this.dim.d == -1) throw Exception("Array dimension is undefined")
if (this.dim.d > 4) return this as NDArray
return NDArray(this.data, this.offset, this.shape, this.strides, DN(this.dim.d), base = base ?: this)
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as NDArray<*, *>
if (size != other.size) return false
if (!shape.contentEquals(other.shape)) return false
if (dtype != other.dtype) return false
if (dim != other.dim) return false
val thIt = this.iterator()
val othIt = other.iterator()
while (thIt.hasNext() && othIt.hasNext()) {
if (thIt.next() != othIt.next())
return false
}
return true
}
override fun hashCode(): Int {
var result = 1
for (el in this) {
result = 31 * result + el.hashCode()
}
return result
}
override fun toString(): String {
return when (dim.d) {
1 -> buildString {
this@NDArray as NDArray
append('[')
for (i in 0 until shape.first()) {
append(this@NDArray[i])
if (i < shape.first() - 1)
append(", ")
}
append(']')
}
2 -> buildString {
this@NDArray as NDArray
append('[')
for (ax0 in 0 until shape[0]) {
append('[')
for (ax1 in 0 until shape[1]) {
append(this@NDArray[ax0, ax1])
if (ax1 < shape[1] - 1)
append(", ")
}
append(']')
if (ax0 < shape[0] - 1)
append(",\n")
}
append(']')
}
3 -> buildString {
this@NDArray as NDArray
append('[')
for (ax0 in 0 until shape[0]) {
append('[')
for (ax1 in 0 until shape[1]) {
append('[')
for (ax2 in 0 until shape[2]) {
append(this@NDArray[ax0, ax1, ax2])
if (ax2 < shape[2] - 1)
append(", ")
}
append(']')
if (ax1 < shape[1] - 1)
append(",\n")
}
append(']')
if (ax0 < shape[0] - 1)
append(",\n\n")
}
append(']')
}
4 -> buildString {
this@NDArray as NDArray
append('[')
for (ax0 in 0 until shape[0]) {
append('[')
for (ax1 in 0 until shape[1]) {
append('[')
for (ax2 in 0 until shape[2]) {
append('[')
for (ax3 in 0 until shape[3]) {
append(this@NDArray[ax0, ax1, ax2, ax3])
if (ax3 < shape[3] - 1)
append(", ")
}
append(']')
if (ax2 < shape[2] - 1)
append(",\n")
}
append(']')
if (ax1 < shape[1] - 1)
append(",\n\n")
}
append(']')
if (ax0 < shape[0] - 1)
append(",\n\n\n")
}
append(']')
}
else -> buildString {
this@NDArray as NDArray<*, DN>
append('[')
for (ind in 0 until shape.first()) {
append([email protected][ind].toString())
if (ind < shape.first() - 1) {
val newLine = "\n".repeat(dim.d - 1)
append(",$newLine")
}
}
append(']')
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy