spire.math.Interval.scala Maven / Gradle / Ivy
package spire.math
import Predef.{any2stringadd => _, _}
import spire.algebra._
import spire.math.poly.Term
import spire.math.interval._
import spire.math.interval.Bound._
import spire.syntax.field._
import spire.syntax.nroot._
import spire.syntax.order._
import java.lang.Double.isNaN
/**
* Interval represents a set of values, usually numbers.
*
* Intervals have upper and lower bounds. Each bound can be one of
* four kinds:
*
* * Closed: The boundary value is included in the interval.
* * Open: The boundary value is excluded from the interval.
* * Unbound: There is no boundary value.
* * Empty: The interval itself is empty.
*
* When the underlying type of the interval supports it, intervals may
* be used in arithmetic. There are several possible interpretations
* of interval arithmetic: the interval can represent uncertainty
* about a single value (for instance, a quantity +/- tolerance in
* engineering) or it can represent all values in the interval
* simultaneously. In this implementation we have chosen to use the
* probabillistic interpretation.
*
* One common pitfall with interval arithmetic is that many familiar
* algebraic relations do not hold. For instance, given two intervals
* a and b:
*
* a == b does not imply a * a == a * b
*
* Consider a = b = [-1, 1]. Since any number times itself is
* non-negative, a * a = [0, 1]. However, a * b = [-1, 1], since we
* may actually have a=1 and b=-1.
*
* These situations will result in loss of precision (in the form of
* wider intervals). The result is not wrong per se, but less
* acccurate than it could be.
*/
sealed abstract class Interval[A](implicit order: Order[A]) { lhs =>
@inline protected[this] final def isClosed(flags: Int): Boolean = flags == 0
@inline protected[this] final def isClosedLower(flags: Int): Boolean = (flags & 1) == 0
@inline protected[this] final def isClosedUpper(flags: Int): Boolean = (flags & 2) == 0
@inline protected[this] final def isOpen(flags: Int): Boolean = flags == 3
@inline protected[this] final def isOpenLower(flags: Int): Boolean = (flags & 1) == 1
@inline protected[this] final def isOpenUpper(flags: Int): Boolean = (flags & 2) == 2
@inline protected[this] final def lowerFlag(flags: Int): Int = flags & 1
@inline protected[this] final def upperFlag(flags: Int): Int = flags & 2
@inline protected[this] final def reverseLowerFlag(flags: Int): Int = flags ^ 1
@inline protected[this] final def reverseUpperFlag(flags: Int): Int = flags ^ 2
@inline protected[this] final def reverseFlags(flags: Int): Int = flags ^ 3
protected[this] final def lowerFlagToUpper(flags: Int): Int = (flags & 1) << 1
protected[this] final def upperFlagToLower(flags: Int): Int = (flags & 2) >>> 1
@inline protected[this] final def swapFlags(flags: Int): Int =
((flags & 1) << 1) | ((flags & 2) >>> 1)
protected[this] def lowerPairBelow(lower1: A, flags1: Int, lower2: A, flags2: Int): Boolean =
lower1 < lower2 || lower1 === lower2 && (isClosedLower(flags1) || isOpenLower(flags2))
protected[this] def upperPairAbove(upper1: A, flags1: Int, upper2: A, flags2: Int): Boolean =
upper1 > upper2 || upper1 === upper2 && (isClosedUpper(flags1) || isOpenUpper(flags2))
def isEmpty: Boolean =
this.isInstanceOf[Empty[_]]
def nonEmpty: Boolean =
!isEmpty
def isPoint: Boolean =
this.isInstanceOf[Point[_]]
def contains(t: A): Boolean =
hasAtOrBelow(t) && hasAtOrAbove(t)
def doesNotContain(t: A): Boolean =
!hasAtOrBelow(t) || !hasAtOrAbove(t)
def crosses(t: A): Boolean =
hasBelow(t) && hasAbove(t)
def crossesZero(implicit ev: AdditiveMonoid[A]): Boolean =
hasBelow(ev.zero) && hasAbove(ev.zero)
def isBounded: Boolean =
this match {
case Below(_, _) | Above(_, _) | All() => true
case _ => false
}
def lowerBound: Bound[A]
def upperBound: Bound[A]
def mapBounds[B: Order](f: A => B): Interval[B] =
Interval.fromBounds(lowerBound.map(f), upperBound.map(f))
def fold[B](f: (Bound[A], Bound[A]) => B): B =
f(lowerBound, upperBound)
def isSupersetOf(rhs: Interval[A]): Boolean = (lhs, rhs) match {
// deal with All, Empty and Point on either left or right side
case (All(), _) => true
case (_, All()) => false
case (_, Empty()) => true
case (Empty(), _) => false
case (Point(lhsval), Point(rhsval)) => lhsval === rhsval
case (Point(_), _) => false // rhs cannot be Empty or Point
case (_, Point(rhsval)) => lhs.contains(rhsval)
// remaining cases are Above, Below and Bounded, we deal first with the obvious false
case (Above(_, _), Below(_, _)) => false
case (Below(_, _), Above(_, _)) => false
case (Bounded(_, _, _), Below(_, _)) => false
case (Bounded(_, _, _), Above(_, _)) => false
case (Above(lower1, flags1), Bounded(lower2, _, flags2)) =>
lowerPairBelow(lower1, flags1, lower2, flags2)
case (Above(lower1, flags1), Above(lower2, flags2)) =>
lowerPairBelow(lower1, flags1, lower2, flags2)
case (Below(upper1, flags1), Below(upper2, flags2)) =>
upperPairAbove(upper1, flags1, upper2, flags2)
case (Below(upper1, flags1), Bounded(_, upper2, flags2)) =>
upperPairAbove(upper1, flags1, upper2, flags2)
case (Bounded(lower1, upper1, flags1), Bounded(lower2, upper2, flags2)) =>
lowerPairBelow(lower1, flags1, lower2, flags2) &&
upperPairAbove(upper1, flags1, upper2, flags2)
}
def isProperSupersetOf(rhs: Interval[A]): Boolean =
lhs != rhs && (lhs isSupersetOf rhs)
def isSubsetOf(rhs: Interval[A]): Boolean =
rhs isSupersetOf lhs
def isProperSubsetOf(rhs: Interval[A]): Boolean =
rhs isProperSupersetOf lhs
// Does this interval contain any points above t ?
def hasAbove(t: A): Boolean = this match {
case Empty() => false
case Point(p) => p > t
case Below(upper, _) => upper > t
case Bounded(_, upper, _) => upper > t
case All() => true
case Above(_, _) => true
}
// Does this interval contain any points below t ?
def hasBelow(t: A): Boolean = this match {
case Empty() => false
case Point(p) => p < t
case Above(lower, _) => lower < t
case Bounded(lower, _, _) => lower < t
case Below(_, _) => true
case All() => true
}
// Does this interval contains any points at or above t ?
def hasAtOrAbove(t: A): Boolean = this match {
case _: Empty[_] => false
case Point(p) => p >= t
case Below(upper, flags) =>
upper > t || isClosedUpper(flags) && upper === t
case Bounded(lower, upper, flags) =>
upper > t || isClosedUpper(flags) && upper === t
case _: Above[_] => true
case _: All[_] => true
}
// Does this interval contains any points at or below t ?
def hasAtOrBelow(t: A): Boolean = this match {
case _: Empty[_] => false
case Point(p) => p <= t
case Above(lower, flags) =>
lower < t || isClosedLower(flags) && lower === t
case Bounded(lower, upper, flags) =>
lower < t || isClosedLower(flags) && lower === t
case _: Below[_] => true
case _: All[_] => true
}
def isAt(t: A): Boolean = this match {
case Point(p) => t === p
case _ => false
}
def intersects(rhs: Interval[A]): Boolean =
!(lhs intersect rhs).isEmpty
def &(rhs: Interval[A]): Interval[A] =
lhs intersect rhs
def intersect(rhs: Interval[A]): Interval[A] =
Interval.fromBounds(maxLower(lhs.lowerBound, rhs.lowerBound, true),
minUpper(lhs.upperBound, rhs.upperBound, true))
def unary_~(): List[Interval[A]] =
this match {
case All() =>
Nil
case Empty() =>
List(All())
case Above(lower, lf) =>
List(Below(lower, lowerFlagToUpper(reverseLowerFlag(lf))))
case Below(upper, uf) =>
List(Above(upper, upperFlagToLower(reverseUpperFlag(uf))))
case Point(p) =>
List(Interval.below(p), Interval.above(p))
case Bounded(lower, upper, flags) =>
val lx = lowerFlagToUpper(reverseLowerFlag(lowerFlag(flags)))
val ux = upperFlagToLower(reverseUpperFlag(upperFlag(flags)))
List(Below(lower, lx), Above(upper, ux))
}
def --(rhs: Interval[A]): List[Interval[A]] =
if (lhs intersects rhs) {
(~rhs).map(lhs & _).filter(_.nonEmpty)
} else {
if (lhs.isEmpty) Nil else List(lhs)
}
def split(t: A): (Interval[A], Interval[A]) =
(this intersect Interval.below(t), this intersect Interval.above(t))
def splitAtZero(implicit ev: AdditiveMonoid[A]): (Interval[A], Interval[A]) =
split(ev.zero)
def mapAroundZero[B](f: Interval[A] => B)(implicit ev: AdditiveMonoid[A]): (B, B) =
splitAtZero match {
case (a, b) => (f(a), f(b))
}
def |(rhs: Interval[A]): Interval[A] =
lhs union rhs
def union(rhs: Interval[A]): Interval[A] =
Interval.fromBounds(minLower(lhs.lowerBound, rhs.lowerBound, false),
maxUpper(lhs.upperBound, rhs.upperBound, false))
override def toString(): String = this match {
case All() =>
"(-∞, ∞)"
case Empty() =>
"(Ø)"
case Above(lower, flags) =>
if (isClosedLower(flags)) s"[$lower, ∞)" else s"($lower, ∞)"
case Below(upper, flags) =>
if (isClosedUpper(flags)) s"(-∞, $upper]" else s"(-∞, $upper)"
case Point(p) =>
s"[$p]"
case Bounded(lower, upper, flags) =>
val s1 = if (isClosedLower(flags)) s"[$lower" else s"($lower"
val s2 = if (isClosedUpper(flags)) s"$upper]" else s"$upper)"
s"$s1, $s2"
}
def abs(implicit m: AdditiveGroup[A]): Interval[A] =
if (crossesZero) { // only Bounded, Above or Below can cross zero
this match {
case Bounded(lower, upper, fs) =>
val x = -lower
if (x > upper) Bounded(m.zero, x, lowerFlagToUpper(fs))
else if (upper > x) Bounded(m.zero, upper, upperFlag(fs))
else Bounded(m.zero, x, lowerFlagToUpper(fs) & upperFlag(fs))
case _ => // Above or Below
Interval.atOrAbove(m.zero)
}
} else if (hasBelow(m.zero)) {
-this
} else {
this
}
// for all a in A, and all b in B, (A vmin B) is the interval that contains all (a min b)
def vmin(rhs: Interval[A])(implicit m: AdditiveMonoid[A]): Interval[A] =
Interval.fromBounds(minLower(lhs.lowerBound, rhs.lowerBound, true),
minUpper(lhs.upperBound, rhs.upperBound, true))
// for all a in A, and all b in B, (A vmax B) is the interval that contains all (a max b)
def vmax(rhs: Interval[A])(implicit m: AdditiveMonoid[A]): Interval[A] =
Interval.fromBounds(maxLower(lhs.lowerBound, rhs.lowerBound, true),
maxUpper(lhs.upperBound, rhs.upperBound, true))
def combine(rhs: Interval[A])(f: (A, A) => A): Interval[A] = {
val lb = lhs.lowerBound.combine(rhs.lowerBound)(f)
val ub = lhs.upperBound.combine(rhs.upperBound)(f)
Interval.fromBounds(lb, ub)
}
def +(rhs: Interval[A])(implicit ev: AdditiveSemigroup[A]): Interval[A] =
combine(rhs)(_ + _)
def -(rhs: Interval[A])(implicit ev: AdditiveGroup[A]): Interval[A] =
lhs + (-rhs)
// scalastyle:off method.length
def *(rhs: Interval[A])(implicit ev: Semiring[A]): Interval[A] = {
val z = ev.zero
def aboveAbove(lower1: A, lf1: Int, lower2: A, lf2: Int): Interval[A] = {
val lower1s = lower1.compare(z)
val lower2s = lower2.compare(z)
if (lower1s < 0 || lower2s < 0) All() else {
val strongZero = (lower1s == 0 && isClosedLower(lf1)) || (lower2s == 0 && isClosedLower(lf2))
val flags = if (strongZero) 0 else lf1 | lf2
Above(lower1 * lower2, flags)
}
}
def belowBelow(upper1: A, uf1: Int, upper2: A, uf2: Int): Interval[A] = {
val upper1s = upper1.compare(z)
val upper2s = upper2.compare(z)
if (upper1s > 0 || upper2s > 0) All() else {
val strongZero = (upper1s == 0 && isClosedUpper(uf1)) || (upper2s == 0 && isClosedUpper(uf2))
val flags = if (strongZero) 0 else upperFlagToLower(uf1) | upperFlagToLower(uf2)
Above(upper1 * upper2, flags)
}
}
def aboveBelow(lower1: A, lf1: Int, upper2: A, uf2: Int): Interval[A] = {
val lower1s = lower1.compare(z)
val upper2s = upper2.compare(z)
if (lower1s < 0 || upper2s > 0) All () else {
val strongZero = (lower1s == 0 && isClosedLower(lf1)) || (upper2s == 0 && isClosedUpper(uf2))
val flags = if (strongZero) 0 else lowerFlagToUpper(lf1) | uf2
Below(lower1 * upper2, flags)
}
}
def aboveBounded(lower1: A, lf1: Int, lower2: A, upper2: A, flags2: Int): Interval[A] = {
val lower1s = lower1.compare(z)
val lower2s = lower2.compare(z)
val upper2s = upper2.compare(z)
val hasBelowZero1 = lower1s < 0
val hasBelowZero2 = lower2s < 0
val hasAboveZero2 = upper2s > 0
if (hasBelowZero2 && hasAboveZero2) All() // bounded interval crosses zero
else if (hasAboveZero2) { // bounded interval is fully above zero
if (hasBelowZero1) // the minimal point is lower1(-) * upper2(+)
Above(lower1 * upper2, lf1 | upperFlagToLower(flags2))
else { // the minimal point is lower1(+) * lower2(+)
val strongZero = (lower1s == 0 && isClosedLower(lf1)) || (lower2s == 0 && isClosedLower(flags2))
val flags = if (strongZero) 0 else lf1 | lowerFlag(flags2)
Above(lower1 * lower2, flags)
}
} else { // bounded interval is fully below zero
assert(hasBelowZero2)
if (hasBelowZero1) { // the maximal point is lower1(-) * lower2(-)
val strongZero = (lower1s == 0 && isClosedLower(lf1)) || (lower2s == 0 && isClosedLower(flags2))
val flags = if (strongZero) 0 else lowerFlagToUpper(lf1) | lowerFlagToUpper(flags2)
Below(lower1 * lower2, flags)
}
else { // the maximal point is lower1(+) * upper2(-)
val strongZero = (lower1s == 0 && isClosedLower(lf1)) || (upper2s == 0 && isClosedUpper(flags2))
val flags = if (strongZero) 0 else lowerFlagToUpper(lf1) | upperFlag(flags2)
Below(lower1 * upper2, flags)
}
}
}
def belowBounded(upper1: A, uf1: Int, lower2: A, upper2: A, flags2: Int): Interval[A] = {
val upper1s = upper1.compare(z)
val lower2s = lower2.compare(z)
val upper2s = upper2.compare(z)
val hasAboveZero1 = upper1s > 0
val hasBelowZero2 = lower2s < 0
val hasAboveZero2 = upper2s > 0
if (hasBelowZero2 && hasAboveZero2) All() // bounded interval crosses zero
else if (hasAboveZero2) { // bounded interval is fully above zero
if (hasAboveZero1) // the maximal point is upper1(+) * upper2(+)
Below(upper1 * upper2, uf1 | upperFlag(flags2))
else { // the maximal point is upper1(+) * lower2(-)
val strongZero = (upper1s == 0 && isClosedUpper(uf1)) || (lower2s == 0 && isClosedLower(flags2))
val flags = if (strongZero) 0 else uf1 | lowerFlagToUpper(flags2)
Below(upper1 * lower2, flags)
}
} else { // bounded interval is fully below zero
if (hasAboveZero1) { // the minimal point is upper1(+) * lower2(-)
val strongZero = (lower2s == 0 && isClosedLower(flags2))
val flags = if (strongZero) 0 else upperFlagToLower(uf1) | lowerFlag(flags2)
Above(upper1 * lower2, flags)
}
else { // the minimal point is upper1(-) * upper2(-)
val strongZero = (upper1s == 0 && isClosedUpper(uf1)) || (upper2s == 0 && isClosedUpper(flags2))
val flags = if (strongZero) 0 else upperFlagToLower(uf1) | upperFlagToLower(flags2)
Above(upper1 * upper2, flags)
}
}
}
def boundedBounded(bd1: Bounded[A], bd2: Bounded[A]): Interval[A] = {
val lb1 = bd1.lowerBound
val ub1 = bd1.upperBound
val lb2 = bd2.lowerBound
val ub2 = bd2.upperBound
val lb1sz = lb1.a === z && lb1.isClosed
val lb2sz = lb2.a === z && lb2.isClosed
val ub1sz = ub1.a === z && ub1.isClosed
val ub2sz = ub2.a === z && ub2.isClosed
val ll = if (lb1sz || lb2sz) interval.Closed(z) else (lb1 *~ lb2)
val lu = if (lb1sz || ub2sz) interval.Closed(z) else (lb1 *~ ub2)
val ul = if (ub1sz || lb2sz) interval.Closed(z) else (ub1 *~ lb2)
val uu = if (ub1sz || ub2sz) interval.Closed(z) else (ub1 *~ ub2)
ValueBound.union4(ll, lu, ul, uu)
}
(lhs, rhs) match {
case (Empty(), _) => lhs
case (_, Empty()) => rhs
case (Point(lv), _) => rhs * lv // use multiplication by scalar
case (_, Point(rv)) => lhs * rv
// now lhs and rhs are both intervals with more that one point
case (All(), _) => lhs
case (_, All()) => rhs
case (Above(lower1, lf1), Above(lower2, lf2)) => aboveAbove(lower1, lf1, lower2, lf2)
case (Above(lower1, lf1), Below(upper2, uf2)) => aboveBelow(lower1, lf1, upper2, uf2)
case (Below(upper1, uf1), Above(lower2, lf2)) => aboveBelow(lower2, lf2, upper1, uf1)
case (Below(upper1, uf1), Below(upper2, uf2)) => belowBelow(upper1, uf1, upper2, uf2)
case (Above(lower1, lf1), Bounded(lower2, upper2, flags2)) =>
aboveBounded(lower1, lf1, lower2, upper2, flags2)
case (Bounded(lower1, upper1, flags1), Above(lower2, lf2)) =>
aboveBounded(lower2, lf2, lower1, upper1, flags1)
case (Below(upper1, uf1), Bounded(lower2, upper2, flags2)) =>
belowBounded(upper1, uf1, lower2, upper2, flags2)
case (Bounded(lower1, upper1, flags1), Below(upper2, uf2)) =>
belowBounded(upper2, uf2, lower1, upper1, flags1)
case (bd1: Bounded[A], bd2: Bounded[A]) => boundedBounded(bd1, bd2)
}
}
def reciprocal(implicit ev: Field[A]): Interval[A] = {
val z = ev.zero
def error: Nothing = throw new java.lang.ArithmeticException("/ by zero")
this match {
case All() => error
case Empty() => this
case Above(lower, lf) =>
(lower.compare(z), isClosedLower(lf)) match {
case (x, _) if x < 0 => error // crosses zero
case (0, true) => error // contains zero
case (0, false) => this
case _ => Bounded(z, lower.reciprocal, 1 | lowerFlagToUpper(lf))
}
case Below(upper, uf) =>
(upper.compare(z), isClosedUpper(uf)) match {
case (x, _) if x > 0 => error // crosses zero
case (0, true) => error // contains zero
case (0, false) => this
case _ => Bounded(upper.reciprocal, z, 2 | upperFlagToLower(uf))
}
case Point(v) => Point(v.reciprocal)
case Bounded(lower, upper, flags) =>
(lower.compare(z), upper.compare(z), isClosedLower(flags), isClosedUpper(flags)) match {
case (x, y, _, _) if x < 0 && y > 0 => error // crosses zero
case (0, _, true, _) => error // contains zero
case (_, 0, _, true) => error // contains zero
case (0, _, false, _) => Above(upper.reciprocal, upperFlagToLower(flags))
case (_, 0, _, false) => Below(lower.reciprocal, lowerFlagToUpper(flags))
case _ => Bounded(upper.reciprocal, lower.reciprocal, swapFlags(flags))
}
}
}
// scalastyle:on method.length
def /(rhs: Interval[A])(implicit ev: Field[A]): Interval[A] =
(lhs, rhs) match {
case (Point(lv), _) => rhs.reciprocal * lv
case (_, Point(rv)) => lhs * (rv.reciprocal)
case (_, _) => lhs * rhs.reciprocal
}
def /(rhs: A)(implicit ev: Field[A]): Interval[A] =
lhs * rhs.reciprocal
def +(rhs: A)(implicit ev: AdditiveSemigroup[A]): Interval[A] =
this match {
case Point(v) => Point(v + rhs)
case Bounded(l, u, flags) => Bounded(l + rhs, u + rhs, flags)
case Above(l, lf) => Above(l + rhs, lf)
case Below(u, uf) => Below(u + rhs, uf)
case All() | Empty() => this
}
def -(rhs: A)(implicit ev: AdditiveGroup[A]): Interval[A] =
this + (-rhs)
def unary_-()(implicit ev: AdditiveGroup[A]): Interval[A] =
this match {
case Point(v) => Point(-v)
case Bounded(l, u, f) => Bounded(-u, -l, swapFlags(f))
case Above(l, lf) => Below(-l, lowerFlagToUpper(lf))
case Below(u, uf) => Above(-u, upperFlagToLower(uf))
case All() | Empty() => this
}
def *(rhs: A)(implicit ev: Semiring[A]): Interval[A] =
if (rhs < ev.zero) {
this match {
case Point(v) => Point(v * rhs)
case Bounded(l, u, f) => Bounded(u * rhs, l * rhs, swapFlags(f))
case Above(l, lf) => Below(l * rhs, lowerFlagToUpper(lf))
case Below(u, uf) => Above(u * rhs, upperFlagToLower(uf))
case All() | Empty() => this
}
} else if (rhs === ev.zero) {
Interval.zero
} else {
this match {
case Point(v) => Point(v * rhs)
case Bounded(l, u, flags) => Bounded(l * rhs, u * rhs, flags)
case Above(l, lf) => Above(l * rhs, lf)
case Below(u, uf) => Below(u * rhs, uf)
case All() | Empty() => this
}
}
def pow(k: Int)(implicit r: Ring[A]): Interval[A] = {
def loop(b: Interval[A], k: Int, extra: Interval[A]): Interval[A] =
if (k == 1)
b * extra
else
loop(b * b, k >>> 1, if ((k & 1) == 1) b * extra else extra)
if (k < 0) {
throw new IllegalArgumentException(s"negative exponent: $k")
} else if (k == 0) {
Interval.point(r.one)
} else if (k == 1) {
this
} else if ((k & 1) == 0) {
val t = abs
loop(t, k - 1, t)
} else {
loop(this, k - 1, this)
}
}
def nroot(k: Int)(implicit r: Ring[A], n: NRoot[A]): Interval[A] = {
if (k == 1) {
this
} else if ((k & 1) == 0 && hasBelow(r.zero)) {
throw new IllegalArgumentException("can't take even root of negative number")
} else {
this match {
case All() | Empty() => this
case Point(v) => Point(v.nroot(k))
case Above(l, lf) => Above(l.nroot(k), lf)
case Below(u, uf) => Below(u.nroot(k), uf)
case Bounded(l, u, flags) => Bounded(l.nroot(k), u.nroot(k), flags)
}
}
}
def sqrt(implicit r: Ring[A], n: NRoot[A]): Interval[A] = nroot(2)
def top(epsilon: A)(implicit r: AdditiveGroup[A]): Option[A] =
this match {
case Empty() | All() | Above(_, _) =>
None // TOCHECK: changed semantics, Empty().top == None
case Below(upper, uf) =>
Some(if (isOpenUpper(uf)) upper - epsilon else upper)
case Point(v) =>
Some(v)
case Bounded(_, upper, flags) =>
Some(if (isOpenUpper(flags)) upper - epsilon else upper)
}
def bottom(epsilon: A)(implicit r: AdditiveGroup[A]): Option[A] =
this match {
case Empty() | All() | Below(_, _) =>
None // TOCHECK: changed semantics, Empty().bottom == None
case Above(lower, lf) =>
Some(if (isOpenLower(lf)) lower + epsilon else lower)
case Point(v) =>
Some(v)
case Bounded(lower, _, flags) =>
Some(if (isOpenLower(flags)) lower + epsilon else lower)
}
import spire.random.{Dist, Uniform}
def dist(min: A, max: A, epsilon: A)(implicit u: Uniform[A], r: AdditiveGroup[A]): Dist[A] =
u(bottom(epsilon).getOrElse(min), top(epsilon).getOrElse(max))
/**
* Apply the given polynomial to the interval.
*
* For every point contained in the interval, this method maps that
* point through the given polynomial. The resulting interval is the
* set of all the translated points. I.e.
*
* result = { p(x) | x ∈ interval }
*
*/
def translate(p: Polynomial[A])(implicit ev: Field[A]): Interval[A] = {
val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point(c), e) }
val p2 = Polynomial(terms2)
p2(this)
}
// optional unicode operators
def ∋(rhs: A): Boolean = lhs contains rhs
def ∌(rhs: A): Boolean = !(lhs contains rhs)
def ∈:(a: A): Boolean = lhs contains a
def ∉:(a: A): Boolean = !(lhs contains a)
def ∩(rhs: Interval[A]): Interval[A] = lhs intersect rhs
def ∪(rhs: Interval[A]): Interval[A] = lhs union rhs
def \(rhs: Interval[A]): List[Interval[A]] = lhs -- rhs
def ⊂(rhs: Interval[A]): Boolean = lhs isProperSubsetOf rhs
def ⊃(rhs: Interval[A]): Boolean = lhs isProperSupersetOf rhs
def ⊆(rhs: Interval[A]): Boolean = lhs isSubsetOf rhs
def ⊇(rhs: Interval[A]): Boolean = lhs isSupersetOf rhs
// xyz
// find the "first" value in our iterator. if step is positive we
// proceed from the lower bound up, and if negative from the upper
// bound down. thus, we always use addition when dealing with the
// step.
private[this] def getStart(bound: Bound[A], step: A, unboundError: String)(implicit ev: AdditiveMonoid[A]): A =
bound match {
case EmptyBound() => ev.zero
case Open(x) => x + step
case Closed(x) => x
case Unbound() => throw new IllegalArgumentException(unboundError)
}
/**
* Build an Iterator[A] from an Interval[A] and a (step: A)
* parameter.
*
* A positive 'step' means we are proceeding from the lower bound
* up, and a negative 'step' means we are proceeding from the upper
* bound down. In each case, the interval must be bounded on the
* side we are starting with (though it may be unbound on the
* opposite side). A zero 'step' is not allowed.
*
* The step is repeatedly added to the starting parameter as long as
* the sum remains within the interval. This means that arithmetic
* error can accumulate (e.g. with doubles). However, this method
* does overflow checking to ensure that Intervals parameterized on
* integer types will behave correctly.
*
* Users who want to avoid using arithmetic error should consider
* starting with an Interval[Rational], calling iterator with the
* exact step desired, then mapping to the original type
* (e.g. Double). For example:
*
* val ns = Interval.closed(Rational(0), Rational(5))
* val it = ns.iterator(Rational(1, 7)).map(_.toDouble)
*
* This method provides some of the same functionality as Scala's
* NumericRange class.
*/
def iterator(step: A)(implicit ev: AdditiveMonoid[A], nt: NumberTag[A]): Iterator[A] = {
// build an iterator, using start, step, and continue.
// this can be used in cases where we don't have to worry about
// overflow (e.g. Double, or Rational).
def iter0(start: A, continue: A => Boolean): Iterator[A] =
new Iterator[A] {
var x: A = start
def hasNext: Boolean = continue(x)
def next: A = {
val r = x
x += step
r
}
}
// build an iterator, using start, step, continue, and test.
// test is used to detect overflow in cases where it can happen.
// it won't always be necessary but there isn't currently a typeclass
// that lets us know when we need to do it.
def iter1(start: A, continue: A => Boolean, test: (A, A) => Boolean): Iterator[A] =
new Iterator[A] {
var x: A = start
var ok: Boolean = true
def hasNext: Boolean = ok && continue(x)
def next: A = {
val r = x
val next = x + step
if (test(x, next)) { x = next } else { ok = false }
r
}
}
def iter(start: A, safe: Boolean, continue: A => Boolean, test: (A, A) => Boolean): Iterator[A] =
if (nt.overflows && !safe) {
iter1(start, continue, test)
} else {
iter0(start, continue)
}
// build the actual iterator, which primarily relies on figuring
// out which "direction" we are moving (based on the sign of the
// step) as well as what kind of limiting bounds we have.
if (step === ev.zero) {
throw new IllegalArgumentException("zero step")
} else if (step > ev.zero) {
val x = getStart(lowerBound, step, "positive step with no lower bound")
val test = (x1: A, x2: A) => x1 < x2
upperBound match {
case EmptyBound() => Iterator.empty
case Unbound() => iter(x, false, _ => true, test)
case Closed(y) => iter(x, y + step > y, _ <= y, test)
case Open(y) => iter(x, y + step > y, _ < y, test)
}
} else {
val x = getStart(upperBound, step, "negative step with no lower bound")
val test = (x1: A, x2: A) => x1 > x2
lowerBound match {
case EmptyBound() => Iterator.empty
case Unbound() => iter(x, false, _ => true, test)
case Closed(y) => iter(x, y + step < y, _ >= y, test)
case Open(y) => iter(x, y + step < y, _ > y, test)
}
}
}
def loop(step: A)(f: A => Unit)(implicit ev: AdditiveMonoid[A], nt: NumberTag[A]): Unit =
iterator(step).foreach(f)
def foldOver[B](init: B, step: A)(f: (B, A) => B)(implicit ev: AdditiveMonoid[A], nt: NumberTag[A]): B =
iterator(step).foldLeft(init)(f)
}
case class All[A: Order] private[spire] () extends Interval[A] {
def lowerBound: Unbound[A] = Unbound()
def upperBound: Unbound[A] = Unbound()
}
case class Above[A: Order] private[spire] (lower: A, flags: Int) extends Interval[A] {
def lowerBound: ValueBound[A] = if (isOpenLower(flags)) Open(lower) else Closed(lower)
def upperBound: Unbound[A] = Unbound()
}
case class Below[A: Order] private[spire] (upper: A, flags: Int) extends Interval[A] {
def lowerBound: Unbound[A] = Unbound()
def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper)
}
// Bounded, non-empty interval with lower < upper
case class Bounded[A: Order] private[spire] (lower: A, upper: A, flags: Int) extends Interval[A] {
require(lower < upper) // TODO: remove after refactoring
def lowerBound: ValueBound[A] = if (isOpenLower(flags)) Open(lower) else Closed(lower)
def upperBound: ValueBound[A] = if (isOpenUpper(flags)) Open(upper) else Closed(upper)
}
case class Point[A: Order] private[spire] (value: A) extends Interval[A] {
def lowerBound: Closed[A] = Closed(value)
def upperBound: Closed[A] = Closed(value)
}
case class Empty[A: Order] private[spire] () extends Interval[A] {
def lowerBound: EmptyBound[A] = EmptyBound()
def upperBound: EmptyBound[A] = EmptyBound()
}
object Interval {
import interval._
private[spire] def withFlags[A: Order](lower: A, upper: A, flags: Int): Interval[A] =
if (lower < upper)
Bounded(lower, upper, flags)
else if (lower === upper && flags == 0)
Point(lower)
else
Interval.empty[A]
def empty[A](implicit o: Order[A]): Interval[A] = Empty[A]
def point[A: Order](a: A): Interval[A] = Point(a)
def zero[A](implicit o: Order[A], r: Semiring[A]): Interval[A] = Point(r.zero)
def all[A: Order]: Interval[A] = All[A]()
def apply[A: Order](lower: A, upper: A): Interval[A] = closed(lower, upper)
/**
* Return an Interval[Rational] that corresponds to the error bounds
* for the given Double value.
*
* The error bounds are represented as a closed interval, whose
* lower point is midway between d and the adjacent Double value
* below it. Similarly, the upper bound is the point midway between
* d and the adjacent Double value above it.
*
* There are three Double values that return "special" intervals:
*
* Infinity => Interval.above(Double.MaxValue)
* -Infinity => Interval.below(Double.MinValue)
* NaN => Interval.empty
*/
def errorBounds(d: Double): Interval[Rational] =
if (d == Double.PositiveInfinity) {
Interval.above(Double.MaxValue)
} else if (d == Double.NegativeInfinity) {
Interval.below(Double.MinValue)
} else if (isNaN(d)) {
Interval.empty[Rational]
} else {
val n0 = Rational(Math.nextAfter(d, -1.0))
val n1 = Rational(d)
val n2 = Rational(Math.nextUp(d))
Interval((n1 - n0) / 2 + n0, (n2 - n1) / 2 + n1)
}
@inline private[spire] final def closedLowerFlags = 0
@inline private[spire] final def openLowerFlags = 1
@inline private[spire] final def closedUpperFlags = 0
@inline private[spire] final def openUpperFlags = 2
/**
* Constructs an interval from bounds.
*
* This method assumes that lower < upper to avoid comparisons.
*
* - When one of the arguments is Unbound, the result will be All,
* Above(x, _), or Below(y, _).
*
* - When both arguments are Open/Closed (e.g. Open(x), Open(y)),
* then x < y and the result will be a Bounded interval.
*
* - If both arguments are EmptyBound, the result is Empty.
*
* - Any other arguments are invalid.
*
* This method cannot construct Point intervals.
*/
private[spire] def fromOrderedBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (Closed(x), Closed(y)) => Bounded(x, y, closedLowerFlags | closedUpperFlags)
case (Open(x), Open(y)) => Bounded(x, y, openLowerFlags | openUpperFlags)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Closed(x), Open(y)) => Bounded(x, y, closedLowerFlags | openUpperFlags)
case (Open(x), Closed(y)) => Bounded(x, y, openLowerFlags | closedUpperFlags)
case (Unbound(), Unbound()) => all
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}
def fromBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (Closed(x), Closed(y)) => closed(x, y)
case (Open(x), Open(y)) => open(x, y)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Closed(x), Open(y)) => openUpper(x, y)
case (Open(x), Closed(y)) => openLower(x, y)
case (Unbound(), Unbound()) => all
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}
def closed[A: Order](lower: A, upper: A): Interval[A] = {
val c = lower compare upper
if (c < 0) Bounded(lower, upper, 0)
else if (c == 0) Point(lower)
else Interval.empty[A]
}
def open[A: Order](lower: A, upper: A): Interval[A] =
if (lower < upper) Bounded(lower, upper, 3) else Interval.empty[A]
def openLower[A: Order](lower: A, upper: A): Interval[A] =
if (lower < upper) Bounded(lower, upper, 1) else Interval.empty[A]
def openUpper[A: Order](lower: A, upper: A): Interval[A] =
if (lower < upper) Bounded(lower, upper, 2) else Interval.empty[A]
def above[A: Order](a: A): Interval[A] = Above(a, 1)
def below[A: Order](a: A): Interval[A] = Below(a, 2)
def atOrAbove[A: Order](a: A): Interval[A] = Above(a, 0)
def atOrBelow[A: Order](a: A): Interval[A] = Below(a, 0)
private val NullRe = "^ *\\( *Ø *\\) *$".r
private val SingleRe = "^ *\\[ *([^,]+) *\\] *$".r
private val PairRe = "^ *(\\[|\\() *(.+?) *, *(.+?) *(\\]|\\)) *$".r
def apply(s: String): Interval[Rational] =
s match {
case NullRe() => Interval.empty[Rational]
case SingleRe(x) => Interval.point(Rational(x))
case PairRe(left, x, y, right) =>
(left, x, y, right) match {
case ("(", "-∞", "∞", ")") => Interval.all[Rational]
case ("(", "-∞", y, ")") => Interval.below(Rational(y))
case ("(", "-∞", y, "]") => Interval.atOrBelow(Rational(y))
case ("(", x, "∞", ")") => Interval.above(Rational(x))
case ("[", x, "∞", ")") => Interval.atOrAbove(Rational(x))
case ("[", x, y, "]") => Interval.closed(Rational(x), Rational(y))
case ("(", x, y, ")") => Interval.open(Rational(x), Rational(y))
case ("[", x, y, ")") => Interval.openUpper(Rational(x), Rational(y))
case ("(", x, y, "]") => Interval.openLower(Rational(x), Rational(y))
case _ => throw new NumberFormatException("Impossible: " + s)
}
case _ => throw new NumberFormatException("For input string: " + s)
}
implicit def eq[A: Eq]: Eq[Interval[A]] =
new Eq[Interval[A]] {
def eqv(x: Interval[A], y: Interval[A]): Boolean = x == y
}
implicit def semiring[A](implicit ev: Ring[A], o: Order[A]): Semiring[Interval[A]] =
new Semiring[Interval[A]] {
def zero: Interval[A] = Interval.point(ev.zero)
def plus(x: Interval[A], y: Interval[A]): Interval[A] = x + y
def times(x: Interval[A], y: Interval[A]): Interval[A] = x * y
override def pow(x: Interval[A], k: Int): Interval[A] = x.pow(k)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy