All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spire.math.Algebraic.scala Maven / Gradle / Ivy

package spire.math

import java.lang.Long.numberOfLeadingZeros
import java.lang.Double.{ isInfinite, isNaN }
import java.math.{ MathContext, RoundingMode, BigInteger, BigDecimal => JBigDecimal }
import java.util.concurrent.atomic.AtomicReference

import scala.annotation.tailrec
import scala.math.{ ScalaNumber, ScalaNumericConversions }
import scala.reflect.ClassTag

import spire.Platform
import spire.algebra.{Eq, EuclideanRing, Field, IsAlgebraic, NRoot, Order, Ring, Sign, Signed}
import spire.algebra.Sign.{ Positive, Negative, Zero }
import spire.macros.Checked.checked
import spire.math.poly.{ Term, BigDecimalRootRefinement, RootFinder, Roots }
import spire.std.bigInt._
import spire.std.bigDecimal._
import spire.std.long._
import spire.syntax.std.seq._

/**
 * Algebraic provides an exact number type for algebraic numbers. Algebraic
 * numbers are roots of polynomials with rational coefficients. With it, we can
 * represent expressions involving addition, multiplication, division, n-roots
 * (eg. `sqrt` or `cbrt`), and roots of rational polynomials. So, it is similar
 * [[Rational]], but adds roots as a valid, exact operation. The cost is that
 * this will not be as fast as [[Rational]] for many operations.
 *
 * In general, you can assume all operations on this number type are exact,
 * except for those that explicitly construct approximations to an Algebraic
 * number, such as `toBigDecimal`.
 *
 * For an overview of the ideas, algorithms, and proofs of this number type,
 * you can read the following papers:
 *
 *  - "On Guaranteed Accuracy Computation." C. K. Yap.
 *  - "Recent Progress in Exact Geometric Computation." C. Li, S. Pion, and C. K. Yap.
 *  - "A New Constructive Root Bound for Algebraic Expressions" by C. Li & C. K. Yap.
 *  - "A Separation Bound for Real Algebraic Expressions." C. Burnikel, et al.
 */
@SerialVersionUID(1L)
final class Algebraic private (val expr: Algebraic.Expr)
extends ScalaNumber with ScalaNumericConversions with Serializable {
  import Algebraic.{ Zero, One, Expr, MinIntValue, MaxIntValue, MinLongValue, MaxLongValue, JBigDecimalOrder, roundExact, BFMSS, LiYap }

  /**
   * Returns an `Int` with the same sign as this algebraic number. Algebraic
   * numbers support exact sign tests, so this is guaranteed to be accurate.
   */
  def signum: Int = expr.signum

  /**
   * Returns the sign of this Algebraic number. Algebraic numbers support exact
   * sign tests, so this is guaranteed to be accurate.
   */
  def sign: Sign = Sign(signum)

  /**
   * Return a non-negative `Algebraic` with the same magnitude as this one.
   */
  def abs: Algebraic =
    if (this.signum < 0) -this else this

  def unary_- : Algebraic =
    new Algebraic(Expr.Neg(expr))

  def +(that: Algebraic): Algebraic =
    new Algebraic(Expr.Add(this.expr, that.expr))

  def -(that: Algebraic): Algebraic =
    new Algebraic(Expr.Sub(this.expr, that.expr))

  def *(that: Algebraic): Algebraic =
    new Algebraic(Expr.Mul(this.expr, that.expr))

  def /(that: Algebraic): Algebraic =
    new Algebraic(Expr.Div(this.expr, that.expr))

  /**
   * Returns an `Algebraic` whose value is just the integer part of
   * `this / that`. This operation is exact.
   */
  def quot(that: Algebraic): Algebraic =
    this /~ that

  /** An alias for [[quot]]. */
  def /~(that: Algebraic): Algebraic =
    Algebraic((this / that).toBigInt)

  /**
   * Returns an `Algebraic` whose value is the difference between `this` and
   * `(this /~ that) * that` -- the modulus.
   */
  def mod(that: Algebraic): Algebraic =
    this % that

  /** An alias for [[mod]]. */
  def %(that: Algebraic): Algebraic =
    this - (this /~ that) * that

  /** Returns the square root of this number. */
  def sqrt: Algebraic = nroot(2)

  /** Returns the cube root of this number. */
  def cbrt: Algebraic = nroot(3)

  /** Returns the `k`-th root of this number. */
  def nroot(k: Int): Algebraic = if (k < 0) {
    new Algebraic(Expr.Div(Expr.ConstantLong(1), Expr.KRoot(this.expr, -k)))
  } else if (k > 0) {
    new Algebraic(Expr.KRoot(this.expr, k))
  } else {
    throw new ArithmeticException("divide by zero (0-root)")
  }

  /** Raise this number to the `k`-th power. */
  def pow(k: Int): Algebraic =
    if (k == Int.MinValue) {
      throw new ArithmeticException(s"illegal exponent (${Int.MinValue})")
    } else if (k == 0) {
      if (signum == 0) {
        throw new ArithmeticException("undeterminate result (0^0)")
      } else {
        One
      }
    } else if (k == 1) {
      this
    } else if (k < 0) {
      new Algebraic(Expr.Div(Expr.ConstantLong(1), this.pow(-k).expr))
    } else {
      new Algebraic(Expr.Pow(this.expr, k))
    }

  def <  (that: Algebraic): Boolean = compare(that) <  0
  def >  (that: Algebraic): Boolean = compare(that) >  0
  def <= (that: Algebraic): Boolean = compare(that) <= 0
  def >= (that: Algebraic): Boolean = compare(that) >= 0

  /**
   * Returns an integer with the same sign as `this - that`. Specifically, if
   * `this < that`, then the sign is negative, if `this > that`, then the
   * sign is positive, otherwise `this == that` and this returns 0.
   */
  def compare(that: Algebraic): Int = (this - that).signum

  /**
   * Returns `true` iff this Algebraic number is exactly 0.
   */
  def isZero: Boolean = signum == 0

  override def equals(that: Any): Boolean = that match {
    case (that: Algebraic) => this === that
    case (that: Real) => this.toReal == that
    case (that: Number) => this.compare(Algebraic(that.toBigDecimal)) == 0
    case (that: Rational) => this.compare(Algebraic(that)) == 0
    case (that: BigInt) => isWhole && toBigInt == that
    case (that: Natural) => isWhole && signum >= 0 && that == toBigInt
    case (that: SafeLong) => isWhole && that == this
    case (that: Complex[_]) => that == this
    case (that: Quaternion[_]) => that == this
    case (that: BigDecimal) => try {
      toBigDecimal(that.mc) == that
    } catch {
      case ae: ArithmeticException => false
    }
    case _ => unifiedPrimitiveEquals(that)
  }

  def ===(that: Algebraic): Boolean =
    this.compare(that) == 0

  def =!=(that: Algebraic): Boolean =
    !(this === that)

  override def hashCode: Int = if (isWhole && isValidLong) {
    unifiedPrimitiveHashcode
  } else {
    val x = toBigDecimal(java.math.MathContext.DECIMAL64)
    x.underlying.unscaledValue.hashCode + 23 * x.scale.hashCode + 17
  }

  def toExprString: String = {
    import Expr._

    def recur(e: Expr): String = e match {
      case ConstantLong(n) => n.toString
      case ConstantDouble(n) => n.toString
      case ConstantBigDecimal(n) => n.toString
      case ConstantRational(n) => s"(${n})"
      case ConstantRoot(poly, i, _, _) => s"root($poly, $i)"
      case Neg(sub) => s"-$sub"
      case Add(lhs, rhs) => s"(${recur(lhs)}) + (${recur(rhs)})"
      case Sub(lhs, rhs) => s"(${recur(lhs)}) - (${recur(rhs)})"
      case Mul(lhs, rhs) => s"(${recur(lhs)}) * (${recur(rhs)})"
      case Div(lhs, rhs) => s"(${recur(lhs)}) / (${recur(rhs)})"
      case KRoot(sub, 2) => s"(${recur(sub)}).sqrt"
      case KRoot(sub, 3) => s"(${recur(sub)}).cbrt"
      case KRoot(sub, k) => s"(${recur(sub)}).nroot($k)"
      case Pow(sub, k) => s"${recur(sub)}.pow(k)"
    }

    recur(expr)
  }

  override def toString: String = {
    val approx = toBigDecimal(MathContext.DECIMAL64)
    if (this == Algebraic(approx)) {
      if (approx.signum == 0) {
        "Algebraic(0)"
      } else {
        s"Algebraic(${approx.bigDecimal.stripTrailingZeros})"
      }
    } else {
      s"Algebraic(~$approx)"
    }
  }

  /**
   * Returns the nearest, valid `Int` value to this Algebraic, without going
   * further away from 0 (eg. truncation).
   *
   * If this `Algebraic` represented 1.2, then this would return 1. If this
   * represented -3.3, then this would return -3. If this value is greater than
   * `Int.MaxValue`, then `Int.MaxValue` is returned. If this value is less
   * than `Int.MinValue`, then `Int.MinValue` is returned.
   */
  def intValue: Int = {
    val n = toBigInt
    if (n < MinIntValue) Int.MinValue
    else if (n > MaxIntValue) Int.MaxValue
    else n.intValue
  }

  /**
   * Returns the nearest, valid `Long` value to this Algebraic, without going
   * further away from 0 (eg. truncation).
   *
   * If this `Algebraic` represented 1.2, then this would return 1. If this
   * represented -3.3, then this would return -3. If this value is greater than
   * `Long.MaxValue`, then `Long.MaxValue` is returned. If this value is less
   * than `Long.MinValue`, then `Long.MinValue` is returned.
   */
  def longValue: Long = {
    val n = toBigInt
    if (n < MinLongValue) Long.MinValue
    else if (n > MaxLongValue) Long.MaxValue
    else n.longValue
  }

  /**
   * Returns a `Float` that approximates this value. If the exponent is too
   * large to fit in a float, the `Float.PositiveInfinity` or
   * `Float.NegativeInfinity` is returned.
   */
  def floatValue: Float = toBigDecimal(MathContext.DECIMAL32).toFloat

  /**
   * Returns a `Double` that approximates this value. If the exponent is too
   * large to fit in a double, the `Double.PositiveInfinity` or
   * `Double.NegativeInfinity` is returned.
   */
  def doubleValue: Double = toBigDecimal(MathContext.DECIMAL64).toDouble

  /**
   * Returns the nearest, valid `BigInt` value to this Algebraic, without going
   * further away from 0 (eg. truncation).
   *
   * If this `Algebraic` represented 1.2, then this would return 1. If this
   * represented -3.3, then this would return -3.
   */
  def toBigInt: BigInt =
    toBigDecimal(0, RoundingMode.DOWN).toBigInt

  /**
   * Absolute approximation to `scale` decimal places with the given rounding
   * mode. Rounding is always exact.
   */
  def toBigDecimal(scale: Int, roundingMode: RoundingMode): BigDecimal =
    BigDecimal(roundExact(this, expr.toBigDecimal(scale + 2), scale, roundingMode))

  /**
   * Relative approximation to the precision specified in `mc` with the given
   * rounding mode. Rounding is always exact. The sign is always correct; the
   * sign of the returned `BigDecimal` matches the sign of the exact value this
   * `Algebraic` represents.
   *
   * @param mc the precision and rounding mode of the final result
   * @return an approximation to the value of this algebraic number
   */
  def toBigDecimal(mc: MathContext): BigDecimal = {
    import Expr._

    val roundingMode = mc.getRoundingMode

    def rec(e: Expr, digits: Int): JBigDecimal = e match {
      case ConstantLong(n) =>
        new JBigDecimal(n, new MathContext(digits, roundingMode))
      case ConstantDouble(n) =>
        new JBigDecimal(n, new MathContext(digits, roundingMode))
      case ConstantBigDecimal(n) =>
        n.bigDecimal.round(new MathContext(digits, roundingMode))
      case ConstantRational(n) =>
        val num = new JBigDecimal(n.numerator.toBigInteger)
        val den = new JBigDecimal(n.denominator.toBigInteger)
        num.divide(den, new MathContext(digits, roundingMode))
      case ConstantRoot(poly, _, lb, ub) =>
        // Ugh - on an airplane and can't trust BigDecimal's constructors.
        val poly0 = poly.map { n => new BigDecimal(new JBigDecimal(n.bigInteger), MathContext.UNLIMITED) }
        BigDecimalRootRefinement(poly0, lb, ub, new MathContext(digits, roundingMode)).approximateValue
      case Neg(sub) =>
        rec(sub, digits).negate()
      case Add(_, _) | Sub(_, _) if e.signum == 0 =>
        JBigDecimal.ZERO
      case Add(lhs, rhs) =>
        val digits0 = digits + e.separationBound.decimalDigits.toInt + 1
        val lValue = rec(lhs, digits0)
        val rValue = rec(rhs, digits0)
        lValue.add(rValue, new MathContext(digits, roundingMode))
      case Sub(lhs, rhs) =>
        val digits0 = digits + e.separationBound.decimalDigits.toInt + 1
        val lValue = rec(lhs, digits0)
        val rValue = rec(rhs, digits0)
        lValue.subtract(rValue, new MathContext(digits, roundingMode))
      case Mul(lhs, rhs) =>
        val lValue = rec(lhs, digits + 1)
        val rValue = rec(rhs, digits + 2)
        lValue.multiply(rValue, new MathContext(digits, roundingMode))
      case Div(lhs, rhs) =>
        val rValue = rec(rhs, digits + 2)
        if (rValue.compareTo(JBigDecimal.ZERO) == 0)
          throw new ArithmeticException("divide by zero")
        val lValue = rec(lhs, digits + 2)
        lValue
          .divide(rValue, new MathContext(digits + 2, roundingMode))
          .round(new MathContext(digits, roundingMode))
      case KRoot(sub, k) =>
        Algebraic.nroot(rec(sub, digits + 2), k, new MathContext(digits + 2, roundingMode))
          .round(new MathContext(digits, roundingMode))
      case Pow(sub, k) =>
        val subValue = rec(sub, digits + ceil(log(k.toDouble)).toInt)
        subValue.pow(k, new MathContext(digits, roundingMode))
    }
    val approx = rec(expr, mc.getPrecision + 2)
    val newScale = approx.scale - approx.precision + mc.getPrecision
    val adjustedApprox =
      if (newScale <= approx.scale) approx.setScale(newScale + 1, RoundingMode.DOWN)
      else approx
    roundExact(this, adjustedApprox, newScale, roundingMode)
      .round(mc) // We perform a final round, since roundExact uses scales.
  }

  /**
   * Returns `true` iff this Algebraic exactly represents a valid `BigInt`.
   */
  def isWhole: Boolean = this == Algebraic(toBigInt)

  /**
   * Returns `true` if this Algebraic number is a whole number (no fractional
   * part) and fits within the bounds of an `Int`. That is, if `x.isValidInt`,
   * then `Algebraic(x.toInt) == x`.
   */
  override def isValidInt: Boolean = {
    val n = toBigInt
    (n <= MaxIntValue) &&
    (n >= MinIntValue) &&
    (this == Algebraic(n))
  }

  /**
   * Returns `true` if this Algebraic number is a whole number (no fractional
   * part) and fits within the bounds of an `Long`. That is, if `x.isValidLong`,
   * then `Algebraic(x.toLong) == x`.
   */
  def isValidLong: Boolean = {
    val n = toBigInt
    (n <= MaxLongValue) &&
    (n >= MinLongValue) &&
    (this == Algebraic(n))
  }

  /**
   * Returns `true` iff this is a rational expression (ie contains no n-root
   * expressions). Otherwise it is a radical expression and returns false.
   */
  def isRational: Boolean = expr.flags.isRational

  /**
   * If this is a rational expressions, then it returns the exact value as a
   * [[Rational]]. Otherwise, this is a radical expression and `None` is
   * returned.
   */
  def toRational: Option[Rational] =
    if (expr.flags.isRational) {
      implicit val nroot: NRoot[Rational] with RootFinder[Rational] =
        new NRoot[Rational] with RootFinder[Rational] {
          private def fail =
            throw new ArithmeticException(s"Rational cannot support exact algebraic operations")
          def nroot(a: Rational, n: Int): Rational = fail
          def fpow(a:Rational, b:Rational): Rational = fail
          def findRoots(poly: Polynomial[Rational]): Roots[Rational] = fail
        }
      Some(evaluateWith[Rational])
    } else {
      None
    }

  /**
   * Evaluates this algebraic expression with a different number type. All
   * `Algebraic` numbers store the entire expression tree, so we can use this
   * to *replay* the stored expression using a different type. This will
   * accumulate errors as if the number type had been used from the beginning
   * and is only really suitable for more exact number types, like [[Real]].
   *
   * TODO: Eq/ClassTag come from poly.map - would love to get rid of them.
   */
  def evaluateWith[A: Field: NRoot: RootFinder: Eq: ClassTag](implicit conv: ConvertableTo[A]): A = {
    import spire.syntax.field._
    import spire.syntax.nroot._
    import Expr._

    def eval(e: Expr): A = e match {
      case ConstantLong(n) => conv.fromLong(n)
      case ConstantDouble(n) => conv.fromDouble(n)
      case ConstantBigDecimal(n) => conv.fromBigDecimal(n)
      case ConstantRational(n) => conv.fromRational(n)
      case ConstantRoot(poly, i, _, _) =>
        RootFinder[A].findRoots(poly.map(conv.fromBigInt)).get(i)
      case Neg(n) => -eval(n)
      case Add(a, b) => eval(a) + eval(b)
      case Sub(a, b) => eval(a) - eval(b)
      case Mul(a, b) => eval(a) * eval(b)
      case Div(a, b) => eval(a) / eval(b)
      case KRoot(a, k) => eval(a).nroot(k)
      case Pow(a, k) => eval(a).pow(k)
    }

    eval(expr)
  }

  /**
   * Returns an exact [[Real]] representation of this number.
   */
  def toReal: Real = evaluateWith[Real]

  // ScalaNumber. Because of course all Scala numbers are wrappers.
  def underlying: AnyRef = this
}

object Algebraic extends AlgebraicInstances {

  /** Returns an Algebraic expression equal to 0. */
  val Zero: Algebraic = new Algebraic(Expr.ConstantLong(0))

  /** Returns an Algebraic expression equal to 1. */
  val One: Algebraic = new Algebraic(Expr.ConstantLong(1))

  /** Returns an Algebraic expression equivalent to `n`. */
  implicit def apply(n: Int): Algebraic =
    new Algebraic(Expr.ConstantLong(n))

  /** Returns an Algebraic expression equivalent to `n`. */
  def apply(n: Long): Algebraic =
    new Algebraic(Expr.ConstantLong(n))

  /**
   * Returns an Algebraic expression equivalent to `n`, if `n` is finite. If
   * `n` is either infinite or `NaN`, then an `IllegalArgumentException` is
   * thrown.
   */
  def apply(n: Float): Algebraic =
    Algebraic(n.toDouble)

  /**
   * Returns an Algebraic expression equivalent to `n`, if `n` is finite. If
   * `n` is either infinite or `NaN`, then an `IllegalArgumentException` is
   * thrown.
   */
  implicit def apply(n: Double): Algebraic =
    if (java.lang.Double.isInfinite(n)) {
      throw new IllegalArgumentException("cannot construct inifinite Algebraic")
    } else if (java.lang.Double.isNaN(n)) {
      throw new IllegalArgumentException("cannot construct Algebraic from NaN")
    } else {
      new Algebraic(Expr.ConstantDouble(n))
    }

  /** Returns an Algebraic expression equivalent to `n`. */
  def apply(n: BigInt): Algebraic =
    new Algebraic(Expr.ConstantBigDecimal(BigDecimal(n)))

  /** Returns an Algebraic expression equivalent to `n`. */
  def apply(n: BigDecimal): Algebraic =
    new Algebraic(Expr.ConstantBigDecimal(n))

  /** Returns an Algebraic expression equivalent to `n`. */
  def apply(n: Rational): Algebraic =
    new Algebraic(Expr.ConstantRational(n))

  /**
   * Returns an Algebraic expression whose value is equivalent to the `i`-th
   * real root of the [[Polynomial]] `poly`. If `i` is negative or does not an
   * index a real root (eg the value is greater than or equal to the number of
   * real roots) then an `ArithmeticException` is thrown. Roots are indexed
   * starting at 0.  So if there are 3 roots, then they are indexed as 0, 1,
   * and 2.
   *
   * @param poly the polynomial containing at least i real roots
   * @param i    the index (0-based) of the root
   * @return an algebraic whose value is the i-th root of the polynomial
   */
  def root(poly: Polynomial[Rational], i: Int): Algebraic = {
    if (i < 0) {
      throw new ArithmeticException(s"invalid real root index: $i")
    } else {
      val zpoly = Roots.removeFractions(poly)
      val intervals = Roots.isolateRoots(zpoly)
      if (i >= intervals.size) {
        throw new ArithmeticException(s"cannot extract root $i, there are only ${intervals.size} roots")
      }
      intervals(i) match {
        case Point(value) =>
          new Algebraic(Expr.ConstantRational(value))
        case Bounded(lb, ub, _) =>
          new Algebraic(Expr.ConstantRoot(zpoly, i, lb, ub))
        case _ =>
          throw new RuntimeException("invalid isolated root interval")
      }
    }
  }

  /**
   * Returns all of the real roots of the given polynomial, in order from
   * smallest to largest.
   *
   * @param poly the polynomial to return the real roots of
   * @return all the real roots of `poly`
   */
  def roots(poly: Polynomial[Rational]): Vector[Algebraic] = {
    val zpoly = Roots.removeFractions(poly)
    val intervals = Roots.isolateRoots(zpoly)
    intervals.zipWithIndex map {
      case (Point(value), _) =>
        new Algebraic(Expr.ConstantRational(value))
      case (Bounded(lb, ub, _), i) =>
        new Algebraic(Expr.ConstantRoot(zpoly, i, lb, ub))
      case x =>
        throw new RuntimeException(s"invalid isolated root interval: $x")
    }
  }

  /**
   * Returns an Algebraic whose value is the real root within (lb, ub). This is
   * potentially unsafe, as we assume that exactly 1 real root lies within the
   * interval, otherwise the results are undetermined.
   *
   * @param poly a polynomial with a real root within (lb, ub)
   * @param i    the index of the root in the polynomial
   * @param lb   the lower bound of the open interval containing the root
   * @param ub   the upper bound of the open interval containing the root
   */
  def unsafeRoot(poly: Polynomial[BigInt], i: Int, lb: Rational, ub: Rational): Algebraic =
    new Algebraic(Expr.ConstantRoot(poly, i, lb, ub))

  /**
   * Returns an Algebraic expression equivalent to `BigDecimal(n)`. If `n` is
   * not parseable as a `BigDecimal` then an exception is thrown.
   */
  def apply(n: String): Algebraic =
    Algebraic(BigDecimal(new JBigDecimal(n)))

  /**
   * The [[Algebraic]] expression AST. `Algebraic` simply stores an expression
   * tree representing all operations performed on it. We then use this tree to
   * deduce certain properties about the algebraic expression and use them to
   * perform exact sign tests, compute approximations, etc.
   *
   * Generally, this should be regarded as an internal implementation detail of
   * `Algebraic`.
   */
  sealed abstract class Expr extends Serializable {
    import Expr._

    protected def flagBits: Int

    /**
     * A set of flags we can quickly compute for an [[Algebraic]] expression.
     *
     * @note we have to do this round-about trip between flagsBits and flags
     * because of
     */
    def flags: Flags = new Flags(flagBits)

    private val bounds: Platform.TrieMap[ZeroBoundFunction, Any] = Platform.TrieMap()

    /**
     * Returns the bound for `zbf`, using a cached value if it is available.
     */
    def getBound(zbf: ZeroBoundFunction): zbf.Bound =
      bounds.getOrElseUpdate(zbf, zbf(this)).asInstanceOf[zbf.Bound]

    @volatile
    private var cachedDegreeBound: Long = 0L

    private def radicalNodes(): Set[KRoot] = {
      val childRadicals = children.foldLeft(Set.empty[KRoot]) { (acc, child) =>
        acc ++ child.radicalNodes()
      }
      val radicals = this match {
        case expr @ KRoot(sub, k) =>
          childRadicals + expr
        case _ =>
          childRadicals
      }
      if (cachedDegreeBound == 0L) {
        cachedDegreeBound = radicals.foldLeft(1L) { (acc, kroot) =>
          checked { acc * kroot.k }
        }
      }
      radicals
    }

    /**
     * Returns a bound on the degree of this expression.
     */
    def degreeBound: Long = {
      if (cachedDegreeBound == 0L)
        radicalNodes()
      cachedDegreeBound
    }

    /**
     * Returns the BFMSS separation bound.
     */
    def bfmssBound: BitBound =
      new BitBound(getBound(BFMSS).getBitBound(degreeBound))

    /**
     * Returns the Li & Yap separation bound.
     */
    def liYapBound: BitBound =
      new BitBound(getBound(LiYap).getBitBound(degreeBound))

    /**
     * Returns a separation bound for this expression as a bit bound. A
     * separation bound is a lower-bound on the value of this expression that
     * is only valid if this expression is not 0. This bound can thus be used
     * to determine if this value is actually 0 and, if not, the sign, by
     * simply approximating the expression with enough accuracy that it falls
     * on one side or the other of the separation bound.
     */
    def separationBound: BitBound =
      bfmssBound min liYapBound

    /**
     * Returns an asbolute approximation to this expression as a BigDecimal
     * that is accurate up to +/- 10^-digits.
     */
    def toBigDecimal(digits: Int): JBigDecimal

    /**
     * Returns an upper bound on the absolute value of this expression as a
     * bit bound.
     */
    def upperBound: BitBound

    /**
     * Returns a lower bound on the absolute value of this expression as a
     * bit bound.
     *
     * TODO: We could do better here wrt to addition (need a fastSignum: Option[Int])
     */
    def lowerBound: BitBound = -separationBound

    /** Returns an integer with the same sign as this expression. */
    def signum: Int

    /**
     * Returns a list of the children of this expression. A child is a
     * sub-expression required by this expression. For instance, `Add` has 2
     * children, the left-hand and right-hand side sub-expressions. A numeric
     * literal expression, such as `ConstantDouble` or `ConstantRational` has
     * no children.
     */
    def children: List[Expr]
  }

  object Expr {

    /**
     * A set of flags for algebraic expressions, so we can quickly determine
     * some properties, like whether the expression is rational, radical, what
     * types of leaf nodes it has, etc. This is used to help guide algorithmic
     * choices, such as what separation bound to use.
     */
    final class Flags(val bits: Int) extends AnyVal {
      import Flags._

      /** Returns the union of flags `this` and `that`. */
      def | (that: Flags): Flags = new Flags(bits | that.bits)

      private def check(n: Int): Boolean = (bits & n) != 0

      /** Returns `true` iff this expression is composed only of rational operations. */
      def isRational: Boolean = !isRadical

      /** Returns `true` iff this expression contains an n-th root operation. */
      def isRadical: Boolean = check(RadicalFlag)

      /** Returns `true` iff this expression contains a `ConstantDouble` leaf node. */
      def hasDoubleLeaf: Boolean = check(HasDoubleLeaf)

      /** Returns `true` iff this expression contains a `ConstantBigDecimal` leaf node. */
      def hasBigDecimalLeaf: Boolean = check(HasBigDecimalLeaf)

      /** Returns `true` iff this expression contains a `ConstantRational` leaf node. */
      def hasRationalLeaf: Boolean = check(HasRationalLeaf)
    }

    object Flags {
      final val RadicalFlag = 1
      final val HasDoubleLeaf = 2
      final val HasBigDecimalLeaf = 4
      final val HasRationalLeaf = 8

      final val IntegerLeaf: Flags = new Flags(0)
      final val DoubleLeaf: Flags = new Flags(HasDoubleLeaf)
      final val BigDecimalLeaf: Flags = new Flags(HasBigDecimalLeaf)
      final val RationalLeaf: Flags = new Flags(HasRationalLeaf)
      final val IsRadical: Flags = new Flags(RadicalFlag)
    }

    /** Constant expressions are leaf nodes, contains literal numbers. */
    sealed abstract class Constant[A] extends Expr {
      def value: A
      def children: List[Expr] = Nil
    }

    /** Unary expressions contain only a single child expression. */
    sealed abstract class UnaryExpr extends Expr {
      val sub: Expr
      def children: List[Expr] = sub :: Nil
    }

    /** Binary expressions contain 2 child expression. */
    sealed abstract class BinaryExpr extends Expr {
      val lhs: Expr
      val rhs: Expr
      val flagBits: Int = (lhs.flags | rhs.flags).bits
      def children: List[Expr] = lhs :: rhs :: Nil
    }

    @SerialVersionUID(0L)
    case class ConstantLong(value: Long) extends Constant[Long] {
      def flagBits: Int = Flags.IntegerLeaf.bits

      def upperBound: BitBound =
        if (value == 0L) new BitBound(0L)
        else if (value == Long.MinValue) new BitBound(64)
        else new BitBound(64 - numberOfLeadingZeros(abs(value) - 1))

      def signum: Int = value.signum

      def toBigDecimal(digits: Int): JBigDecimal =
        new JBigDecimal(value).setScale(digits, RoundingMode.HALF_UP)
    }

    @SerialVersionUID(0L)
    case class ConstantDouble(value: Double) extends Constant[Double] {
      def flagBits: Int = Flags.DoubleLeaf.bits

      def upperBound: BitBound = if (value == 0D) {
        new BitBound(0)
      } else {
        new BitBound(ceil(log(abs(value))).toLong)
      }

      def signum: Int =
        if (value < 0D) -1
        else if (value > 0D) 1
        else 0

      def toBigDecimal(digits: Int): JBigDecimal =
        new JBigDecimal(value).setScale(digits, RoundingMode.HALF_UP)
    }

    @SerialVersionUID(0L)
    case class ConstantBigDecimal(value: BigDecimal) extends Constant[BigDecimal] {
      def flagBits: Int = Flags.BigDecimalLeaf.bits

      def upperBound: BitBound = if (value.signum == 0) {
        new BitBound(0)
      } else {
        // We just need a couple of digits, really.
        val mc = new MathContext(4, RoundingMode.UP)
        new BitBound(ceil(log(value.abs(mc))).toLong)
      }

      def signum: Int = value.signum

      def toBigDecimal(digits: Int): JBigDecimal =
        value.bigDecimal.setScale(digits, RoundingMode.HALF_UP)
    }

    @SerialVersionUID(0L)
    case class ConstantRational(value: Rational) extends Constant[Rational] {
      def flagBits: Int = Flags.RationalLeaf.bits

      def upperBound: BitBound =
        new BitBound(value.numerator.abs.bitLength - value.denominator.bitLength + 1)

      def signum: Int = value.signum

      def toBigDecimal(digits: Int): JBigDecimal = {
        val num = new JBigDecimal(value.numerator.toBigInteger)
        val den = new JBigDecimal(value.denominator.toBigInteger)
        num.divide(den, digits, RoundingMode.DOWN)
      }
    }

    @SerialVersionUID(0L)
    case class ConstantRoot(poly: Polynomial[BigInt], i: Int, lb: Rational, ub: Rational) extends Constant[Polynomial[BigInt]] {
      def value: Polynomial[BigInt] = poly

      def flagBits: Int = Flags.IsRadical.bits

      def upperBound: BitBound =
        if (ub.signum > 0) {
          new BitBound(ub.numerator.bitLength - ub.denominator.bitLength + 1)
        } else {
          new BitBound(lb.numerator.abs.bitLength - lb.denominator.bitLength + 1)
        }

      def signum: Int =
        if (lb.signum != 0) lb.signum
        else ub.signum

      private val refinement: AtomicReference[BigDecimalRootRefinement] = {
        val poly0 = poly.map { n => new BigDecimal(new JBigDecimal(n.bigInteger), MathContext.UNLIMITED) }
        new AtomicReference(BigDecimalRootRefinement(poly0, lb, ub))
      }

      def toBigDecimal(digits: Int): JBigDecimal = {
        val oldRefinement = refinement.get
        val newRefinement = oldRefinement.refine(digits)
        refinement.set(newRefinement)
        newRefinement.approximateValue
      }

      def lead: BigInt = poly.maxTerm.coeff
      def tail: BigInt = poly.minTerm.coeff
    }

    @SerialVersionUID(0L)
    case class Neg(sub: Expr) extends UnaryExpr {
      def flagBits: Int = sub.flags.bits
      def upperBound: BitBound = sub.upperBound
      def signum: Int = -sub.signum
      def toBigDecimal(digits: Int): JBigDecimal =
        sub.toBigDecimal(digits).negate()
    }

    @SerialVersionUID(0L)
    sealed abstract class AddOrSubExpr extends BinaryExpr {
      def upperBound: BitBound =
        new BitBound(max(lhs.upperBound.bitBound, rhs.upperBound.bitBound) + 1)

      lazy val signum: Int = {
        val maxDigits = separationBound.decimalDigits + 1
        val approxOnly = maxDigits > Int.MaxValue

        // An adaptive algorithm to find the sign. Rather than just compute
        // this number to `maxDigits` precision, we start with a smaller
        // precision and keep adding digits until we get one that isn't 0.
        @tailrec def loop(digits0: Long): Int = {
          val digits = min(digits0, min(maxDigits, Int.MaxValue)).toInt
          val approx = toBigDecimal(digits + 1).setScale(digits, RoundingMode.DOWN)
          if (approx.signum != 0 || digits >= maxDigits) {
            approx.signum
          } else if (digits == Int.MaxValue) {
            throw new ArithmeticException("required precision to calculate sign is too high")
          } else {
            loop(2 * digits0)
          }
        }

        loop(4)
      }

      def toBigDecimal(digits: Int): JBigDecimal = {
        val lValue = lhs.toBigDecimal(digits + 1)
        val rValue = rhs.toBigDecimal(digits + 1)
        val sum = this match {
          case (_: Add) => lValue.add(rValue)
          case (_: Sub) => lValue.subtract(rValue)
        }
        val result = sum.setScale(digits, RoundingMode.DOWN)
        result
      }
    }

    @SerialVersionUID(0L)
    case class Add(lhs: Expr, rhs: Expr) extends AddOrSubExpr

    @SerialVersionUID(0L)
    case class Sub(lhs: Expr, rhs: Expr) extends AddOrSubExpr

    @SerialVersionUID(0L)
    case class Mul(lhs: Expr, rhs: Expr) extends BinaryExpr {
      def upperBound: BitBound = lhs.upperBound + rhs.upperBound
      def signum: Int = lhs.signum * rhs.signum
      def toBigDecimal(digits: Int): JBigDecimal = {
        val lDigits = checked(rhs.upperBound.decimalDigits + digits + 1)
        val rDigits = checked(lhs.upperBound.decimalDigits + digits + 1)
        if (lDigits >= Int.MaxValue || rDigits >= Int.MaxValue) {
          throw new IllegalArgumentException("required precision is too high")
        } else {
          val lValue = lhs.toBigDecimal(lDigits.toInt)
          val rValue = rhs.toBigDecimal(rDigits.toInt)
          lValue.multiply(rValue).setScale(digits, RoundingMode.DOWN)
        }
      }
    }

    @SerialVersionUID(0L)
    case class Div(lhs: Expr, rhs: Expr) extends BinaryExpr {
      def upperBound: BitBound = lhs.upperBound - rhs.lowerBound
      def signum: Int = if (rhs.signum == 0) {
        throw new ArithmeticException("divide by 0")
      } else {
        lhs.signum * rhs.signum
      }
      def toBigDecimal(digits: Int): JBigDecimal = checked {
        val lDigits = digits + 2 - rhs.lowerBound.decimalDigits
        val rDigits = max(
          1 - rhs.lowerBound.decimalDigits,
          digits + 4 - 2 * rhs.lowerBound.decimalDigits + lhs.upperBound.decimalDigits
        )
        if (lDigits >= Int.MaxValue || rDigits >= Int.MaxValue) {
          throw new IllegalArgumentException("required precision is too high")
        } else {
          val lValue = lhs.toBigDecimal(lDigits.toInt)
          val rValue = rhs.toBigDecimal(rDigits.toInt)
          val quotient = lValue.divide(rValue, digits + 1, RoundingMode.DOWN)
          quotient.setScale(digits, RoundingMode.DOWN)
        }
      }
    }

    @SerialVersionUID(0L)
    case class KRoot(sub: Expr, k: Int) extends UnaryExpr {
      val flagBits: Int = (sub.flags | Flags.IsRadical).bits

      def upperBound: BitBound = (sub.upperBound + 1) / 2

      def signum: Int = {
        val s = sub.signum
        if (s >= 0) s
        else throw new ArithmeticException(s"$k-root of negative number")
      }

      def toBigDecimal(digits: Int): JBigDecimal = {
        val digits0 = max(
          checked(digits + 1),
          checked(1 - (sub.lowerBound.decimalDigits + 1) / 2)
        )
        if (digits0 >= Int.MaxValue) {
          throw new IllegalArgumentException("required precision is too high")
        } else {
          val value = sub.toBigDecimal(digits0.toInt)
          Algebraic.nroot(value, k, digits, RoundingMode.DOWN)
        }
      }

      // To avoid multiple traversals during degreeBound, we cache the hashCode
      // for KRoots.
      override lazy val hashCode: Int =
        sub.hashCode * 23 + k * 29 + 13
    }

    @SerialVersionUID(0L)
    case class Pow(sub: Expr, k: Int) extends UnaryExpr {
      require(k > 1)

      def flagBits: Int = sub.flags.bits

      def upperBound: BitBound = sub.upperBound * k
      def signum: Int = {
        val s = sub.signum
        if (s == 0) {
          if (k < 0) throw new ArithmeticException("divide by 0")
          else if (k == 0) throw new ArithmeticException("indeterminate")
          else 0
        } else if (k % 2 == 0) {
          if (s < 0) 1 else s
        } else {
          s
        }
      }
      def toBigDecimal(digits: Int): JBigDecimal = {
        // We could possibly do better here. Investigate.
        val height = 32 - java.lang.Integer.numberOfLeadingZeros(k - 1) // ceil(lg2(k))
        val maxDigits = checked(digits + height * (1 + sub.upperBound.decimalDigits))
        if (maxDigits >= Int.MaxValue) {
          throw new IllegalArgumentException("required precision is too high")
        } else {
          val leafValue = sub.toBigDecimal(maxDigits.toInt)
          leafValue.pow(k)
        }
      }
    }
  }

  /**
   * A bit bound represents either an upper or lower bound as some
   * power of 2. Specifically, the bound is typically either `2^bitBound` or
   * `2^-bitBound`.
   */
  final class BitBound(val bitBound: Long) extends AnyVal {
    import BitBound.bitsToDecimalDigits

    /**
     * Returns the minimum number of absolute decimal digits required to
     * represent this separation bound.
     */
    def decimalDigits: Long = bitsToDecimalDigits(bitBound)

    def unary_- : BitBound = new BitBound(-bitBound)

    def +(that: BitBound): BitBound = new BitBound(this.bitBound + that.bitBound)
    def -(that: BitBound): BitBound = new BitBound(this.bitBound - that.bitBound)
    def *(that: BitBound): BitBound = new BitBound(this.bitBound * that.bitBound)
    def /(that: BitBound): BitBound = new BitBound(this.bitBound / that.bitBound)

    def +(rhs: Int): BitBound = new BitBound(this.bitBound + rhs)
    def -(rhs: Int): BitBound = new BitBound(this.bitBound - rhs)
    def *(rhs: Int): BitBound = new BitBound(this.bitBound * rhs)
    def /(rhs: Int): BitBound = new BitBound(this.bitBound / rhs)

    def min(that: BitBound): BitBound =
      if (bitBound < that.bitBound) this else that

    override def toString: String = s"BitBound($bitBound)"
  }

  object BitBound {
    private val Epsilon: Double = 2.220446049250313E-16

    private val FudgeFactor: Double = 1D + 4D * Epsilon

    private val lg2ToLg10: Double = log(2, 10) * FudgeFactor

    private def bitsToDecimalDigits(n: Long): Long =
      ceil(n.toDouble * lg2ToLg10).toLong

    final def apply(n: Int): BitBound = new BitBound(n)
  }

  /**
   * Returns a number that is approximately equal to `x.pow(1/n)`. This number
   * is useful as initial values in converging n-root algorithms, but not as a
   * general purpose n-root algorithm. There are no guarantees about the
   * accuracy here.
   */
  final def nrootApprox(x: JBigDecimal, n: Int): JBigDecimal = {
    // Essentially, we'd like to just find `x.doubleValue.pow(1D / n)`, but x
    // may not be approximable as a finite Double (eg. exponent is larger than
    // 308). So, we basically treat x as a number `a*10^(i+j)`, where
    // `a*10^i` is approximable as a Double and `j % n == 0`. Then, we can
    // approximate the n-th root as `pow(a*10^i, 1 / n) * 10^(j/n)`.

    // If n > ~308, then we could end up with an "approximate" value that is
    // an Infinity, which is no good. So, we approximate all roots > 306 with
    // 306-th root.
    val k = min(n, 306)
    // We need to ensure that the scale of our approximate number leaves `j`
    // evenly divible by n. So, we start by calculating the scale requried to
    // put the decimal place after the first digit
    val width = (ceil(x.unscaledValue.bitLength * log(2) / log(10)) - 1).toInt
    // We then add in (x.scale - width) % n to our initial scale so that the
    // remaining exponenent is divisible by n.
    val safeWidth = width + (x.scale - width) % k
    val approx = new JBigDecimal(x.unscaledValue.abs, safeWidth).doubleValue
    new JBigDecimal(x.signum * pow(approx, 1D / k))
      .scaleByPowerOfTen(-(x.scale - safeWidth) / k)
      .round(MathContext.DECIMAL64)
  }

  /**
   * Approximates the n-th root using the Newton's method. Rather than using a
   * fixed epsilon, it may use an adaptive epsilon, provided by `getEps`. This
   * function takes the previous approximation, and returns the epsilon as
   * `pow(10, -getEps(prev))`. This allows us to use the same algorithm for
   * both absolute and relative precision approximations. Absolute
   * approximations just returns a fixed epsilon from `getEps`, where as a
   * relative approximation returns an adaptive one, that uses the previous
   * value to guide the required epsilon.
   */
  private final def nroot(signedValue: JBigDecimal, k: Int)(getEps: JBigDecimal => Int): JBigDecimal = {
    if (signedValue.compareTo(JBigDecimal.ZERO) == 0)
      return JBigDecimal.ZERO
    val value = signedValue.abs
    val n = new JBigDecimal(k)
    @tailrec def loop(prev: JBigDecimal, prevDigits: Int, prevEps: JBigDecimal): JBigDecimal = {
      val digits = getEps(prev)
      val eps =
        if (digits == prevDigits) prevEps
        else JBigDecimal.ONE.movePointLeft(digits)
      val prevExp = prev.pow(k - 1)
      val delta = value
        .divide(prevExp, digits, RoundingMode.HALF_UP)
        .subtract(prev)
        .divide(n, digits, RoundingMode.HALF_UP)
      if (delta.abs.compareTo(eps) <= 0) prev
      else loop(prev.add(delta), digits, eps)
    }
    val init = nrootApprox(value, k)
    val unsignedResult = loop(init, Int.MinValue, JBigDecimal.ZERO)
    if (signedValue.signum < 0) unsignedResult.negate
    else unsignedResult
  }

  private val bits2dec: Double = log(2, 10)

  /**
   * Returns a relative approximation of the n-th root of `value`, up to
   * the number of digits specified by `mc`. This only uses the rounding mode
   * to chop-off the few remaining digits after the approximation, so may be
   * inaccurate.
   */
  final def nroot(value: JBigDecimal, n: Int, mc: MathContext): JBigDecimal = {
    val result = nroot(value, n) { x =>
      x.scale - ceil(x.unscaledValue.bitLength * bits2dec).toInt + mc.getPrecision + 1
    }
    result.round(mc)
  }

  /**
   * Returns an absolute approximation of the n-th root of `value`, up to
   * `scale` digits past the decimal point. This only uses the rounding mode
   * to chop-off the few remaining digits after the approximation, so may be
   * inaccurate.
   */
  final def nroot(value: JBigDecimal, n: Int, scale: Int, roundingMode: RoundingMode): JBigDecimal =
    nroot(value, n)(_ => scale + 1).setScale(scale, roundingMode)

  private implicit val JBigDecimalOrder: Order[JBigDecimal] = new Order[JBigDecimal] {
    def compare(x: JBigDecimal, y: JBigDecimal): Int = x compareTo y
  }

  /**
   * Rounds an approximation (`approx`) to the `exact` Algebraic value using
   * the given `scale` and `RoundingMode` (`mode`). This will always be
   * accurate for any algebraic number. So, if `exact` represents 0.15 and the
   * rounding mode is set to `HALF_UP` with a scale of 1, then this is
   * guaranteed to round up to 0.2.
   *
   * @param exact  the exact value to use a reference for tricky cases
   * @param approx the approximate value to round
   * @param scale  the final scale of the result
   * @param mode   the rounding mode to use
   */
  private def roundExact(exact: Algebraic, approx: JBigDecimal, scale: Int, mode: RoundingMode): JBigDecimal = {
    import RoundingMode.{ CEILING, FLOOR, UP }

    if (approx.signum == 0) {
      // If the sign is 0, then we deal with it here.
      mode match {
        case UP | CEILING if exact.signum > 0 =>
          new JBigDecimal(BigInteger.ONE, scale)
        case UP | FLOOR if exact.signum < 0 =>
          new JBigDecimal(BigInteger.ONE.negate, scale)
        case _ =>
          approx.setScale(scale, RoundingMode.DOWN)
      }
    } else if (approx.signum > 0) {
      roundPositive(exact, approx, scale, mode)
    } else {
      val adjustedMode = mode match {
        case CEILING => FLOOR
        case FLOOR => CEILING
        case _ => mode
      }
      roundPositive(-exact, approx.abs, scale, adjustedMode).negate()
    }
  }

  private def roundPositive(exact: Algebraic, approx: JBigDecimal, scale: Int, mode: RoundingMode): JBigDecimal = {
    import RoundingMode.{ CEILING, FLOOR, DOWN, UP, HALF_DOWN, HALF_UP, HALF_EVEN, UNNECESSARY }

    val cutoff = approx.scale - scale
    if (cutoff == 0) {
      // Nothing to do here.
      approx
    } else if (cutoff < 0) {
      // Just add some 0s and we're done!
      approx.setScale(scale, RoundingMode.DOWN)
    } else if (cutoff > 18) {
      // We'd like to work with Long arithmetic, if possible. Our rounding is
      // exact anyways, so it doesn't hurt to remove some digits.
      roundPositive(exact, approx.setScale(scale + 18, RoundingMode.DOWN), scale, mode)
    } else {
      val unscale = spire.math.pow(10L, cutoff.toLong)
      val Array(truncatedUnscaledValue, bigRemainder) =
        approx
          .unscaledValue
          .divideAndRemainder(BigInteger.valueOf(unscale))
      val truncated = new JBigDecimal(truncatedUnscaledValue, scale)
      def epsilon: JBigDecimal = new JBigDecimal(BigInteger.ONE, scale)
      val remainder = bigRemainder.longValue
      val rounded = mode match {
        case UNNECESSARY =>
          truncated

        case HALF_DOWN | HALF_UP | HALF_EVEN =>
          val dangerZoneStart = (unscale / 2) - 1
          val dangerZoneStop = dangerZoneStart + 2
          if (remainder >= dangerZoneStart && remainder <= dangerZoneStop) {
            val splitter = BigDecimal(new JBigDecimal(
              truncatedUnscaledValue.multiply(BigInteger.TEN).add(BigInteger.valueOf(5)),
              scale + 1
            ))
            val cmp = exact compare Algebraic(splitter)
            val roundUp = (mode: @unchecked) match {
              case HALF_DOWN => cmp > 0
              case HALF_UP => cmp >= 0
              case HALF_EVEN => cmp > 0 || cmp == 0 && truncatedUnscaledValue.testBit(0)
            }
            if (roundUp) truncated.add(epsilon)
            else truncated
          } else if (remainder < dangerZoneStart) {
            truncated
          } else {
            truncated.add(epsilon)
          }

        case CEILING | UP =>
          if (remainder <= 1 && exact <= Algebraic(BigDecimal(truncated))) {
            truncated
          } else {
            truncated.add(epsilon)
          }

        case FLOOR | DOWN =>
          if (remainder <= 0) {
            if (exact < Algebraic(BigDecimal(truncated))) {
              truncated.subtract(epsilon)
            } else {
              truncated
            }
          } else if (remainder >= (unscale - 1)) {
            val roundedUp = truncated.add(epsilon)
            if (exact >= Algebraic(BigDecimal(roundedUp))) {
              roundedUp
            } else {
              truncated
            }
          } else {
            truncated
          }
      }

      rounded
    }
  }

  private val MaxIntValue: BigInteger = BigInteger.valueOf(Int.MaxValue.toLong)
  private val MinIntValue: BigInteger = BigInteger.valueOf(Int.MinValue.toLong)
  private val MaxLongValue: BigInteger = BigInteger.valueOf(Long.MaxValue)
  private val MinLongValue: BigInteger = BigInteger.valueOf(Long.MinValue)

  /**
   * A zero bound function, defined over an algebraic expression algebra.
   */
  sealed abstract class ZeroBoundFunction {

    /**
     * Some state that is computed for each node in the expression tree. This
     * state is typically memoized, to avoid recomputation.
     */
    type Bound

    def apply(expr: Algebraic.Expr): Bound
  }

  /**
   * An implementation of "A New Constructive Root Bound for Algebraic
   * Expressions" by Chen Li & Chee Yap.
   */
  @SerialVersionUID(0L)
  case object LiYap extends ZeroBoundFunction {
    import Expr._

    final case class Bound(
      /** Bound on the leading coefficient. */
      lc: Long,
      /** Bound on the trailing coefficient. */
      tc: Long,
      /** Bound on the measure. */
      measure: Long,
      /** Lower bound on the value. */
      lb: Long,
      /** Upper bound on the value. */
      ub: Long
    ) {
      def getBitBound(degreeBound: Long): Long = checked {
        ub * (degreeBound - 1) + lc
      }
    }

    def apply(expr: Algebraic.Expr): Bound = checked {
      // Unfortunately, we must call degreeBound early, to avoid many redundant
      // traversals of the Expr tree. Getting this out of the way early on
      // means that we will traverse the tree once and populate the degreeBound
      // cache in all nodes right away. If we do it in a bottom up fashion,
      // then we risk terrible runtime behaviour.
      val degreeBound = expr.degreeBound
      expr match {
        case ConstantLong(n) =>
          rational(Rational(n))

        case ConstantDouble(n) =>
          rational(Rational(n))

        case ConstantBigDecimal(n) =>
          rational(Rational(n))

        case ConstantRational(n) =>
          rational(n)

        case root @ ConstantRoot(poly, _, _, _) =>
          // Bound on the euclidean distance of the coefficients.
          val distBound = poly.terms.map { case Term(c, _) =>
            2L * c.bitLength
          }.qsum / 2L + 1L
          Bound(
            root.lead.bitLength + 1L,
            root.tail.bitLength + 1L,
            distBound,
            Roots.lowerBound(poly),
            Roots.upperBound(poly)
          )

        case Neg(sub) =>
          sub.getBound(this)

        case expr: AddOrSubExpr =>
          val lhsExpr = expr.lhs
          val rhsExpr = expr.rhs
          val lhs = lhsExpr.getBound(this)
          val rhs = rhsExpr.getBound(this)
          val lc = lhs.lc * rhsExpr.degreeBound + rhs.lc * lhsExpr.degreeBound
          val tc = lhs.measure * rhsExpr.degreeBound + rhs.measure * lhsExpr.degreeBound + 2 * degreeBound
          val measure = tc
          val ub = max(lhs.ub, rhs.ub) + 1
          val lb = max(-measure, -(ub * (degreeBound - 1) + lc))
          Bound(lc, tc, measure, lb, ub)

        case Mul(lhsExpr, rhsExpr) =>
          val lhs = lhsExpr.getBound(this)
          val rhs = rhsExpr.getBound(this)
          val lc = lhs.lc * rhsExpr.degreeBound + rhs.lc * lhsExpr.degreeBound
          val tc = lhs.tc * rhsExpr.degreeBound + rhs.tc * lhsExpr.degreeBound
          val measure = lhs.measure * rhsExpr.degreeBound + rhs.measure * lhsExpr.degreeBound
          val lb = lhs.lb + rhs.lb
          val ub = lhs.ub + rhs.ub
          Bound(lc, tc, measure, lb, ub)

        case Div(lhsExpr, rhsExpr) =>
          val lhs = lhsExpr.getBound(this)
          val rhs = rhsExpr.getBound(this)
          val lc = lhs.lc * rhsExpr.degreeBound + rhs.tc * lhsExpr.degreeBound
          val tc = lhs.tc * rhsExpr.degreeBound + rhs.lc * lhsExpr.degreeBound
          val measure = lhs.measure * rhsExpr.degreeBound + rhs.measure * lhsExpr.degreeBound
          val lb = lhs.lb - rhs.ub
          val ub = lhs.ub - rhs.lb
          Bound(lc, tc, measure, lb, ub)

        case KRoot(subExpr, k) =>
          val sub = subExpr.getBound(this)
          val lb = sub.lb / k
          val ub = if (sub.ub % k == 0) (sub.ub / k)
                   else ((sub.ub / k) + 1)
          Bound(sub.lc, sub.tc, sub.measure, lb, ub)

        case Pow(subExpr, k) =>
          val sub = subExpr.getBound(this)
          Bound(sub.lc * k, sub.tc * k, sub.measure * k, sub.lb * k, sub.ub * k)
      }
    }

    private def rational(n: Rational): Bound = {
      // TODO: We can do better here. The + 1 isn't always needed in a & b.
      // Also, the upper and lower bounds could be much tighter if we actually
      // partially perform the division.
      val a = n.numerator.abs.bitLength + 1
      if (n.denominator == BigInt(1)) {
        Bound(0, a, a, a - 1, a)
      } else {
        val b = n.denominator.bitLength + 1
        Bound(b, a, max(a, b), a - b - 1, a - b + 1)
      }
    }
  }

  /**
   * An implementation of "A Separation Bound for Real Algebraic Expressions",
   * by Burnikel, Funke, Mehlhorn, Schirra, and Schmitt. This provides a good
   * [[ZeroBoundFunction]] for use in sign tests.
   *
   * Unlike the paper, we use log-arithmetic instead of working with exact,
   * big integer values. This means our bound isn't technically as good as it
   * could be, but we save the cost of working with arithmetic. We also perform
   * all log arithmetic using `Long`s and check for overflow (throwing
   * `ArithmeticException`s when detected). In practice we shouldn't hit this
   * limit, but in case we do, we prefer to throw over failing silently.
   */
  @SerialVersionUID(0L)
  case object BFMSS extends ZeroBoundFunction {
    import Expr._

    /** Our state that we store, per node. */
    final case class Bound(l: Long, u: Long) {
      def getBitBound(degreeBound: Long): Long = checked {
        l + u * (degreeBound - 1)
      }
    }

    def apply(expr: Algebraic.Expr): Bound = expr match {
      case ConstantLong(n) => integer(n)
      case ConstantDouble(n) => rational(n)
      case ConstantBigDecimal(n) => rational(n)
      case ConstantRational(n) => rational(n)
      case root @ ConstantRoot(poly, _, _, _) =>
        Bound(root.lead.bitLength + 1, Roots.upperBound(poly))
      case Neg(sub) => sub.getBound(this)
      case Add(lhs, rhs) => add(lhs.getBound(this), rhs.getBound(this))
      case Sub(lhs, rhs) => add(lhs.getBound(this), rhs.getBound(this))
      case Mul(lhs, rhs) => mul(lhs.getBound(this), rhs.getBound(this))
      case Div(lhs, rhs) => div(lhs.getBound(this), rhs.getBound(this))
      case KRoot(sub, k) => nroot(sub.getBound(this), k)
      case Pow(sub, k) => pow(sub.getBound(this), k)
    }

    private def integer(n: Long): Bound =
      integer(BigInt(n))

    private def integer(n: SafeLong): Bound =
      Bound(0, n.abs.bitLength + 1)

    private def rational(n: Double): Bound =
      rational(BigDecimal(n))

    private def rational(n: BigDecimal): Bound =
      rational(Rational(n))

    private def rational(n: Rational): Bound =
      div(integer(n.numerator), integer(n.denominator))

    // We're not being fair to the BFMSS bound here. We're really just
    // setting a bound on the max value. However, the alternative would
    // require us to work outside of log arithmetic.
    private def add(lhs: Bound, rhs: Bound): Bound = checked {
      Bound(
        lhs.l + rhs.l,
        math.max(lhs.u + rhs.l, lhs.l + rhs.u) + 1
      )
    }

    private def mul(lhs: Bound, rhs: Bound): Bound = checked {
      Bound(
        lhs.l + rhs.l,
        lhs.u + rhs.u
      )
    }

    private def div(lhs: Bound, rhs: Bound): Bound = checked {
      Bound(
        lhs.l + rhs.u,
        lhs.u + rhs.l
      )
    }

    private def nroot(sub: Bound, k: Int): Bound = checked {
      if (sub.u < sub.l) {
        Bound(
          (sub.l + (k - 1) * sub.u) / k,
          sub.u
        )
      } else {
        Bound(
          sub.l,
          (sub.u * (k - 1) * sub.l) / k
        )
      }
    }

    private def pow(sub: Bound, k: Int): Bound = {
      @tailrec def sum(acc: Long, k: Int, extra: Long): Long =
        if (k == 1) {
          checked(acc + extra)
        } else {
          val x =
            if ((k & 1) == 1) checked(acc + extra)
            else extra
          sum(checked(acc + acc), k >>> 1, x)
        }

      if (k > 1) {
        Bound(
          sum(sub.l, k - 1, sub.l),
          sum(sub.u, k - 1, sub.u)
        )
      } else if (k == 1) {
        sub
      } else if (k == 0) {
        throw new IllegalArgumentException("exponent cannot be 0")
      } else {
        throw new IllegalArgumentException("exponent cannot be negative")
      }
    }
  }
}

trait AlgebraicInstances {
  implicit final val AlgebraicAlgebra = new AlgebraicAlgebra

  import NumberTag._
  implicit final val AlgebraicTag = new LargeTag[Algebraic](Exact, Algebraic(0))
}

private[math] trait AlgebraicIsFieldWithNRoot extends Field[Algebraic] with NRoot[Algebraic] {
  def zero: Algebraic = Algebraic.Zero
  def one: Algebraic = Algebraic.One
  def plus(a: Algebraic, b: Algebraic): Algebraic = a + b
  def negate(a: Algebraic): Algebraic = -a
  override def minus(a: Algebraic, b: Algebraic): Algebraic = a - b
  override def pow(a: Algebraic, b: Int): Algebraic = a pow b
  override def times(a: Algebraic, b: Algebraic): Algebraic = a * b
  def quot(a: Algebraic, b: Algebraic): Algebraic = a /~ b
  def mod(a: Algebraic, b: Algebraic): Algebraic = a % b
  def gcd(a: Algebraic, b: Algebraic): Algebraic = euclid(a, b)(Eq[Algebraic])
  def div(a:Algebraic, b:Algebraic): Algebraic = a / b
  def nroot(a: Algebraic, k: Int): Algebraic = a nroot k
  def fpow(a:Algebraic, b:Algebraic): Algebraic = throw new UnsupportedOperationException("unsupported operation")
  override def fromInt(n: Int): Algebraic = Algebraic(n)
  override def fromDouble(n: Double): Algebraic = Algebraic(n)
}

private[math] trait AlgebraicIsReal extends IsAlgebraic[Algebraic] {
  def toDouble(x: Algebraic): Double = x.toDouble
  def toAlgebraic(x: Algebraic): Algebraic = x
  def ceil(a:Algebraic): Algebraic = Algebraic(a.toBigDecimal(0, RoundingMode.CEILING))
  def floor(a:Algebraic): Algebraic = Algebraic(a.toBigDecimal(0, RoundingMode.FLOOR))
  def round(a:Algebraic): Algebraic = Algebraic(a.toBigDecimal(0, RoundingMode.HALF_EVEN))
  def isWhole(a:Algebraic): Boolean = a.isWhole
  override def sign(a: Algebraic): Sign = a.sign
  def signum(a: Algebraic): Int = a.signum
  def abs(a: Algebraic): Algebraic = a.abs
  override def eqv(x: Algebraic, y: Algebraic): Boolean = x.compare(y) == 0
  override def neqv(x: Algebraic, y: Algebraic): Boolean = x.compare(y) != 0
  def compare(x: Algebraic, y: Algebraic): Int = x.compare(y)
}

@SerialVersionUID(1L)
class AlgebraicAlgebra extends AlgebraicIsFieldWithNRoot with AlgebraicIsReal with Serializable




© 2015 - 2025 Weber Informatics LLC | Privacy Policy