All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.brushfire.TDigest.scala Maven / Gradle / Ivy

package com.stripe.brushfire

import com.tdunning.math.stats.TDigest
import com.twitter.algebird.{Monoid, Semigroup}

import Predicate.Lt

private [this] case object TDigestSemigroup extends Semigroup[TDigest] {
  override def plus(l: TDigest, r: TDigest): TDigest = {
    val td = TDigest.createDigest(math.max(l.compression(), r.compression()))
    td.add(l)
    td.add(r)
    td
  }

  override def sumOption(iter: TraversableOnce[TDigest]): Option[TDigest] = {
    iter.foldLeft(None: Option[TDigest]) {
      case (None, el) =>
        val td = TDigest.createDigest(el.compression())
        td.add(el)
        Some(td)

      case (f@Some(acc), el) =>
        acc.add(el)
        f
    }
  }
}

object TDigestSplitter {
  /**
   * Estimate the number of items on either side of a quantile split
   */
  private def splitCounts(q: Double, digest: TDigest): (Long, Long) = {
    // the cumulative density reported is sometimes outside of [0,1] so we need to bound it. otherwise
    // the estimated target distribution will be far from realistic... leading the Evaluator to
    // frequently pick a suboptimal split point
    val left = (digest.cdf(q).max(0.0).min(1.0) * digest.size().toDouble).toLong
    val right = (digest.size - left).max(0L)
    (left, right)
  }

  /**
   * Create a singleton [[scala.collection.Map]] if the value is positive, else return [[scala.collection.Map.empty]]
   */
  private def positiveOrEmpty[L]: (L, Long) => Map[L, Long] = {
    case (key, value) if value > 0L => Map(key -> value)
    case _ => Map.empty[L, Long]
  }

  private def targetDistribution[L](q: Double)(target: L, digest: TDigest): (Map[L, Long], Map[L, Long]) = {
    val (left, right) = splitCounts(q, digest)
    (positiveOrEmpty(target, left), positiveOrEmpty(target, right))
  }
}

case class TDigestSplitter[L](k: Int = 25, compression: Double = 100.0) extends Splitter[Double, Map[L, Long]] {
  override type S = Map[L, TDigest]

  override def split(parent: Map[L, Long], stats: S): Iterable[Split[Double, Map[L, Long]]] = {
    implicit val tds = TDigestSemigroup
    import TDigestSplitter.targetDistribution

    val splits = for {
      // merge the statistics from all targets
      merged <- Semigroup.sumOption(stats.valuesIterator).toSeq

      // generate the requested number of splits evenly between [1/k, 1]
      // we can skip 0 because the predicate is LessThan, and no targets should
      // exist below the 0th quantile
      i <- 1 to k

      // first estimate the nth quantile from the merged statistics
      // this will become a potential split point in the resulting tree
      q = merged.quantile(i.toDouble / k.toDouble).max(0.0).min(merged.size().toDouble)

      // then estimate the target distribution using the target's statistics
      (left, right) = Monoid.sum(stats.map(Function.tupled(targetDistribution(q))))

      // the goodness score of an entirely empty split should not be the best
      // and so they can be discarded immediately
      if left.nonEmpty || right.nonEmpty
    } yield {
      Split(Lt(q), left, right)
    }

    // if the input is not continuous or has too few examples we will end up
    // with a smaller number of actual splits than requested, and the rest will be
    // duplicates... remove the dupes
    splits.distinct
  }

  override def semigroup: Semigroup[S] = {
    implicit val tds = TDigestSemigroup
    implicitly[Semigroup[S]]
  }

  override def create(value: Double, target: Map[L, Long]): S = {
    target.mapValues {
      case count if count <= Int.MaxValue.toLong =>
        val td = TDigest.createDigest(compression)
        td.add(value, count.toInt)
        td
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy