spire.math.Natural.scala Maven / Gradle / Ivy
package spire.math
import scala.annotation.tailrec
import scala.math.{ScalaNumber, ScalaNumericConversions}
import scala.{specialized => spec}
import spire.algebra.{IsIntegral, Order, Rig, Signed}
import Natural._
// NOTE: this class works, but is only optimal for a relatively narrow
// set of problems. for really big numbers you're probably better off
// using SafeLong or BigInt. that said, there are cases where Natural
// is faster (e.g. addition with 32-128 bit numbers).
// TODO: almost none of this recursion is tailrec. the first goal was
// correctness, but once that's achieved we need to focus on efficiency.
// using a similar "private mutable" strategy that :: and ListBuffer
// use in Scala, we should be able to efficiently build Digit chains
// in a tail-recursive way.
@SerialVersionUID(0L)
sealed abstract class Natural extends ScalaNumber with ScalaNumericConversions with Serializable {
lhs =>
def digit: UInt
def foldDigitsLeft[@spec A](a: A)(f: (A, UInt) => A): A = {
@tailrec def recur(next: Natural, sofar: A): A = next match {
case End(d) => f(a, d)
case Digit(d, tail) => recur(tail, f(a, d))
}
recur(this, a)
}
def foldDigitsRight[@spec A](a: A)(f: (A, UInt) => A): A =
reversed.foldDigitsLeft(a)(f)
def getNumBits: Int = {
@tailrec
def bit(n: UInt, b: Int): Int = if (n == UInt(0)) b else bit(n >>> 1, b + 1)
@tailrec
def recur(next: Natural, b: Int): Int = next match {
case End(d) => b + bit(d, 0)
case Digit(_, tail) => recur(tail, b + 32)
}
recur(this, 0)
}
def getDigitLength: Int = {
@tailrec
def recur(next: Natural, n: Int): Int = next match {
case End(d) => n + 1
case Digit(d, tail) => recur(tail, n + 1)
}
recur(this, 0)
}
def toList: List[UInt] = {
@tailrec
def recur(next: Natural, sofar: List[UInt]): List[UInt] = next match {
case End(d) => d :: sofar
case Digit(d, tail) => recur(tail, d :: sofar)
}
recur(this, Nil)
}
// Array[UInt] would be boxed so we do this for now.
def toArray: Array[Int] = {
val n = getDigitLength
val arr = new Array[Int](n)
@tailrec
def recur(next: Natural, i: Int): Unit = next match {
case End(d) =>
arr(i) = d.signed
case Digit(d, tail) =>
arr(i) = d.signed
recur(tail, i - 1)
}
recur(this, n - 1)
arr
}
def reversed: Natural = {
@tailrec
def recur(next: Natural, sofar: Natural): Natural = next match {
case End(d) => Digit(d, sofar)
case Digit(d, tail) => recur(tail, Digit(d, sofar))
}
this match {
case Digit(d, tail) => recur(tail, End(d))
case _ => this
}
}
def trim: Natural = {
@tailrec
def recur(next: Natural): Natural = {
next match {
case Digit(n, tail) =>
if (n == UInt(0)) recur(tail) else next
case End(n) =>
next
}
}
recur(reversed).reversed
}
def isWhole: Boolean = true
def underlying: Object = this
def intValue: Int = toInt
def longValue: Long = toLong
def floatValue: Float = toBigInt.toFloat
def doubleValue: Double = toBigInt.toDouble
override def toInt: Int = digit.toInt & 0x7fffffff
override def toLong: Long = this match {
case End(d) => d.toLong
case Digit(d, tail) => (tail.toLong << 32L) + d.toLong
}
def toBigInt: BigInt = this match {
case End(d) => BigInt(d.toLong)
case Digit(d, tail) => (tail.toBigInt << 32) + BigInt(d.toLong)
}
// calculate 9 digits at a time using /%
override def toString: String = {
@tailrec def recur(next: Natural, s: String): String = {
next match {
case End(d) =>
d.toLong.toString + s
case Digit(d, tail) =>
val (q, r) = next /% Natural.denom
if (q.isZero)
r.digit.toLong.toString + s
else
recur(q, "%09d%s" format (r.digit.toLong, s))
}
}
recur(this, "")
}
def toRepr: String = toList.mkString("Natural(", ", ", ")")
def isZero: Boolean = {
@tailrec
def recur(next: Natural): Boolean = next match {
case End(n) =>
n == UInt(0)
case Digit(n, tail) =>
if (n == UInt(0)) recur(tail) else false
}
recur(this)
}
def isOne: Boolean = this match {
case End(n) =>
n == UInt(1)
case Digit(n, tail) =>
n == UInt(1) && tail.isZero
}
def isOdd: Boolean = (digit & UInt(1)) == UInt(1)
def isEven: Boolean = (digit & UInt(1)) == UInt(0)
def powerOfTwo: Int = {
import java.lang.Integer.highestOneBit
def test(n: UInt): Int = {
if ((n.signed & -n.signed) != n.signed) return -1
// TODO: this could be better/faster
var i = 1
while (i < 32 && (n >>> i) != UInt(0)) i += 1
i - 1
}
@tailrec
def recur(next: Natural, shift: Int, bit: Int): Int = next match {
case End(n) =>
val t = test(n)
if (t < 0) -1 else if (bit < 0) shift + t else -1
case Digit(n, tail) =>
val t = test(n)
if (t < 0)
recur(tail, shift + 32, bit)
else if (bit < 0)
recur(tail, shift + 32, shift + t)
else
-1
}
recur(this, 0, -1)
}
def compare(rhs: UInt): Int = this match {
case End(d) =>
if (d < rhs) -1 else if (d > rhs) 1 else 0
case Digit(d, tail) =>
if (tail.isZero)
if (d > rhs) 1 else if (d < rhs) -1 else 0
else
1
}
def compare(rhs: Natural): Int = {
def cmp(a: UInt, b: UInt, c: Int): Int =
if (a < b) -1 else if (a > b) 1 else c
@tailrec
def recur(lhs: Natural, rhs: Natural, d: Int): Int = lhs match {
case End(ld) => rhs match {
case End(rd) => cmp(ld, rd, d)
case _: Digit => -rhs.compare(ld)
}
case Digit(ld, ltail) => rhs match {
case End(rd) => lhs.compare(rd)
case Digit(rd, rtail) => recur(ltail, rtail, cmp(ld, rd, d))
}
}
recur(lhs, rhs, 0)
}
final override def equals(rhs: Any): Boolean = rhs match {
case rhs: Natural => this === rhs
case rhs: UInt => (lhs compare rhs) == 0
case rhs: BigInt => lhs.toBigInt == rhs
case rhs: SafeLong => SafeLong(lhs.toBigInt) == rhs
case rhs: BigDecimal => rhs.isWhole && lhs.toBigInt == rhs
case rhs: Rational => rhs.isWhole && Rational(lhs.toBigInt) == rhs
case rhs: Algebraic => rhs == lhs
case rhs: Real => lhs == rhs.toRational
case rhs: Number => Number(lhs.toBigInt) == rhs
case rhs: Complex[_] => rhs == lhs
case rhs: Quaternion[_] => rhs == lhs
case that => unifiedPrimitiveEquals(that)
}
def ===(rhs: Natural): Boolean =
(lhs compare rhs) == 0
def =!=(rhs: Natural): Boolean =
!(this === rhs)
def <(rhs: Natural): Boolean = (lhs compare rhs) < 0
def <=(rhs: Natural): Boolean = (lhs compare rhs) <= 0
def >(rhs: Natural): Boolean = (lhs compare rhs) > 0
def >=(rhs: Natural): Boolean = (lhs compare rhs) >= 0
def <(r: UInt): Boolean = (lhs compare r) < 0
def <=(r: UInt): Boolean = (lhs compare r) <= 0
def >(r: UInt): Boolean = (lhs compare r) > 0
def >=(r: UInt): Boolean = (lhs compare r) >= 0
def <(r: BigInt): Boolean = (lhs.toBigInt compare r) < 0
def <=(r: BigInt): Boolean = (lhs.toBigInt compare r) <= 0
def >(r: BigInt): Boolean = (lhs.toBigInt compare r) > 0
def >=(r: BigInt): Boolean = (lhs.toBigInt compare r) >= 0
// implemented in Digit and End
def +(rd: UInt): Natural
def -(rd: UInt): Natural
def *(rd: UInt): Natural
def /~(rd: UInt): Natural = lhs / rd
def /(rd: UInt): Natural
def %(rd: UInt): Natural
def /%(rd: UInt): (Natural, Natural)
def +(rhs: BigInt): BigInt = lhs.toBigInt + rhs
def -(rhs: BigInt): BigInt = lhs.toBigInt - rhs
def *(rhs: BigInt): BigInt = lhs.toBigInt * rhs
def /~(rhs: BigInt): BigInt = lhs.toBigInt / rhs
def /(rhs: BigInt): BigInt = lhs.toBigInt / rhs
def %(rhs: BigInt): BigInt = lhs.toBigInt % rhs
def /%(rhs: BigInt): (BigInt, BigInt) = lhs.toBigInt /% rhs
def +(rhs: Natural): Natural = {
def recur(left: Natural, right: Natural, carry: Long): Natural = left match {
case End(ld) => right match {
case End(rd) =>
Natural(ld.toLong + rd.toLong + carry)
case Digit(rd, rtail) =>
val t = ld.toLong + rd.toLong + carry
Digit(UInt(t), rtail + UInt(t >> 32))
}
case Digit(ld, ltail) => right match {
case End(rd) =>
val t = ld.toLong + rd.toLong + carry
Digit(UInt(t), ltail + UInt(t >> 32))
case Digit(rd, rtail) =>
val t = ld.toLong + rd.toLong + carry
Digit(UInt(t), recur(ltail, rtail, t >> 32))
}
}
recur(lhs, rhs, 0L)
}
def -(rhs: Natural): Natural = {
def recur(left: Natural, right: Natural, carry: Long): Natural = left match {
case End(ld) => right match {
case End(rd) =>
Natural(ld.toLong - rd.toLong - carry)
case Digit(rd, rtail) =>
val t = ld.toLong - rd.toLong - carry
val tl = rtail - UInt(-(t >> 32))
if (tl.isInstanceOf[End] && tl.digit == UInt(0))
End(UInt(t))
else
Digit(UInt(t), tl)
}
case Digit(ld, ltail) => right match {
case End(rd) =>
val t = ld.toLong - rd.toLong - carry
val tl = ltail - UInt(-(t >> 32))
if (tl.isInstanceOf[End] && tl.digit == UInt(0))
End(UInt(t))
else
Digit(UInt(t), tl)
case Digit(rd, rtail) =>
val t = ld.toLong - rd.toLong - carry
val tl = recur(ltail, rtail, -(t >> 32))
if (tl.isInstanceOf[End] && tl.digit == UInt(0))
End(UInt(t))
else
Digit(UInt(t), tl)
}
}
if (lhs < rhs)
throw new ArithmeticException("negative subtraction: %s - %s" format (lhs, rhs))
else
recur(lhs, rhs, 0L)
}
def *(rhs: Natural): Natural = lhs match {
case End(ld) => rhs * ld
case Digit(ld, ltail) => rhs match {
case End(rd) => lhs * rd
case Digit(rd, rtail) =>
Digit(UInt(0), Digit(UInt(0), ltail * rtail)) +
Digit(UInt(0), ltail * rd) +
Digit(UInt(0), rtail * ld) +
Natural(ld.toLong * rd.toLong)
}
}
def pow(rhs: Natural): Natural = {
@tailrec def _pow(t: Natural, b: Natural, e: Natural): Natural = {
if (e.isZero) t
else if (e.isOdd) _pow(t * b, b * b, e >> 1)
else _pow(t, b * b, e >> 1)
}
_pow(Natural(1), lhs, rhs)
}
def pow(rhs: UInt): Natural = {
@tailrec def _pow(t: Natural, b: Natural, e: UInt): Natural = {
if (e == UInt(0)) t
else if ((e & UInt(1)) == UInt(1)) _pow(t * b, b * b, e >> 1)
else _pow(t, b * b, e >> 1)
}
_pow(Natural(1), lhs, rhs)
}
def /~(rhs: Natural): Natural = lhs / rhs
def /(rhs: Natural): Natural = {
rhs match {
case End(rd) =>
lhs / rd
case Digit(rd, rtail) => lhs match {
case End(ld) =>
End(UInt(0))
case Digit(ld, ltail) => rhs.compare(UInt(1)) match {
case -1 => throw new IllegalArgumentException("/ by zero")
case 0 =>
lhs
case 1 =>
val p = rhs.powerOfTwo
if (p >= 0) {
lhs >> p
} else {
longdiv(lhs, rhs)._1
}
}
}
}
}
def %(rhs: Natural): Natural = {
rhs match {
case End(rd) => lhs % rd
case Digit(rd, rtail) => lhs match {
case End(ld) => End(ld)
case Digit(ld, ltail) => rhs.compare(UInt(1)) match {
case -1 => throw new IllegalArgumentException("/ by zero")
case 0 => End(UInt(0))
case 1 =>
val p = rhs.powerOfTwo
if (p >= 0)
lhs & ((Natural(1) << p) - UInt(1))
else
longdiv(lhs, rhs)._2
}
}
}
}
def /%(rhs: Natural): (Natural, Natural) = {
rhs match {
case End(rd) => (lhs / rd, lhs % rd)
case Digit(rd, rtail) => lhs match {
case End(ld) => (End(UInt(0)), lhs)
case Digit(ld, ltail) => rhs.compare(UInt(1)) match {
case -1 => throw new IllegalArgumentException("/ by zero")
case 0 => (lhs, Natural(0))
case 1 =>
val p = rhs.powerOfTwo
if (p >= 0) {
val mask = (Natural(1) << p) - UInt(1)
(lhs >> p, lhs & mask)
} else {
longdiv(lhs, rhs)
}
}
}
}
}
private def longdiv(num: Natural, denom: Natural): (Natural, Natural) = {
var rem = num
var quo = Natural(0)
var remBits: Int = rem.getNumBits
var denomBits: Int = denom.getNumBits
var shift: Int = remBits - denomBits
while (shift >= 0) {
val shifted = denom << shift
if (shifted <= rem) {
quo += Natural(1) << shift
rem -= shifted
remBits = rem.getNumBits
shift = remBits - denomBits
} else {
shift -= 1
}
}
(quo, rem)
}
def <<(n: Int): Natural = {
val m: Int = n & 0x1f
def recur(next: Natural, carry: Long): Natural = next match {
case End(d) =>
Natural((d.toLong << m) | carry)
case Digit(d, tail) =>
val t = (d.toLong << m) | carry
Digit(UInt(t), recur(tail, t >> 32))
}
val num = recur(this, 0L)
(0 until n / 32).foldLeft(num)((n, _) => Digit(UInt(0), n))
}
def chop(n: Int): Natural = {
@tailrec def recur(next: Natural, n: Int): Natural = if (n <= 0) {
next
} else {
next match {
case End(d) => End(UInt(0))
case Digit(d, tail) => recur(tail, n - 1)
}
}
recur(this, n)
}
def >>(n: Int): Natural = {
val m: Int = n & 0x1f
def recur(next: Natural, carry: Long): Natural = next match {
case End(d) =>
Natural((d.toLong >> m) | carry)
case Digit(d, tail) =>
val t = (d.toLong | carry) << (32 - m)
Digit(UInt(t >> 32), recur(tail, t & 0xffffffffL))
}
recur(chop(n / 32).reversed, 0L).reversed
}
def |(rhs: Natural): Natural = lhs match {
case End(ld) => rhs match {
case End(rd) => End(ld | rd)
case Digit(rd, rtail) => Digit(ld | rd, rtail)
}
case Digit(ld, ltail) => rhs match {
case End(rd) => Digit(ld | rd, ltail)
case Digit(rd, rtail) => Digit(ld | rd, ltail | rtail)
}
}
def |(rhs: UInt): Natural = lhs match {
case End(ld) => End(ld | rhs)
case Digit(ld, ltail) => Digit(ld | rhs, ltail)
}
def &(rhs: Natural): Natural = {
def and(lhs: Natural, rhs: Natural): Natural = lhs match {
case End(ld) => rhs match {
case End(rd) => End(ld & rd)
case Digit(rd, rtail) => End(ld & rd)
}
case Digit(ld, ltail) => rhs match {
case End(rd) => End(ld & rd)
case Digit(rd, rtail) => Digit(ld & rd, and(ltail, rtail))
}
}
and(lhs, rhs).trim
}
def &(rhs: UInt): Natural = End(digit & rhs)
def ^(rhs: Natural): Natural = {
def xor(lhs: Natural, rhs: Natural): Natural = lhs match {
case End(ld) => rhs match {
case End(rd) => End(ld ^ rd)
case Digit(rd, rtail) => Digit(ld ^ rd, rtail)
}
case Digit(ld, ltail) => rhs match {
case End(rd) => Digit(ld ^ rd, ltail)
case Digit(rd, rtail) => Digit(ld ^ rd, ltail ^ rtail)
}
}
xor(lhs, rhs).trim
}
def ^(rhs: UInt): Natural = lhs match {
case End(ld) => End(ld ^ rhs)
case Digit(ld, ltail) => Digit(ld ^ rhs, ltail)
}
}
// TODO: maybe split apply into apply() and fromX()
// this way we can protect end-users from sign problems
object Natural extends NaturalInstances {
private[math] final val denom = UInt(1000000000)
implicit def naturalToBigInt(n: Natural): BigInt = n.toBigInt
// required in big-endian order
def apply(us: UInt*): Natural = {
if (us.isEmpty) throw new IllegalArgumentException("invalid arguments")
us.tail.foldLeft(End(us.head): Natural)((n, u) => Digit(u, n))
}
def apply(n: Long): Natural = if ((n & 0xffffffffL) == n)
End(UInt(n.toInt))
else
Digit(UInt(n.toInt), End(UInt((n >> 32).toInt)))
def apply(n: BigInt): Natural = if (n < 0)
throw new IllegalArgumentException("negative numbers not allowed: %s" format n)
else if (n < 0xffffffffL)
End(UInt(n.toLong))
else
Digit(UInt((n & 0xffffffffL).toLong), apply(n >> 32))
private val ten18 = Natural(1000000000000000000L)
def apply(s: String): Natural = {
def parse(sofar: Natural, s: String, m: Natural): Natural = if (s.length <= 18) {
Natural(s.toLong) * m + sofar
} else {
val p = s.substring(s.length - 18, s.length)
val r = s.substring(0, s.length - 18)
parse(Natural(p.toLong) * m + sofar, r, m * ten18)
}
parse(Natural(0L), s, Natural(1L))
}
val zero: Natural = apply(0L)
val one: Natural = apply(1L)
@SerialVersionUID(0L)
case class Digit(d: UInt, tl: Natural) extends Natural with Serializable {
def digit: UInt = d
def tail: Natural = tl
def +(n: UInt): Natural = if (n == UInt(0)) {
this
} else {
val t = d.toLong + n.toLong
Digit(UInt(t), tail + UInt(t >> 32))
}
def -(n: UInt): Natural = if (n == UInt(0)) {
this
} else {
val t = d.toLong - n.toLong
Digit(UInt(t), tail - UInt(-(t >> 32)))
}
def *(n: UInt): Natural = if (n == UInt(0))
End(n)
else if (n == UInt(1))
this
else
Natural(d.toLong * n.toLong) + Digit(UInt(0), tl * n)
def /(n: UInt): Natural = (this /% n)._1
def %(n: UInt): Natural = (this /% n)._2
def /%(n: UInt): (Natural, Natural) = {
@tailrec
def recur(next: Natural, rem: UInt, sofar: Natural): (Natural, Natural) = {
val t: ULong = ULong(rem.toLong << 32) + ULong(next.digit.toLong)
val q: Long = (t / ULong(n.toLong)).toLong
val r: Long = (t % ULong(n.toLong)).toLong
next match {
case Natural.End(d) => (Digit(UInt(q), sofar), End(UInt(r)))
case Natural.Digit(d, tail) => recur(tail, UInt(r), Digit(UInt(q), sofar))
}
}
if (n == UInt(0)) {
throw new IllegalArgumentException("/ by zero")
} else if (n == UInt(1)) {
(this, Natural(UInt(0)))
} else {
reversed match {
case Digit(d, tail) =>
val q = d / n
val r = d % n
recur(tail, r, End(q))
case _ =>
throw new IllegalArgumentException("bug in reversed")
}
}
}
}
@SerialVersionUID(0L)
case class End(d: UInt) extends Natural with Serializable {
def digit: UInt = d
def +(n: UInt): Natural = if (n == UInt(0)) {
this
} else {
val t = d.toLong + n.toLong
if (t <= 0xffffffffL)
End(UInt(t))
else
Digit(UInt(t), End(UInt(1)))
}
def -(n: UInt): Natural = if (n == UInt(0)) {
this
} else {
val t = d.toLong - n.toLong
if (t >= 0L)
End(UInt(t.toInt))
else
throw new IllegalArgumentException("illegal subtraction: %s %s" format (this, n))
}
def *(n: UInt): Natural = if (n == UInt(0))
End(n)
else if (n == UInt(1))
this
else
Natural(d.toLong * n.toLong)
def /(n: UInt): Natural = if (n == UInt(0))
throw new IllegalArgumentException("/ by zero")
else
End(d / n)
def %(n: UInt): Natural = if (n == UInt(0))
throw new IllegalArgumentException("/ by zero")
else
End(d % n)
def /%(n: UInt): (Natural, Natural) = (this / n, this % n)
}
}
trait NaturalInstances {
implicit final val NaturalAlgebra = new NaturalAlgebra
import NumberTag._
implicit final val NaturalTag = new CustomTag[Natural](
Integral, Some(Natural.zero), Some(Natural.zero), None, false, false)
}
private[math] trait NaturalIsRig extends Rig[Natural] {
def one: Natural = Natural(1L)
def plus(a:Natural, b:Natural): Natural = a + b
override def pow(a:Natural, b:Int): Natural = {
if (b < 0)
throw new IllegalArgumentException("negative exponent: %s" format b)
a pow UInt(b)
}
override def times(a:Natural, b:Natural): Natural = a * b
def zero: Natural = Natural(0L)
}
private[math] trait NaturalOrder extends Order[Natural] {
override def eqv(x: Natural, y: Natural): Boolean = x == y
override def neqv(x: Natural, y: Natural): Boolean = x != y
override def gt(x: Natural, y: Natural): Boolean = x > y
override def gteqv(x: Natural, y: Natural): Boolean = x >= y
override def lt(x: Natural, y: Natural): Boolean = x < y
override def lteqv(x: Natural, y: Natural): Boolean = x <= y
def compare(x: Natural, y: Natural): Int = x.compare(y)
}
private[math] trait NaturalIsSigned extends Signed[Natural] {
def signum(a: Natural): Int = if (a == Natural.zero) 0 else 1
def abs(a: Natural): Natural = a
}
private[math] trait NaturalIsReal extends IsIntegral[Natural]
with NaturalOrder with NaturalIsSigned {
def toDouble(n: Natural): Double = n.toDouble
def toBigInt(n: Natural): BigInt = n.toBigInt
}
@SerialVersionUID(0L)
class NaturalAlgebra extends NaturalIsRig with NaturalIsReal with Serializable
© 2015 - 2025 Weber Informatics LLC | Privacy Policy