All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.compute.PartialEvaluator.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.compute

class PartialEvaluator(var noChange: Set[Real], rowIndex: Int) {
  var cache = Map.empty[Real, Real]

  def next() = new PartialEvaluator(noChange, rowIndex + 1)

  def apply(real: Real): (Real, Boolean) =
    if (noChange.contains(real))
      (real, false)
    else {
      cache.get(real) match {
        case Some(v) => (v, true)
        case None =>
          val (v, changed) = eval(real)
          if (changed)
            cache += real -> v
          else
            noChange += real
          (v, changed)
      }
    }

  private def eval(real: Real): (Real, Boolean) = real match {
    case Scalar(_) => (real, false)
    case c: Column =>
      (c.values(rowIndex), true)
    case l: Line =>
      val terms = l.ax.toList.map { case (x, a) => (apply(x), apply(a)) }
      val (b, bModified) = apply(l.b)
      val anyModified =
        terms.exists { case ((_, m1), (_, m2)) => m1 || m2 } || bModified
      if (anyModified) {
        val sum = Real.sum(terms.map { case ((x, _), (a, _)) => x * a })
        (sum + b, true)
      } else {
        (real, false)
      }
    case l: LogLine =>
      val terms = l.ax.toList.map { case (x, a) => (apply(x), apply(a)) }
      val anyModified =
        terms.exists { case ((_, m1), (_, m2)) => m1 || m2 }
      if (anyModified) {
        val product =
          terms
            .map { case ((x, _), (a, _)) => x.pow(a) }
            .reduce { _ * _ }
        (product, true)
      } else {
        (real, false)
      }
    case Unary(original, op) =>
      val (r, modified) = apply(original)
      if (modified)
        (RealOps.unary(r, op), true)
      else
        (real, false)
    case Compare(left, right) =>
      val (newLeft, leftModified) = apply(left)
      val (newRight, rightModified) = apply(right)
      if (leftModified || rightModified)
        (RealOps.compare(newLeft, newRight), true)
      else
        (real, false)
    case Pow(base, exponent) =>
      val (newBase, baseModified) = apply(base)
      val (newExponent, exponentModified) = apply(exponent)
      if (baseModified || exponentModified)
        (newBase.pow(newExponent), true)
      else
        (real, false)
    case l: Lookup =>
      val (newIndex, indexModified) = apply(l.index)
      val newTable = l.table.map(apply)
      val anyModified =
        newTable.exists { case (_, modified) => modified }
      if (indexModified || anyModified)
        (Lookup(newIndex, newTable.map(_._1), l.low), true)
      else
        (l, false)
    case p: Parameter =>
      (p, false)
  }
}

object PartialEvaluator {
  def apply(index: Int): PartialEvaluator =
    new PartialEvaluator(Set.empty, index)

  def inline(real: Real, nRows: Int): Real = {
    0.until(nRows)
      .foldLeft((Real.zero, PartialEvaluator(0))) {
        case ((acc, pe), _) =>
          (acc + pe(real)._1, pe.next())
      }
      ._1
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy