All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.algebird.Batched.scala Maven / Gradle / Ivy

There is a newer version: 0.13.4
Show newest version
package com.twitter.algebird

import scala.annotation.tailrec

/**
 * Batched: the free semigroup.
 *
 * For any type `T`, `Batched[T]` represents a way to lazily combine T
 * values as a semigroup would (i.e. associatively). A `Semigroup[T]`
 * instance can be used to recover a `T` value from a `Batched[T]`.
 *
 * Like other free structures, Batched trades space for time. A sum of
 * batched values defers the underlying semigroup action, instead
 * storing all values in memory (in a tree structure). If an
 * underlying semigroup is available, `Batched.semigroup` and
 * `Batch.monoid` can be configured to periodically sum the tree to
 * keep the overall size below `batchSize`.
 *
 * `Batched[T]` values are guaranteed not to be empty -- that is, they
 * will contain at least one `T` value.
 */
sealed abstract class Batched[T] extends Serializable {

  /**
   * Sum all the `T` values in this batch using the given semigroup.
   */
  def sum(implicit sg: Semigroup[T]): T

  /**
   * Combine two batched values.
   *
   * As mentioned above, this just creates a new tree structure
   * containing `this` and `that`.
   */
  def combine(that: Batched[T]): Batched[T] =
    Batched.Items(this, that)

  /**
   * Compact this batch if it exceeds `batchSize`.
   *
   * Compacting a branch means summing it, and then storing the summed
   * value in a new single-item batch.
   */
  def compact(batchSize: Int)(implicit s: Semigroup[T]): Batched[T] =
    if (size < batchSize) this else Batched.Item(sum(s))

  /**
   * Add more values to a batched value.
   *
   * This method will grow the tree to the left.
   */
  def append(that: TraversableOnce[T]): Batched[T] =
    that.foldLeft(this)((b, t) => b.combine(Batched(t)))

  /**
   * Provide an iterator over the underlying tree structure.
   *
   * This is the order used by `.sum`.
   *
   * This iterator traverses the tree from left-to-right. If the
   * original expression was (w + x + y + z), this iterator returns w,
   * x, y, and then z.
   */
  def iterator: Iterator[T] =
    this match {
      case Batched.Item(t) => Iterator.single(t)
      case b => new Batched.ForwardItemsIterator(b)
    }

  /**
   * Convert the batch to a `List[T]`.
   */
  def toList: List[T] =
    reverseIterator.foldLeft(List.empty[T])((ts, t) => t :: ts)

  /**
   * Provide a reversed iterator over the underlying tree structure.
   *
   * This iterator traverses the tree from right-to-left. If the
   * original expression was (w + x + y + z), this iterator returns z,
   * y, x, and then w.
   */
  def reverseIterator: Iterator[T] =
    this match {
      case Batched.Item(t) => Iterator.single(t)
      case b => new Batched.ReverseItemsIterator(b)
    }

  /**
   * Report the size of the underlying tree structure.
   *
   * This is an O(1) operation -- each subtree knows how big it is.
   */
  def size: Int
}

object Batched {

  /**
   * Constructed a batch from a single value.
   */
  def apply[T](t: T): Batched[T] =
    Item(t)

  /**
   * Constructed an optional batch from a collection of values.
   *
   * Since batches cannot be empty, this method returns `None` if `ts`
   * is empty, and `Some(batch)` otherwise.
   */
  def items[T](ts: TraversableOnce[T]): Option[Batched[T]] =
    if (ts.isEmpty) None else {
      val it = ts.toIterator
      val t0 = it.next
      Some(Item(t0).append(it))
    }

  /**
   * Equivalence for batches.
   *
   * Batches are equivalent if they sum to the same value. Since the
   * free semigroup is associative, it's not correct to take tree
   * structure into account when determining equality.
   *
   * One thing to note here is that two equivalent batches might
   * produce different lists (for instance, if one of the batches has
   * more zeros in it than another one).
   */
  implicit def equiv[A](implicit e: Equiv[A], s: Semigroup[A]): Equiv[Batched[A]] =
    new Equiv[Batched[A]] {
      def equiv(x: Batched[A], y: Batched[A]): Boolean =
        e.equiv(x.sum(s), y.sum(s))
    }

  /**
   * The free semigroup for batched values.
   *
   * This semigroup just accumulates batches and doesn't ever evaluate
   * them to flatten the tree.
   */
  implicit def semigroup[A]: Semigroup[Batched[A]] =
    new Semigroup[Batched[A]] {
      def plus(x: Batched[A], y: Batched[A]): Batched[A] = x combine y
    }

  /**
   * Compacting semigroup for batched values.
   *
   * This semigroup ensures that the batch's tree structure has fewer
   * than `batchSize` values in it. When more values are added, the
   * tree is compacted using `s`.
   */
  def compactingSemigroup[A: Semigroup](batchSize: Int): Semigroup[Batched[A]] =
    new BatchedSemigroup[A](batchSize)

  /**
   * Compacting monoid for batched values.
   *
   * This monoid ensures that the batch's tree structure has fewer
   * than `batchSize` values in it. When more values are added, the
   * tree is compacted using `m`.
   *
   * It's worth noting that `x + 0` here will produce the same sum as
   * `x`, but `.toList` will produce different lists (one will have an
   * extra zero).
   */
  def compactingMonoid[A: Monoid](batchSize: Int): Monoid[Batched[A]] =
    new BatchedMonoid[A](batchSize)

  /**
   * This aggregator batches up `agg` so that all the addition can be
   * performed at once.
   *
   * It is useful when `sumOption` is much faster than using `plus`
   * (e.g. when there is temporary mutable state used to make
   * summation fast).
   */
  def aggregator[A, B, C](batchSize: Int, agg: Aggregator[A, B, C]): Aggregator[A, Batched[B], C] = new Aggregator[A, Batched[B], C] {
    def prepare(a: A): Batched[B] = Item(agg.prepare(a))
    def semigroup: Semigroup[Batched[B]] = new BatchedSemigroup(batchSize)(agg.semigroup)
    def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
  }

  /**
   * This monoid aggregator batches up `agg` so that all the addition
   * can be performed at once.
   *
   * It is useful when `sumOption` is much faster than using `plus`
   * (e.g. when there is temporary mutable state used to make
   * summation fast).
   */
  def monoidAggregator[A, B, C](batchSize: Int, agg: MonoidAggregator[A, B, C]): MonoidAggregator[A, Batched[B], C] =
    new MonoidAggregator[A, Batched[B], C] {
      def prepare(a: A): Batched[B] = Item(agg.prepare(a))
      def monoid: Monoid[Batched[B]] = new BatchedMonoid(batchSize)(agg.monoid)
      def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
    }

  def foldOption[T: Semigroup](batchSize: Int): Fold[T, Option[T]] =
    Fold.foldLeft[T, Option[Batched[T]]](Option.empty[Batched[T]]) {
      case (Some(b), t) => Some(b.combine(Item(t)).compact(batchSize))
      case (None, t) => Some(Item(t))
    }.map(_.map(_.sum))

  def fold[T](batchSize: Int)(implicit m: Monoid[T]): Fold[T, T] =
    Fold.foldLeft[T, Batched[T]](Batched(m.zero)) { (b, t) =>
      b.combine(Item(t)).compact(batchSize)
    }.map(_.sum)

  /**
   * This represents a single (unbatched) value.
   */
  private[algebird] case class Item[T](t: T) extends Batched[T] {
    def size: Int = 1
    def sum(implicit sg: Semigroup[T]): T = t
  }

  /**
   * This represents two (or more) batched values being added.
   *
   * The actual addition is deferred until the `.sum` method is called.
   */
  private[algebird] case class Items[T](left: Batched[T], right: Batched[T]) extends Batched[T] {
    // Items#size will always be >= 2.
    val size: Int = left.size + right.size

    def sum(implicit sg: Semigroup[T]): T =
      sg.sumOption(new ForwardItemsIterator(this)).get
  }

  /**
   * Abstract iterator through a batch's tree.
   *
   * This class is agnostic about whether the traversal is
   * left-to-right or right-to-left. The abstract method `descend`
   * controls which direction the iterator moves.
   */
  private[algebird] abstract class ItemsIterator[A](root: Batched[A]) extends Iterator[A] {
    var stack: List[Batched[A]] = Nil
    var running: Boolean = true
    var ready: A = descend(root)

    def ascend(): Unit =
      stack match {
        case Nil =>
          running = false
        case h :: t =>
          stack = t
          ready = descend(h)
      }

    def descend(v: Batched[A]): A

    def hasNext: Boolean =
      running

    def next(): A =
      if (running) {
        val result = ready
        ascend()
        result
      } else {
        throw new NoSuchElementException("next on empty iterator")
      }
  }

  /**
   * Left-to-right iterator through a batch's tree.
   */
  private[algebird] class ForwardItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
    def descend(v: Batched[A]): A = {
      @inline @tailrec def descend0(v: Batched[A]): A =
        v match {
          case Items(lhs, rhs) =>
            stack = rhs :: stack
            descend0(lhs)
          case Item(value) =>
            value
        }
      descend0(v)
    }
  }

  /**
   * Right-to-left iterator through a batch's tree.
   */
  private[algebird] class ReverseItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
    def descend(v: Batched[A]): A = {
      @inline @tailrec def descend0(v: Batched[A]): A =
        v match {
          case Items(lhs, rhs) =>
            stack = lhs :: stack
            descend0(rhs)
          case Item(value) =>
            value
        }
      descend0(v)
    }
  }
}

/**
 * Compacting semigroup for batched values.
 *
 * This semigroup ensures that the batch's tree structure has fewer
 * than `batchSize` values in it. When more values are added, the
 * tree is compacted using `s`.
 */
class BatchedSemigroup[T: Semigroup](batchSize: Int) extends Semigroup[Batched[T]] {

  require(batchSize > 0, s"Batch size must be > 0, found: $batchSize")

  def plus(a: Batched[T], b: Batched[T]): Batched[T] =
    a.combine(b).compact(batchSize)
}

/**
 * Compacting monoid for batched values.
 *
 * This monoid ensures that the batch's tree structure has fewer
 * than `batchSize` values in it. When more values are added, the
 * tree is compacted using `m`.
 */
class BatchedMonoid[T: Monoid](batchSize: Int) extends BatchedSemigroup[T](batchSize) with Monoid[Batched[T]] {
  val zero: Batched[T] = Batched(Monoid.zero)

  // if we knew that (a+b=0) only for (a=0, b=0), we could instead do:
  //   new Batched.ItemsIterator(b).exists(monoid.isNonZero)
  override def isNonZero(b: Batched[T]): Boolean =
    Monoid.isNonZero(b.sum)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy