
commonMain.space.kscience.kmath.linear.MatrixContext.kt Maven / Gradle / Ivy
package space.kscience.kmath.linear
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.SpaceOperations
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.operations.sum
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.asSequence
import kotlin.reflect.KClass
/**
* Basic operations on matrices. Operates on [Matrix].
*
* @param T the type of items in the matrices.
* @param M the type of operated matrices.
*/
public interface MatrixContext> : SpaceOperations> {
/**
* Produces a matrix with this context and given dimensions.
*/
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
/**
* Produces a point compatible with matrix space (and possibly optimized for it).
*/
public fun point(size: Int, initializer: (Int) -> T): Point = Buffer.boxing(size, initializer)
@Suppress("UNCHECKED_CAST")
public override fun binaryOperationFunction(operation: String): (left: Matrix, right: Matrix) -> M =
when (operation) {
"dot" -> { left, right -> left dot right }
else -> super.binaryOperationFunction(operation) as (Matrix, Matrix) -> M
}
/**
* Computes the dot product of this matrix and another one.
*
* @receiver the multiplicand.
* @param other the multiplier.
* @return the dot product.
*/
public infix fun Matrix.dot(other: Matrix): M
/**
* Computes the dot product of this matrix and a vector.
*
* @receiver the multiplicand.
* @param vector the multiplier.
* @return the dot product.
*/
public infix fun Matrix.dot(vector: Point): Point
/**
* Multiplies a matrix by its element.
*
* @receiver the multiplicand.
* @param value the multiplier.
* @receiver the product.
*/
public operator fun Matrix.times(value: T): M
/**
* Multiplies an element by a matrix of it.
*
* @receiver the multiplicand.
* @param m the multiplier.
* @receiver the product.
*/
public operator fun T.times(m: Matrix): M = m * this
/**
* Gets a feature from the matrix. This function may return some additional features to
* [kscience.kmath.nd.NDStructure.getFeature].
*
* @param F the type of feature.
* @param m the matrix.
* @param type the [KClass] instance of [F].
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public fun getFeature(m: Matrix, type: KClass): F? = m.getFeature(type)
public companion object {
/**
* A structured matrix with custom buffer
*/
public fun > buffered(
ring: R,
bufferFactory: BufferFactory = Buffer.Companion::boxing,
): GenericMatrixContext> = BufferMatrixContext(ring, bufferFactory)
/**
* Automatic buffered matrix, unboxed if it is possible
*/
public inline fun > auto(ring: R): GenericMatrixContext> =
buffered(ring, Buffer.Companion::auto)
}
}
/**
* Gets a feature from the matrix. This function may return some additional features to
* [kscience.kmath.nd.NDStructure.getFeature].
*
* @param T the type of items in the matrices.
* @param M the type of operated matrices.
* @param F the type of feature.
* @receiver the [MatrixContext] of [T].
* @param m the matrix.
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public inline fun MatrixContext.getFeature(m: Matrix): F? =
getFeature(m, F::class)
/**
* Partial implementation of [MatrixContext] for matrices of [Ring].
*
* @param T the type of items in the matrices.
* @param R the type of ring of matrix elements.
* @param M the type of operated matrices.
*/
public interface GenericMatrixContext, out M : Matrix> : MatrixContext {
/**
* The ring over matrix elements.
*/
public val elementContext: R
public override infix fun Matrix.dot(other: Matrix): M {
//TODO add typed error
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
return produce(rowNum, other.colNum) { i, j ->
val row = rows[i]
val column = other.columns[j]
elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) }
}
}
public override infix fun Matrix.dot(vector: Point): Point {
//TODO add typed error
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
return point(rowNum) { i ->
val row = rows[i]
elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) }
}
}
public override operator fun Matrix.unaryMinus(): M =
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
public override fun add(a: Matrix, b: Matrix): M {
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
}
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
}
public override operator fun Matrix.minus(b: Matrix): M {
require(rowNum == b.rowNum && colNum == b.colNum) {
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
}
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
}
public override fun multiply(a: Matrix, k: Number): M =
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
public override operator fun Matrix.times(value: T): M =
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy