All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.space.kscience.kmath.operations.BigInt.kt Maven / Gradle / Ivy

package space.kscience.kmath.operations

import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferedNDRing
import space.kscience.kmath.nd.NDAlgebra
import space.kscience.kmath.operations.BigInt.Companion.BASE
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.MutableBuffer
import kotlin.math.log2
import kotlin.math.max
import kotlin.math.min
import kotlin.math.sign

public typealias Magnitude = UIntArray
public typealias TBase = ULong

/**
 * Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger).
 *
 * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai)
 */
@OptIn(UnstableKMathAPI::class)
public object BigIntField : Field, RingWithNumbers {
    override val zero: BigInt = BigInt.ZERO
    override val one: BigInt = BigInt.ONE

    override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b)
    override fun number(value: Number): BigInt = value.toLong().toBigInt()

    override fun multiply(a: BigInt, k: Number): BigInt = a.times(number(k))

    override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)

    public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")

    public operator fun String.unaryMinus(): BigInt =
        -(this.parseBigInteger() ?: error("Can't parse $this as big integer"))

    override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b)
}

public class BigInt internal constructor(
    private val sign: Byte,
    private val magnitude: Magnitude
    ) : Comparable {
    public override fun compareTo(other: BigInt): Int = when {
        (sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
        sign < other.sign -> -1
        sign > other.sign -> 1
        else -> sign * compareMagnitudes(magnitude, other.magnitude)
    }

    public override fun equals(other: Any?): Boolean =
        if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")

    public override fun hashCode(): Int = magnitude.hashCode() + sign

    public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude)

    public operator fun unaryMinus(): BigInt =
        if (this.sign == 0.toByte()) this else BigInt((-this.sign).toByte(), this.magnitude)

    public operator fun plus(b: BigInt): BigInt = when {
        b.sign == 0.toByte() -> this
        sign == 0.toByte() -> b
        this == -b -> ZERO
        sign == b.sign -> BigInt(sign, addMagnitudes(magnitude, b.magnitude))

        else -> {
            val comp = compareMagnitudes(magnitude, b.magnitude)

            if (comp == 1)
                BigInt(sign, subtractMagnitudes(magnitude, b.magnitude))
            else
                BigInt((-sign).toByte(), subtractMagnitudes(b.magnitude, magnitude))
        }
    }

    public operator fun minus(b: BigInt): BigInt = this + (-b)

    public operator fun times(b: BigInt): BigInt = when {
        this.sign == 0.toByte() -> ZERO
        b.sign == 0.toByte() -> ZERO
//          TODO: Karatsuba
        else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
    }

    public operator fun times(other: UInt): BigInt = when {
        sign == 0.toByte() -> ZERO
        other == 0U -> ZERO
        else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
    }

    public operator fun times(other: Int): BigInt = if (other > 0)
        this * kotlin.math.abs(other).toUInt()
    else
        -this * kotlin.math.abs(other).toUInt()

    public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))

    public operator fun div(other: Int): BigInt = BigInt(
        (this.sign * other.sign).toByte(),
        divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())
    )

    private fun division(other: BigInt): Pair {
        // Long division algorithm:
        //     https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder
        // TODO: Implement more effective algorithm
        var q = ZERO
        var r = ZERO

        val bitSize =
            (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: 0f + 1)).toInt()

        for (i in bitSize downTo 0) {
            r = r shl 1
            r = r or ((abs(this) shr i) and ONE)

            if (r >= abs(other)) {
                r -= abs(other)
                q += (ONE shl i)
            }
        }

        return Pair(BigInt((this.sign * other.sign).toByte(), q.magnitude), r)
    }

    public operator fun div(other: BigInt): BigInt = division(other).first

    public infix fun shl(i: Int): BigInt {
        if (this == ZERO) return ZERO
        if (i == 0) return this
        val fullShifts = i / BASE_SIZE + 1
        val relShift = i % BASE_SIZE
        val shiftLeft = { x: UInt -> if (relShift >= 32) 0U else x shl relShift }
        val shiftRight = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift) }
        val newMagnitude = Magnitude(magnitude.size + fullShifts)

        for (j in magnitude.indices) {
            newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j])

            if (j != 0)
                newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1])
        }

        newMagnitude[magnitude.size + fullShifts - 1] = shiftRight(magnitude.last())
        return BigInt(this.sign, stripLeadingZeros(newMagnitude))
    }

    public infix fun shr(i: Int): BigInt {
        if (this == ZERO) return ZERO
        if (i == 0) return this
        val fullShifts = i / BASE_SIZE
        val relShift = i % BASE_SIZE
        val shiftRight = { x: UInt -> if (relShift >= 32) 0U else x shr relShift }
        val shiftLeft = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift) }
        if (this.magnitude.size - fullShifts <= 0) return ZERO
        val newMagnitude: Magnitude = Magnitude(magnitude.size - fullShifts)

        for (j in fullShifts until magnitude.size) {
            newMagnitude[j - fullShifts] = shiftRight(magnitude[j])

            if (j != magnitude.size - 1)
                newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(magnitude[j + 1])
        }

        return BigInt(this.sign, stripLeadingZeros(newMagnitude))
    }

    public infix fun or(other: BigInt): BigInt {
        if (this == ZERO) return other
        if (other == ZERO) return this
        val resSize = max(magnitude.size, other.magnitude.size)
        val newMagnitude: Magnitude = Magnitude(resSize)

        for (i in 0 until resSize) {
            if (i < magnitude.size) newMagnitude[i] = newMagnitude[i] or magnitude[i]
            if (i < other.magnitude.size) newMagnitude[i] = newMagnitude[i] or other.magnitude[i]
        }

        return BigInt(1, stripLeadingZeros(newMagnitude))
    }

    public infix fun and(other: BigInt): BigInt {
        if ((this == ZERO) or (other == ZERO)) return ZERO
        val resSize = min(this.magnitude.size, other.magnitude.size)
        val newMagnitude: Magnitude = Magnitude(resSize)
        for (i in 0 until resSize) newMagnitude[i] = this.magnitude[i] and other.magnitude[i]
        return BigInt(1, stripLeadingZeros(newMagnitude))
    }

    public operator fun rem(other: Int): Int {
        val res = this - (this / other) * other
        return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt()
    }

    public operator fun rem(other: BigInt): BigInt = this - (this / other) * other

    public fun modPow(exponent: BigInt, m: BigInt): BigInt = when {
        exponent == ZERO -> ONE
        exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m

        else -> {
            val sqRoot = modPow(exponent / 2, m)
            (sqRoot * sqRoot) % m
        }
    }

    override fun toString(): String {
        if (this.sign == 0.toByte()) {
            return "0x0"
        }
        var res: String = if (this.sign == (-1).toByte()) "-0x" else "0x"
        var numberStarted = false

        for (i in this.magnitude.size - 1 downTo 0) {
            for (j in BASE_SIZE / 4 - 1 downTo 0) {
                val curByte = (this.magnitude[i] shr 4 * j) and 0xfU
                if (numberStarted or (curByte != 0U)) {
                    numberStarted = true
                    res += hexMapping[curByte]
                }
            }
        }

        return res
    }

    public companion object {
        public const val BASE: ULong = 0xffffffffUL
        public const val BASE_SIZE: Int = 32
        public val ZERO: BigInt = BigInt(0, uintArrayOf())
        public val ONE: BigInt = BigInt(1, uintArrayOf(1u))

        private val hexMapping: HashMap = hashMapOf(
            0U to "0", 1U to "1", 2U to "2", 3U to "3",
            4U to "4", 5U to "5", 6U to "6", 7U to "7",
            8U to "8", 9U to "9", 10U to "a", 11U to "b",
            12U to "c", 13U to "d", 14U to "e", 15U to "f"
        )

        private fun compareMagnitudes(mag1: Magnitude, mag2: Magnitude): Int {
            when {
                mag1.size > mag2.size -> return 1
                mag1.size < mag2.size -> return -1
                else -> {
                    for (i in mag1.size - 1 downTo 0) {
                        if (mag1[i] > mag2[i]) {
                            return 1
                        } else if (mag1[i] < mag2[i]) {
                            return -1
                        }
                    }
                    return 0
                }
            }
        }

        private fun addMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
            val resultLength = max(mag1.size, mag2.size) + 1
            val result = Magnitude(resultLength)
            var carry = 0uL

            for (i in 0 until resultLength - 1) {
                val res = when {
                    i >= mag1.size -> mag2[i].toULong() + carry
                    i >= mag2.size -> mag1[i].toULong() + carry
                    else -> mag1[i].toULong() + mag2[i].toULong() + carry
                }

                result[i] = (res and BASE).toUInt()
                carry = (res shr BASE_SIZE)
            }

            result[resultLength - 1] = carry.toUInt()
            return stripLeadingZeros(result)
        }

        private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
            val resultLength = mag1.size
            val result = Magnitude(resultLength)
            var carry = 0L

            for (i in 0 until resultLength) {
                var res =
                    if (i < mag2.size) mag1[i].toLong() - mag2[i].toLong() - carry
                    else mag1[i].toLong() - carry

                carry = if (res < 0) 1 else 0
                res += carry * (BASE + 1UL).toLong()

                result[i] = res.toUInt()
            }

            return stripLeadingZeros(result)
        }

        private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
            val resultLength = mag.size + 1
            val result = Magnitude(resultLength)
            var carry = 0uL

            for (i in mag.indices) {
                val cur: ULong = carry + mag[i].toULong() * x.toULong()
                result[i] = (cur and BASE).toUInt()
                carry = cur shr BASE_SIZE
            }
            result[resultLength - 1] = (carry and BASE).toUInt()

            return stripLeadingZeros(result)
        }

        private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
            val resultLength = mag1.size + mag2.size
            val result = Magnitude(resultLength)

            for (i in mag1.indices) {
                var carry = 0uL

                for (j in mag2.indices) {
                    val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry
                    result[i + j] = (cur and BASE).toUInt()
                    carry = cur shr BASE_SIZE
                }

                result[i + mag2.size] = (carry and BASE).toUInt()
            }

            return stripLeadingZeros(result)
        }

        private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
            val resultLength = mag.size
            val result = Magnitude(resultLength)
            var carry = 0uL

            for (i in mag.size - 1 downTo 0) {
                val cur: ULong = mag[i].toULong() + (carry shl BASE_SIZE)
                result[i] = (cur / x).toUInt()
                carry = cur % x
            }

            return stripLeadingZeros(result)
        }
    }
}

private fun stripLeadingZeros(mag: Magnitude): Magnitude {
    if (mag.isEmpty() || mag.last() != 0U) return mag
    var resSize = mag.size - 1

    while (mag[resSize] == 0U) {
        if (resSize == 0) break
        resSize -= 1
    }

    return mag.sliceArray(IntRange(0, resSize))
}

public fun abs(x: BigInt): BigInt = x.abs()

/**
 * Convert this [Int] to [BigInt]
 */
public fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt()))

/**
 * Convert this [Long] to [BigInt]
 */
public fun Long.toBigInt(): BigInt = BigInt(
    sign.toByte(),
    stripLeadingZeros(
        uintArrayOf(
            (kotlin.math.abs(this).toULong() and BASE).toUInt(),
            ((kotlin.math.abs(this).toULong() shr BASE_SIZE) and BASE).toUInt()
        )
    )
)

/**
 * Convert UInt to [BigInt]
 */
public fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this))

/**
 * Convert ULong to [BigInt]
 */
public fun ULong.toBigInt(): BigInt = BigInt(
    1,
    stripLeadingZeros(
        uintArrayOf(
            (this and BASE).toUInt(),
            ((this shr BASE_SIZE) and BASE).toUInt()
        )
    )
)

/**
 * Create a [BigInt] with this array of magnitudes with protective copy
 */
public fun UIntArray.toBigInt(sign: Byte): BigInt {
    require(sign != 0.toByte() || !isNotEmpty())
    return BigInt(sign, copyOf())
}

private val hexChToInt: MutableMap = hashMapOf(
    '0' to 0, '1' to 1, '2' to 2, '3' to 3,
    '4' to 4, '5' to 5, '6' to 6, '7' to 7,
    '8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
    'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)

/**
 * Returns null if a valid number can not be read from a string
 */
public fun String.parseBigInteger(): BigInt? {
    val sign: Int
    val sPositive: String

    when {
        this[0] == '+' -> {
            sign = +1
            sPositive = this.substring(1)
        }
        this[0] == '-' -> {
            sign = -1
            sPositive = this.substring(1)
        }
        else -> {
            sPositive = this
            sign = +1
        }
    }

    var res = BigInt.ZERO
    var digitValue = BigInt.ONE
    val sPositiveUpper = sPositive.toUpperCase()

    if (sPositiveUpper.startsWith("0X")) {  // hex representation
        val sHex = sPositiveUpper.substring(2)

        for (ch in sHex.reversed()) {
            if (ch == '_') continue
            res += digitValue * (hexChToInt[ch] ?: return null)
            digitValue *= 16.toBigInt()
        }
    } else for (ch in sPositiveUpper.reversed()) {
        // decimal representation
        if (ch == '_') continue
        if (ch !in '0'..'9') {
            return null
        }
        res += digitValue * (ch.toInt() - '0'.toInt())
        digitValue *= 10.toBigInt()
    }

    return res * sign
}

public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer =
    boxing(size, initializer)

public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer =
    boxing(size, initializer)

public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BufferedNDRing =
    BufferedNDRing(shape, BigIntField, Buffer.Companion::bigInt)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy