All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.util.Memoize.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.util

import java.util.concurrent.{CountDownLatch => JCountDownLatch}

import scala.annotation.tailrec

object Memoize {
  /**
   * Thread-safe memoization for a function.
   *
   * This works like a lazy val indexed by the input value. The memo
   * is held as part of the state of the returned function, so keeping
   * a reference to the function will keep a reference to the
   * (unbounded) memo table. The memo table will never forget a
   * result, and will retain a reference to the corresponding input
   * values as well.
   *
   * If the computation has side-effects, they will happen exactly
   * once per input, even if multiple threads attempt to memoize the
   * same input at one time, unless the computation throws an
   * exception. If an exception is thrown, then the result will not be
   * stored, and the computation will be attempted again upon the next
   * access. Only one value will be computed at a time. The overhead
   * required to ensure that the effects happen only once is paid only
   * in the case of a miss (once per input over the life of the memo
   * table). Computations for different input values will not block
   * each other.
   *
   * The combination of these factors means that this method is useful
   * for functions that will only ever be called on small numbers of
   * inputs, are expensive compared to a hash lookup and the memory
   * overhead, and will be called repeatedly.
   */
  def apply[A, B](f: A => B): A => B =
    new Function1[A, B] {
      private[this] var memo = Map.empty[A, Either[JCountDownLatch, B]]

      /**
       * What to do if we do not find the value already in the memo
       * table.
       */
      @tailrec private[this] def missing(a: A): B =
        synchronized {
          // With the lock, check to see what state the value is in.
          memo.get(a) match {
            case None =>
              // If it's missing, then claim the slot by putting in a
              // CountDownLatch that will be completed when the value is
              // available.
              val latch = new JCountDownLatch(1)
              memo = memo + (a -> Left(latch))

              // The latch wrapped in Left indicates that the value
              // needs to be computed in this thread, and then the
              // latch counted down.
              Left(latch)

            case Some(other) =>
              // This is either the latch that will indicate that the
              // work has been done, or the computed value.
              Right(other)
          }
        } match {
          case Right(Right(b)) =>
            // The computation is already done.
            b

          case Right(Left(latch)) =>
            // Someone else is doing the computation.
            latch.await()

            // This recursive call will happen when there is an
            // exception computing the value, or if the value is
            // currently being computed.
            missing(a)

          case Left(latch) =>
            // Compute the value outside of the synchronized block.
            val b =
              try {
                f(a)
              } catch {
                case t: Throwable =>
                  // If there was an exception running the
                  // computation, then we need to make sure we do not
                  // starve any waiters before propagating the
                  // exception.
                  synchronized { memo = memo - a }
                  latch.countDown()
                  throw t
              }

            // Update the memo table to indicate that the work has
            // been done, and signal to any waiting threads that the
            // work is complete.
            synchronized { memo = memo + (a -> Right(b)) }
            latch.countDown()
            b
        }

      override def apply(a: A): B =
        // Look in the (possibly stale) memo table. If the value is
        // present, then it is guaranteed to be the final value. If it
        // is absent, call missing() to determine what to do.
        memo.get(a) match {
          case Some(Right(b)) => b
          case _              => missing(a)
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy