All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.linear.MatrixContext.kt Maven / Gradle / Ivy

package space.kscience.kmath.linear

import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.SpaceOperations
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.operations.sum
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.asSequence
import kotlin.reflect.KClass

/**
 * Basic operations on matrices. Operates on [Matrix].
 *
 * @param T the type of items in the matrices.
 * @param M the type of operated matrices.
 */
public interface MatrixContext> : SpaceOperations> {
    /**
     * Produces a matrix with this context and given dimensions.
     */
    public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M

    /**
     * Produces a point compatible with matrix space (and possibly optimized for it).
     */
    public fun point(size: Int, initializer: (Int) -> T): Point = Buffer.boxing(size, initializer)

    @Suppress("UNCHECKED_CAST")
    public override fun binaryOperationFunction(operation: String): (left: Matrix, right: Matrix) -> M =
        when (operation) {
            "dot" -> { left, right -> left dot right }
            else -> super.binaryOperationFunction(operation) as (Matrix, Matrix) -> M
        }

    /**
     * Computes the dot product of this matrix and another one.
     *
     * @receiver the multiplicand.
     * @param other the multiplier.
     * @return the dot product.
     */
    public infix fun Matrix.dot(other: Matrix): M

    /**
     * Computes the dot product of this matrix and a vector.
     *
     * @receiver the multiplicand.
     * @param vector the multiplier.
     * @return the dot product.
     */
    public infix fun Matrix.dot(vector: Point): Point

    /**
     * Multiplies a matrix by its element.
     *
     * @receiver the multiplicand.
     * @param value the multiplier.
     * @receiver the product.
     */
    public operator fun Matrix.times(value: T): M

    /**
     * Multiplies an element by a matrix of it.
     *
     * @receiver the multiplicand.
     * @param m the multiplier.
     * @receiver the product.
     */
    public operator fun T.times(m: Matrix): M = m * this

    /**
     * Gets a feature from the matrix. This function may return some additional features to
     * [kscience.kmath.nd.NDStructure.getFeature].
     *
     * @param F the type of feature.
     * @param m the matrix.
     * @param type the [KClass] instance of [F].
     * @return a feature object or `null` if it isn't present.
     */
    @UnstableKMathAPI
    public fun  getFeature(m: Matrix, type: KClass): F? = m.getFeature(type)

    public companion object {

        /**
         * A structured matrix with custom buffer
         */
        public fun > buffered(
            ring: R,
            bufferFactory: BufferFactory = Buffer.Companion::boxing,
        ): GenericMatrixContext> = BufferMatrixContext(ring, bufferFactory)

        /**
         * Automatic buffered matrix, unboxed if it is possible
         */
        public inline fun > auto(ring: R): GenericMatrixContext> =
            buffered(ring, Buffer.Companion::auto)
    }
}

/**
 * Gets a feature from the matrix. This function may return some additional features to
 * [kscience.kmath.nd.NDStructure.getFeature].
 *
 * @param T the type of items in the matrices.
 * @param M the type of operated matrices.
 * @param F the type of feature.
 * @receiver the [MatrixContext] of [T].
 * @param m the matrix.
 * @return a feature object or `null` if it isn't present.
 */
@UnstableKMathAPI
public inline fun  MatrixContext.getFeature(m: Matrix): F? =
    getFeature(m, F::class)

/**
 * Partial implementation of [MatrixContext] for matrices of [Ring].
 *
 * @param T the type of items in the matrices.
 * @param R the type of ring of matrix elements.
 * @param M the type of operated matrices.
 */
public interface GenericMatrixContext, out M : Matrix> : MatrixContext {
    /**
     * The ring over matrix elements.
     */
    public val elementContext: R

    public override infix fun Matrix.dot(other: Matrix): M {
        //TODO add typed error
        require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }

        return produce(rowNum, other.colNum) { i, j ->
            val row = rows[i]
            val column = other.columns[j]
            elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) }
        }
    }

    public override infix fun Matrix.dot(vector: Point): Point {
        //TODO add typed error
        require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }

        return point(rowNum) { i ->
            val row = rows[i]
            elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) }
        }
    }

    public override operator fun Matrix.unaryMinus(): M =
        produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }

    public override fun add(a: Matrix, b: Matrix): M {
        require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
            "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
        }

        return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
    }

    public override operator fun Matrix.minus(b: Matrix): M {
        require(rowNum == b.rowNum && colNum == b.colNum) {
            "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
        }

        return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
    }

    public override fun multiply(a: Matrix, k: Number): M =
        produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }

    public override operator fun Matrix.times(value: T): M =
        produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy