
commonMain.space.kscience.kmath.linear.MatrixWrapper.kt Maven / Gradle / Ivy
package space.kscience.kmath.linear
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.getFeature
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.structures.asBuffer
import kotlin.math.sqrt
import kotlin.reflect.KClass
import kotlin.reflect.safeCast
/**
* A [Matrix] that holds [MatrixFeature] objects.
*
* @param T the type of items.
*/
public class MatrixWrapper internal constructor(
public val origin: Matrix,
public val features: Set,
) : Matrix by origin {
/**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/
@UnstableKMathAPI
override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) })
?: origin.getFeature(type)
override fun equals(other: Any?): Boolean = origin == other
override fun hashCode(): Int = origin.hashCode()
override fun toString(): String {
return "MatrixWrapper(matrix=$origin, features=$features)"
}
}
/**
* Return the original matrix. If this is a wrapper, return its origin. If not, this matrix.
* Origin does not necessary store all features.
*/
@UnstableKMathAPI
public val Matrix.origin: Matrix
get() = (this as? MatrixWrapper)?.origin ?: this
/**
* Add a single feature to a [Matrix]
*/
public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) {
MatrixWrapper(origin, features + newFeature)
} else {
MatrixWrapper(this, setOf(newFeature))
}
/**
* Add a collection of features to a [Matrix]
*/
public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper =
if (this is MatrixWrapper) {
MatrixWrapper(origin, features + newFeatures)
} else {
MatrixWrapper(this, newFeatures.toSet())
}
/**
* Build a square matrix from given elements.
*/
public fun Structure2D.Companion.square(vararg elements: T): Matrix {
val size: Int = sqrt(elements.size.toDouble()).toInt()
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer)
}
/**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/
public fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix =
VirtualMatrix(rows, columns) { i, j ->
if (i == j) elementContext.one else elementContext.zero
} + UnitFeature
/**
* A virtual matrix of zeroes
*/
public fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix =
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature
public class TransposedFeature(public val original: Matrix) : MatrixFeature
/**
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
*/
@OptIn(UnstableKMathAPI::class)
public fun Matrix.transpose(): Matrix {
return getFeature>()?.original ?: VirtualMatrix(
colNum,
rowNum,
) { i, j -> get(j, i) } + TransposedFeature(this)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy