All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spire.algebra.NRoot.scala Maven / Gradle / Ivy

The newest version!
package spire.algebra

import spire.math._
import spire.macrosk.Ops

import scala.{specialized => spec, math => mth}
import java.math.MathContext
import java.lang.Math

/**
 * This is a type class for types with n-roots. The value returned by `nroot`
 * and `sqrt` are only guaranteed to be approximate answers (except in the case
 * of `Real`).
 *
 * Also, generally `nroot`s where `n` is even are not defined for
 * negative numbers. The behaviour is undefined if this is attempted. It would
 * be nice to ensure an exception is raised, but some types may defer
 * computation and testing if a value is negative may not be ideal. So, do not
 * count on `ArithmeticException`s to save you from bad arithmetic!
 */
trait NRoot[@spec(Double,Float,Int,Long) A] {
  def nroot(a: A, n: Int): A
  def sqrt(a: A): A = nroot(a, 2)
  def log(a:A):A
  def fpow(a:A, b:A): A
}

final class NRootOps[A](lhs: A)(implicit n: NRoot[A]) {
  def nroot(rhs: Int): A = macro Ops.binop[Int, A]
  def sqrt(): A = macro Ops.unop[A]
  def log(): A = macro Ops.unop[A]
  def fpow(rhs: A): A = macro Ops.binop[A, A]
}

trait DoubleIsNRoot extends NRoot[Double] {
  def nroot(a: Double, k: Int): Double = Math.pow(a, 1 / k.toDouble)
  override def sqrt(a: Double): Double = Math.sqrt(a)
  def log(a: Double) = Math.log(a)
  def fpow(a: Double, b: Double) = Math.pow(a, b)
}

trait FloatIsNRoot extends NRoot[Float] {
  def nroot(a: Float, k: Int): Float = Math.pow(a, 1 / k.toDouble).toFloat
  override def sqrt(a: Float): Float = Math.sqrt(a).toFloat
  def log(a: Float) = Math.log(a).toFloat
  def fpow(a: Float, b: Float) = Math.pow(a, b).toFloat
}


trait RationalIsNRoot extends NRoot[Rational] {
  implicit def context:ApproximationContext[Rational]
  def nroot(a: Rational, k: Int): Rational = a.nroot(k)
  def log(a: Rational): Rational = a.log
  def fpow(a: Rational, b: Rational): Rational = a.pow(b)
}

trait RealIsNRoot extends NRoot[Real] {
  def nroot(a: Real, k: Int): Real = a nroot k
  def log(a:Real) = sys.error("fixme")
  def fpow(a:Real, b:Real) = sys.error("fixme")
}


trait BigDecimalIsNRoot extends NRoot[BigDecimal] {
  def nroot(a: BigDecimal, k: Int): BigDecimal = {
    if (a.mc.getPrecision <= 0)
      throw new ArithmeticException("Cannot find the nroot of a BigDecimal with unlimited precision.")
    NRoot.nroot(a, k, a.mc)
  }
  def log(a:BigDecimal) = fun.log(a)
  def fpow(a:BigDecimal, b:BigDecimal) = fun.pow(a, b)
}


trait IntIsNRoot extends NRoot[Int] {
  def nroot(x: Int, n: Int): Int = {
    def findnroot(prev: Int, add: Int): Int = {
      val next = prev | add
      val e = Math.pow(next, n)

      if (e == x || add == 0) {
        next
      } else if (e <= 0 || e > x) {
        findnroot(prev, add >> 1)
      } else {
        findnroot(next, add >> 1)
      }
    }

    findnroot(0, 1 << ((33 - n) / n))
  }

  def log(a:Int) = Math.log(a.toDouble).toInt
  def fpow(a:Int, b:Int) = Math.pow(a, b).toInt
}

trait LongIsNRoot extends NRoot[Long] {
  def nroot(x: Long, n: Int): Long = {
    def findnroot(prev: Long, add: Long): Long = {
      val next = prev | add
      val e = Math.pow(next, n)

      if (e == x || add == 0) {
        next
      } else if (e <= 0 || e > x) {
        findnroot(prev, add >> 1)
      } else {
        findnroot(next, add >> 1)
      }
    }

    findnroot(0, 1L << ((65 - n) / n))
  }
  def log(a:Long) = Math.log(a.toDouble).toLong
  def fpow(a:Long, b:Long) = fun.pow(a, b) // xyz
}

trait BigIntIsNRoot extends NRoot[BigInt] {
  def nroot(a: BigInt, k: Int): BigInt = if (a < 0 && k % 2 == 1) {
    -nroot(-a, k)
  } else if (a < 0) {
    throw new ArithmeticException("Cannot find %d-root of negative number." format k)
  } else {
    def findNroot(b: BigInt, i: Int): BigInt = if (i < 0) {
      b
    } else {
      val c = b setBit i

      if ((c pow k) <= a)
        findNroot(c, i - 1)
      else
        findNroot(b, i - 1)
    }

    findNroot(0, a.bitLength - 1)
  }
  def log(a:BigInt) = fun.log(BigDecimal(a)).toBigInt
  def fpow(a:BigInt, b:BigInt) = fun.pow(BigDecimal(a), BigDecimal(b)).toBigInt
}

trait SafeLongIsNRoot extends NRoot[SafeLong] {
  import NRoot.{LongIsNRoot, BigIntIsNRoot}

  def nroot(a: SafeLong, k: Int): SafeLong = a.fold(
    n => SafeLong(LongIsNRoot.nroot(n, k)),
    n => SafeLong(BigIntIsNRoot.nroot(n, k))
  )
  def log(a:SafeLong) = a.fold(
    n => SafeLong(LongIsNRoot.log(n)),
    n => SafeLong(BigIntIsNRoot.log(n))
  )

  def fpow(a:SafeLong, b:SafeLong) =
    SafeLong(BigIntIsNRoot.fpow(a.toBigInt, b.toBigInt))
}

object NRoot {
  @inline final def apply[@spec(Int,Long,Float,Double) A](implicit ev:NRoot[A]) = ev

  implicit object IntIsNRoot extends IntIsNRoot
  implicit object LongIsNRoot extends LongIsNRoot
  implicit object BigIntIsNRoot extends BigIntIsNRoot
  implicit object SafeLongIsNRoot extends SafeLongIsNRoot

  implicit object FloatIsNRoot extends FloatIsNRoot
  implicit object DoubleIsNRoot extends DoubleIsNRoot
  implicit object BigDecimalIsNRoot extends BigDecimalIsNRoot

  implicit def rationalIsNRoot(implicit c:ApproximationContext[Rational]) = new RationalIsNRoot {
    implicit def context = c
  }

  implicit object RealIsNRoot extends RealIsNRoot

  /**
   * This will return the largest integer that meets some criteria. Specifically,
   * if we're looking for some integer `x` and `f(x')` is guaranteed to return
   * `true` iff `x' <= x`, then this will return `x`.
   *
   * This can be used, for example, to find an integer `x` s.t. 
   * `x * x < y < (x+1)*(x+1)`, by using `intSearch(x => x * x <= y)`.
   */
  private def intSearch(f: Int => Boolean): Int = {
    val ceil = (0 until 32) find (i => !f(1 << i)) getOrElse 33
    if (ceil == 0) {
      0
    } else {
      (0 /: ((ceil - 1) to 0 by -1)) { (x, i) =>
        val y = x | (1 << i)
        if (f(y)) y else x
      }
    }
  }


  /**
   * Returns the digits to the right of the decimal point of `x / y` in base
   * `r` if x < y.
   */
  private def decDiv(x: BigInt, y: BigInt, r: Int): Stream[BigInt] = {
    val expanded = x * r
    val quot = expanded / y
    val rem = expanded - (quot * y)

    if (rem == 0) {
      Stream.cons(quot, Stream.empty)
    } else {
      Stream.cons(quot, decDiv(rem, y, r))
    }
  }


  /** Returns the digits of `x` in base `r`. */
  private def digitize(x: BigInt, r: Int, prev: List[Int] = Nil): List[Int] =
    if (x == 0) prev else digitize(x / r, r, (x % r).toInt :: prev)


  /** Converts a list of digits in base `r` to a `BigInt`. */
  private def undigitize(digits: Seq[Int], r: Int): BigInt =
    (BigInt(0) /: digits)(_ * r + _)


  // 1 billion: because it's the largest positive Int power of 10.
  private val radix = 1000000000


  /**
   * An implementation of the shifting n-th root algorithm for BigDecimal. For
   * the BigDecimal a, this is guaranteed to be accurate up to the precision
   * specified in ctxt.
   *
   * See http://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
   *
   * @param a A (positive if k % 2 == 0) `BigDecimal`.
   * @param k A positive `Int` greater than 1.
   * @param ctxt The `MathContext` to bound the precision of the result.
   *
   * returns A `BigDecimal` approximation to the `k`-th root of `a`.
   */
  def nroot(a: BigDecimal, k: Int, ctxt: MathContext): BigDecimal = if (k == 0) {
    BigDecimal(1)
  } else if (a.signum < 0) {
    if (k % 2 == 0) {
      throw new ArithmeticException("%d-root of negative number" format k)
    } else {
      -nroot(-a, k, ctxt)
    }
  } else {
    val underlying = BigInt(a.bigDecimal.unscaledValue.toByteArray)
    val scale = BigInt(10) pow a.scale
    val intPart = digitize(underlying / scale, radix)
    val fracPart = decDiv(underlying % scale, scale, radix) map (_.toInt)
    val leader = if (intPart.size % k == 0) Stream.empty else {
      Stream.fill(k - intPart.size % k)(0)
    }
    val digits = leader ++ intPart.toStream ++ fracPart ++ Stream.continually(0)
    val radixPowK = BigInt(radix) pow k

    // Total # of digits to compute.
    // Note: I originally had `+ 1` here, but some edge cases were missed, so now
    // it is `+ 2`.
    val maxSize = (ctxt.getPrecision + 8) / 9 + 2

    def findRoot(digits: Stream[Int], y: BigInt, r: BigInt, i: Int): (Int, BigInt) = {
      val y_ = y * radix
      val a = undigitize(digits take k, radix)
      // Note: target grows quite fast (so I imagine (y_ + b) pow k does too).
      val target = radixPowK * r + a + (y_ pow k)
      val b = intSearch(b => ((y_ + b) pow k) <= target)

      val ny = y_ + b

      if (i == maxSize) {
        (i, ny)
      } else {
        val nr = target - (ny pow k)
        
        // TODO: Add stopping condition for when nr == 0 and there are no more
        // digits. Tricky part is refactoring to know when digits end...

        findRoot(digits drop k, ny, nr, i + 1)
      }
    }

    val (size, unscaled) = findRoot(digits, 0, 0, 1)
    val newscale = (size - (intPart.size + k - 1) / k) * 9
    BigDecimal(unscaled, newscale, ctxt)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy