All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.linear.LupDecomposition.kt Maven / Gradle / Ivy

package space.kscience.kmath.linear

import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.getFeature
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferAccessor2D
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableBufferFactory

/**
 * Common implementation of [LupDecompositionFeature].
 */
public class LupDecomposition(
    public val context: MatrixContext>,
    public val elementContext: Field,
    public val lu: Matrix,
    public val pivot: IntArray,
    private val even: Boolean,
) : LupDecompositionFeature, DeterminantFeature {
    /**
     * Returns the matrix L of the decomposition.
     *
     * L is a lower-triangular matrix with [Ring.one] in diagonal
     */
    override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
        when {
            j < i -> lu[i, j]
            j == i -> elementContext.one
            else -> elementContext.zero
        }
    } + LFeature


    /**
     * Returns the matrix U of the decomposition.
     *
     * U is an upper-triangular matrix including the diagonal
     */
    override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
        if (j >= i) lu[i, j] else elementContext.zero
    } + UFeature

    /**
     * Returns the P rows permutation matrix.
     *
     * P is a sparse matrix with exactly one element set to [Ring.one] in
     * each row and each column, all other elements being set to [Ring.zero].
     */
    override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
        if (j == pivot[i]) elementContext.one else elementContext.zero
    }

    /**
     * Return the determinant of the matrix
     * @return determinant of the matrix
     */
    override val determinant: T by lazy {
        elementContext { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } }
    }

}

@PublishedApi
internal fun , F : Field> GenericMatrixContext.abs(value: T): T =
    if (value > elementContext.zero) value else elementContext { -value }

/**
 * Create a lup decomposition of generic matrix.
 */
public fun > MatrixContext>.lup(
    factory: MutableBufferFactory,
    elementContext: Field,
    matrix: Matrix,
    checkSingular: (T) -> Boolean,
): LupDecomposition {
    require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
    val m = matrix.colNum
    val pivot = IntArray(matrix.rowNum)

    //TODO just waits for KEEP-176
    BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run {
        elementContext {
            val lu = create(matrix)

            // Initialize permutation array and parity
            for (row in 0 until m) pivot[row] = row
            var even = true

            // Initialize permutation array and parity
            for (row in 0 until m) pivot[row] = row

            // Loop over columns
            for (col in 0 until m) {
                // upper
                for (row in 0 until col) {
                    val luRow = lu.row(row)
                    var sum = luRow[col]
                    for (i in 0 until row) sum -= luRow[i] * lu[i, col]
                    luRow[col] = sum
                }

                // lower
                var max = col // permutation row
                var largest = -one

                for (row in col until m) {
                    val luRow = lu.row(row)
                    var sum = luRow[col]
                    for (i in 0 until col) sum -= luRow[i] * lu[i, col]
                    luRow[col] = sum

                    // maintain best permutation choice
                    if (abs(sum) > largest) {
                        largest = abs(sum)
                        max = row
                    }
                }

                // Singularity check
                check(!checkSingular(abs(lu[max, col]))) { "The matrix is singular" }

                // Pivot if necessary
                if (max != col) {
                    val luMax = lu.row(max)
                    val luCol = lu.row(col)

                    for (i in 0 until m) {
                        val tmp = luMax[i]
                        luMax[i] = luCol[i]
                        luCol[i] = tmp
                    }

                    val temp = pivot[max]
                    pivot[max] = pivot[col]
                    pivot[col] = temp
                    even = !even
                }

                // Divide the lower elements by the "winning" diagonal elt.
                val luDiag = lu[col, col]
                for (row in col + 1 until m) lu[row, col] /= luDiag
            }

            return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
        }
    }
}

public inline fun , F : Field> GenericMatrixContext>.lup(
    matrix: Matrix,
    noinline checkSingular: (T) -> Boolean,
): LupDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)

public fun MatrixContext>.lup(matrix: Matrix): LupDecomposition =
    lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 }

public fun  LupDecomposition.solveWithLup(
    factory: MutableBufferFactory,
    matrix: Matrix,
): Matrix {
    require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }

    BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run {
        elementContext {
            // Apply permutations to b
            val bp = create { _, _ -> zero }

            for (row in pivot.indices) {
                val bpRow = bp.row(row)
                val pRow = pivot[row]
                for (col in 0 until matrix.colNum) bpRow[col] = matrix[pRow, col]
            }

            // Solve LY = b
            for (col in pivot.indices) {
                val bpCol = bp.row(col)

                for (i in col + 1 until pivot.size) {
                    val bpI = bp.row(i)
                    val luICol = lu[i, col]
                    for (j in 0 until matrix.colNum) {
                        bpI[j] -= bpCol[j] * luICol
                    }
                }
            }

            // Solve UX = Y
            for (col in pivot.size - 1 downTo 0) {
                val bpCol = bp.row(col)
                val luDiag = lu[col, col]
                for (j in 0 until matrix.colNum) bpCol[j] /= luDiag

                for (i in 0 until col) {
                    val bpI = bp.row(i)
                    val luICol = lu[i, col]
                    for (j in 0 until matrix.colNum) bpI[j] -= bpCol[j] * luICol
                }
            }

            return context.produce(pivot.size, matrix.colNum) { i, j -> bp[i, j] }
        }
    }
}

public inline fun  LupDecomposition.solveWithLup(matrix: Matrix): Matrix =
    solveWithLup(MutableBuffer.Companion::auto, matrix)

/**
 * Solves a system of linear equations *ax = b** using LUP decomposition.
 */
@OptIn(UnstableKMathAPI::class)
public inline fun , F : Field> GenericMatrixContext>.solveWithLup(
    a: Matrix,
    b: Matrix,
    noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto,
    noinline checkSingular: (T) -> Boolean,
): Matrix {
    // Use existing decomposition if it is provided by matrix
    val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular)
    return decomposition.solveWithLup(bufferFactory, b)
}

public inline fun , F : Field> GenericMatrixContext>.inverseWithLup(
    matrix: Matrix,
    noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto,
    noinline checkSingular: (T) -> Boolean,
): Matrix = solveWithLup(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)


@OptIn(UnstableKMathAPI::class)
public fun RealMatrixContext.solveWithLup(a: Matrix, b: Matrix): Matrix {
    // Use existing decomposition if it is provided by matrix
    val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real
    val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
    return decomposition.solveWithLup(bufferFactory, b)
}

/**
 * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
 */
public fun RealMatrixContext.inverseWithLup(matrix: Matrix): Matrix =
    solveWithLup(matrix, one(matrix.rowNum, matrix.colNum))




© 2015 - 2025 Weber Informatics LLC | Privacy Policy