com.stripe.rainier.compute.Coefficients.scala Maven / Gradle / Ivy
The newest version!
package com.stripe.rainier.compute
import scala.annotation.tailrec
sealed trait Coefficients extends Product with Serializable {
def isEmpty: Boolean
def size: Int
def coefficients: Iterable[Constant]
def terms: Iterable[NonConstant]
def toList: List[(NonConstant, Constant)]
def toMap: Map[NonConstant, Constant]
def withComplements: Iterable[(NonConstant, Constant, Coefficients)]
def mapCoefficients(fn: Constant => Constant): Coefficients
def merge(other: Coefficients): Coefficients
def +(pair: (NonConstant, Constant)): Coefficients
object Coefficients {
def apply(term: NonConstant): Coefficients = apply(term -> Constant.One)
def apply(pair: (NonConstant, Constant)): Coefficients =
if (pair._2.isZero)
One(pair._1, pair._2)
def apply(seq: Seq[(NonConstant, Constant)]): Coefficients = {
val filtered = seq.filterNot(_._2.isZero)
if (filtered.isEmpty)
else if (filtered.size == 1)
private final case object EmptyCoefficients extends Coefficients {
val isEmpty = true
val size = 0
val coefficients = Nil
val terms = Nil
val toList = Nil
val toMap = Map.empty[NonConstant, Constant]
val withComplements = Nil
def mapCoefficients(fn: Constant => Constant) = this
def +(pair: (NonConstant, Constant)) = apply(pair)
def merge(other: Coefficients) = other
val Empty: Coefficients = EmptyCoefficients
case class One(term: NonConstant, coefficient: Constant)
extends Coefficients {
val size = 1
val isEmpty = false
def coefficients = List(coefficient)
def terms = List(term)
def toList = List((term, coefficient))
def toMap = Map(term -> coefficient)
def withComplements = List((term, coefficient, Empty))
def mapCoefficients(fn: Constant => Constant) =
One(term, fn(coefficient))
def merge(other: Coefficients) = other + (term -> coefficient)
def +(pair: (NonConstant, Constant)) =
if (pair._1 == term) {
val newCoefficient = coefficient + pair._2
if (newCoefficient.isZero)
One(term, newCoefficient)
} else {
Coefficients(pair :: toList)
case class Many(toMap: Map[NonConstant, Constant], terms: List[NonConstant])
extends Coefficients {
val isEmpty = false
def size = toMap.size
def coefficients = toMap.values
def toList = { x =>
x -> toMap(x)
def mapCoefficients(fn: Constant => Constant) =
Many( { case (x, a) => x -> fn(a) }, terms)
def withComplements = {
def loop(
acc: List[(NonConstant, Constant, Coefficients)],
a: List[NonConstant],
b: List[NonConstant]): List[(NonConstant, Constant, Coefficients)] =
b match {
case head :: tail =>
val complementTerms =
if (a.size > tail.size)
tail ::: a
a ::: tail
val complement =
if (complementTerms.size == 1)
One(complementTerms.head, toMap(complementTerms.head))
Many(toMap - head, complementTerms)
loop((head, toMap(head), complement) :: acc, head :: a, tail)
case Nil =>
loop(Nil, Nil, terms)
def merge(other: Coefficients) =
if (other.size > size)
other.toList.foldLeft(this: Coefficients) {
case (acc, pair) => acc + pair
def +(pair: (NonConstant, Constant)) = {
val (term, coefficient) = pair
if (toMap.contains(term)) {
val newCoefficient = coefficient + toMap(term)
if (newCoefficient.isZero) {
val newMap = toMap - term
val newTerms = terms.filterNot(_ == term)
if (newTerms.size == 1)
One(newTerms.head, newMap.values.head)
Many(newMap, newTerms)
} else {
Many(toMap + (term -> newCoefficient), terms)
} else {
Many(toMap + pair, term :: terms)