All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spire.math.poly.BigDecimalRootRefinement.scala Maven / Gradle / Ivy

package spire.math
package poly

import scala.annotation.tailrec

import java.math.{ BigDecimal => JBigDecimal, RoundingMode, MathContext }

import spire.algebra._
import spire.std.bigDecimal._
import spire.syntax.signed._

case class BigDecimalRootRefinement(
  context: BigDecimalRootRefinement.ApproximationContext,
  approximation: BigDecimalRootRefinement.Approximation
) {
  import BigDecimalRootRefinement._

  def approximateValue: JBigDecimal = {
    import context.{ floor, ceil }
    approximation match {
      case ExactRoot(root) => floor(root)
      case Bounded(lb, _, _, _, n) => ceil(lb)
      case BoundedLeft(_, ub) => floor(ub)
      case BoundedRight(lb, _) => ceil(lb)
      case Unbounded(lb, ub) => ceil(lb)
    }
  }

  def refine(scale: Int): BigDecimalRootRefinement =
    context match {
      case AbsoluteContext(poly, oldScale) if oldScale >= scale =>
        this
      case _ =>
        val newContext = AbsoluteContext(context.poly, scale)
        BigDecimalRootRefinement(newContext, refineApproximation(newContext))
    }

  def refine(mc: MathContext): BigDecimalRootRefinement =
    context match {
      case RelativeContext(poly, oldmc) if oldmc.getPrecision >= mc.getPrecision =>
        this
      case _ =>
        val newContext = RelativeContext(context.poly, mc)
        BigDecimalRootRefinement(newContext, refineApproximation(newContext))
    }

  def refineApproximation(ctx: ApproximationContext): Approximation = approximation match {
    case ExactRoot(_) =>
      approximation
    case Bounded(lb, lby, ub, uby, n) =>
      QIR(ctx, lb, lby, ub, uby, n)
    case BoundedLeft(lb, ubApprox) =>
      val lbApprox = ctx.ceil(lb)
      val ub = Rational(new BigDecimal(ubApprox, MathContext.UNLIMITED))
      QIR(ctx, lb, ub, lbApprox, ubApprox)
    case BoundedRight(lbApprox, ub) =>
      val lb = Rational(new BigDecimal(lbApprox, MathContext.UNLIMITED))
      val ubApprox = ctx.floor(ub)
      QIR(ctx, lb, ub, lbApprox, ubApprox)
    case Unbounded(lb, ub) =>
      val lbApprox = ctx.ceil(lb)
      val ubApprox = ctx.floor(ub)
      QIR(ctx, lb, ub, lbApprox, ubApprox)
  }
}

object BigDecimalRootRefinement {
  def apply(poly: Polynomial[BigDecimal], lowerBound: Rational, upperBound: Rational): BigDecimalRootRefinement = {
    // We expect the polynomial to evaluate with exact arithmetic, so we ditch
    // whatever MathContext the user provided and use an UNLIMITED context. It
    // would probably be best just to use JBigDecimal and avoid the unexpected
    // precision loss inherent with using them.
    val upoly = poly.map { n => new BigDecimal(n.bigDecimal, MathContext.UNLIMITED) }
    BigDecimalRootRefinement(AbsoluteContext(upoly), Unbounded(lowerBound, upperBound))
  }

  def apply(poly: Polynomial[BigDecimal], lowerBound: Rational, upperBound: Rational, scale: Int): BigDecimalRootRefinement =
    apply(poly, lowerBound, upperBound).refine(scale)

  def apply(poly: Polynomial[BigDecimal], lowerBound: Rational, upperBound: Rational, mc: MathContext): BigDecimalRootRefinement =
    apply(poly, lowerBound, upperBound).refine(mc)

  private implicit object JBigDecimalOrder extends Signed[JBigDecimal] with Order[JBigDecimal] {
    def signum(a: JBigDecimal): Int = a.signum
    def abs(a: JBigDecimal): JBigDecimal = a.abs
    def compare(x: JBigDecimal, y: JBigDecimal): Int = x compareTo y
  }

  private val bits2dec: Double = log(2, 10)

  sealed abstract class Approximation
  case class ExactRoot(root: JBigDecimal) extends Approximation
  case class Bounded(lowerBound: JBigDecimal, lowerBoundValue: JBigDecimal, upperBound: JBigDecimal, upperBoundValue: JBigDecimal, n: Int) extends Approximation
  case class BoundedLeft(lowerBound: Rational, upperBound: JBigDecimal) extends Approximation
  case class BoundedRight(lowerBound: JBigDecimal, upperBound: Rational) extends Approximation
  case class Unbounded(lowerBound: Rational, upperBound: Rational) extends Approximation

  sealed abstract class ApproximationContext {
    def poly: Polynomial[BigDecimal]
    def getEps(x: JBigDecimal): Int
    def evalExact(x: JBigDecimal): JBigDecimal
    def floor(x: Rational): JBigDecimal
    def ceil(x: Rational): JBigDecimal
    def floor(x: JBigDecimal): JBigDecimal
    def ceil(x: JBigDecimal): JBigDecimal
  }

  case class AbsoluteContext private[poly] (poly: Polynomial[BigDecimal], scale: Int = Int.MinValue) extends ApproximationContext {
    def getEps(x: JBigDecimal): Int = scale

    def evalExact(x: JBigDecimal): JBigDecimal =
      poly(new BigDecimal(x, MathContext.UNLIMITED))
        .bigDecimal
        .setScale(scale, RoundingMode.UP)

    def floor(x: Rational): JBigDecimal =
      x.toBigDecimal(scale, RoundingMode.FLOOR).bigDecimal

    def ceil(x: Rational): JBigDecimal =
      x.toBigDecimal(scale, RoundingMode.CEILING).bigDecimal

    def floor(x: JBigDecimal): JBigDecimal =
      x.setScale(scale, RoundingMode.FLOOR)

    def ceil(x: JBigDecimal): JBigDecimal =
      x.setScale(scale, RoundingMode.CEILING)
  }

  case class RelativeContext private[poly] (poly: Polynomial[BigDecimal], mc: MathContext = new MathContext(0)) extends ApproximationContext {
    def getEps(x: JBigDecimal): Int =
      x.scale - spire.math.ceil(x.unscaledValue.bitLength * bits2dec).toInt + mc.getPrecision + 1

    def evalExact(x: JBigDecimal): JBigDecimal =
      poly(new BigDecimal(x, MathContext.UNLIMITED))
        .bigDecimal
        .round(mc)

    def floor(x: Rational): JBigDecimal =
      x.toBigDecimal(new MathContext(mc.getPrecision, RoundingMode.CEILING)).bigDecimal

    def ceil(x: Rational): JBigDecimal =
      x.toBigDecimal(new MathContext(mc.getPrecision, RoundingMode.FLOOR)).bigDecimal

    def floor(x: JBigDecimal): JBigDecimal =
      x.round(new MathContext(mc.getPrecision, RoundingMode.FLOOR))

    def ceil(x: JBigDecimal): JBigDecimal =
      x.round(new MathContext(mc.getPrecision, RoundingMode.CEILING))
  }

  /**
   * An implementation of "Quadratic Interval Refinement for Real Roots" by
   * John Abbot for `BigDecimal`.
   */
  private def QIR(
    context: ApproximationContext,
    lowerBound: Rational,
    upperBound: Rational,
    lb: JBigDecimal,
    ub: JBigDecimal
  ): Approximation = {
    import context._

    // We avoid rational arithmetic by using the fact:
    // poly.compose(x + a/b) == poly.compose(x/b).compose(b*x + a)
    // If we first scale the polynomial by b^poly.degree, then the roots
    // remain the same, but the first composition will partially cancel with
    // b^poly.degree, leaving an integer.
    def shift(poly: Polynomial[BigDecimal], h: Rational): Polynomial[BigDecimal] = {
      val n = poly.degree
      poly
        .mapTerms { case Term(coeff, k) =>
          val a = BigDecimal(h.denominator.toBigInteger.pow(n - k), MathContext.UNLIMITED)
          Term(coeff * a, k)
        }
        .compose(Polynomial.linear[BigDecimal](BigDecimal(h.denominator.toBigInteger, MathContext.UNLIMITED), BigDecimal(h.numerator.toBigInteger, MathContext.UNLIMITED)))
        .removeZeroRoots
    }

    // Returns true if there is a root in the open sub-interval (l, r).
    def hasRoot(l: Rational, r: Rational): Boolean =
      if (l != r) {
        // Ue Descartes' rule of signs to see if the root in the open interval
        // is actually in the sub interval (l, r).
        val poly0 = shift(shift(poly, l), (r - l).reciprocal)
        poly0.signVariations % 2 == 1
      } else {
        false
      }

    // QIR expects that (lx,rx) contain exactly 1 root and that they evaluate
    // to different, non-zero signs. However, we may find 1 or both of the
    // bounds to be 0, or to have "overshot" the root, so we have to either
    // find a good (lx, rx) or return an approximation early.
    def adjust(lx: JBigDecimal, lyOpt: Option[JBigDecimal], rx: JBigDecimal, ryOpt: Option[JBigDecimal]): Approximation = {
      def qlx = Rational(new BigDecimal(lx, MathContext.UNLIMITED))
      def qrx = Rational(new BigDecimal(rx, MathContext.UNLIMITED))

      if (lx.compareTo(rx) < 0) {
        val ly = lyOpt.getOrElse(evalExact(lx))
        val ry = ryOpt.getOrElse(evalExact(rx))
        if (ly.signum == 0) {
          if (qlx > lowerBound) {
            // We've "bumped" the lowerbound up to an exact root, coincidentally.
            ExactRoot(lx)
          } else {
            // We try to push lx up a bit to get the sign to change.
            adjust(lx.add(JBigDecimal.valueOf(1, getEps(lx))), Some(ly), rx, Some(ry))
          }
        } else if (ry.signum == 0) {
          if (qrx < upperBound) {
            // We've "bumped" the lowerbound up to an exact root, coincidentally.
            ExactRoot(rx)
          } else {
            // We try to push rx down a bit to get the sign to change.
            adjust(lx, Some(ly), rx.subtract(JBigDecimal.valueOf(1, getEps(rx))), Some(ry))
          }
        } else if (ry.signum == ly.signum) {
          // We've managed to overshoot the actual root, but since we're still
          // "in-bounds", we know it's in either the left cut off bit or the
          // right.
          if (hasRoot(lowerBound, qlx)) {
            BoundedLeft(lowerBound, lb)
          } else {
            BoundedRight(ub, upperBound)
          }
        } else {
          // Yay! We've successfully approximated the lower/upper bounds with
          // big decimal, while keeping the root within (lx, rx).
          QIR(context, lx, ly, rx, ry)
        }
      } else {
        // We overshot a root.
        Unbounded(lowerBound, upperBound)
      }
    }

    adjust(lb, None, ub, None)
  }

  /**
   * An implementation of "Quadratic Interval Refinement for Real Roots" by
   * John Abbot for `BigDecimal`.
   */
  // scalastyle:off method.length
  private def QIR(
    context: ApproximationContext,
    lowerBound: JBigDecimal, lowerBoundValue: JBigDecimal,
    upperBound: JBigDecimal, upperBoundValue: JBigDecimal,
    n0: Int = 0
  ): Approximation = {
    import context._

    @tailrec
    def loop(
      lx: JBigDecimal,
      ly: JBigDecimal,
      rx: JBigDecimal,
      ry: JBigDecimal,
      n: Int
    ): Approximation = {
      val dx = rx.subtract(lx)
      val scale = max(getEps(lx), getEps(rx))
      val eps = JBigDecimal.valueOf(1, scale)
      if (dx.compareTo(eps) <= 0) {
        Bounded(lx, ly, rx, ry, n)
      } else {
        val dy = ly.subtract(ry)
        val s = ly.divide(dy, n, RoundingMode.HALF_UP) // BAM!
        val delta = dx.multiply(s.ulp)
        val k = s.unscaledValue
        val x1 = lx.add(delta.multiply(new JBigDecimal(k))) // BAM!
        val y1 = evalExact(x1)
        val s1 = y1.sign
        if (s1 == ly.sign) {
          val x2 = x1.add(delta)
          val y2 = evalExact(x2)
          val s2 = y2.sign
          if (s2 == s1) loop0(lx, ly, rx, ry)
          else if (s2 == ry.sign) loop(x1, y1, x2, y2, 2 * n)
          else ExactRoot(x2)
        } else if (s1 == ry.sign) {
          val x0 = x1.subtract(delta)
          val y0 = evalExact(x0)
          val s0 = y0.sign
          if (s0 == s1) loop0(lx, ly, rx, ry)
          else if (s0 == ly.sign) loop(x0, y0, x1, y1, 2 * n)
          else ExactRoot(x0)
        } else {
          ExactRoot(x1)
        }
      }
    }

    def bisect(
      x0: JBigDecimal,
      y0: JBigDecimal,
      x1: JBigDecimal,
      y1: JBigDecimal,
      x2: JBigDecimal,
      y2: JBigDecimal
    ): Approximation = {
      if (y0.signum == 0) ExactRoot(x0)
      else if (y1.signum == 0) ExactRoot(x1)
      else if (y2.signum == 0) ExactRoot(x2)
      else if (y0.sign != y1.sign) loop(x0, y0, x1, y1, 1)
      else loop(x1, y1, x2, y2, 1)
    }

    def loop0(
      x0: JBigDecimal,
      y0: JBigDecimal,
      x5: JBigDecimal,
      y5: JBigDecimal
    ): Approximation = {
      val dy = y0.subtract(y5)
      val k = y0
        .divide(dy, 1, RoundingMode.HALF_UP)
        .unscaledValue
        .intValue
      val eps = x5.subtract(x0).divide(new JBigDecimal(5))
      def eval(k: Int): (JBigDecimal, JBigDecimal) = {
        val x = new JBigDecimal(k).multiply(eps).add(x0)
        val y = evalExact(x)
        (x, y)
      }

      if (k < 5) {
        val (x2, y2) = eval(2)
        if (y2.sign != y0.sign) {
          val (x1, y1) = eval(1)
          bisect(x0, y0, x1, y1, x2, y2)
        } else {
          val (x3, y3) = eval(3)
          if (y3.sign == y5.sign) {
            loop(x2, y2, x3, y3, 1)
          } else {
            val (x4, y4) = eval(4)
            bisect(x3, y3, x4, y4, x5, y5)
          }
        }
      } else {
        val (x3, y3) = eval(3)
        if (y3.sign != y5.sign) {
          val (x4, y4) = eval(4)
          bisect(x3, y3, x4, y4, x5, y5)
        } else {
          val (x2, y2) = eval(2)
          if (y2.sign == y0.sign) {
            loop(x2, y2, x3, y3, 1)
          } else {
            val (x1, y1) = eval(1)
            bisect(x0, y0, x1, y1, x2, y2)
          }
        }
      }
    }

    val lx = lowerBound
    val ly = lowerBoundValue
    val rx = upperBound
    val ry = upperBoundValue

    if (n0 <= 0) {
      loop0(lx, ly, rx, ry)
    } else {
      loop(lx, ly, rx, ry, n0)
    }
  }
  // scalastyle:on method.length
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy