All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.brushfire.Voter.scala Maven / Gradle / Ivy

package com.stripe.brushfire

import com.twitter.algebird._

import AnnotatedTree.AnnotatedTreeTraversal

/** Combines multiple targets into a single prediction **/
trait Voter[T, P] { self =>

  /**
   * Combines the targets from multiple trees into a prediction.
   */
  def combine(targets: Iterable[T]): P

  /**
   * Transform the input targets by first applying the function `f`.
   */
  def contramap[S](f: S => T): Voter[S, P] = new Voter[S, P] {
    def combine(targets: Iterable[S]): P =
      self.combine(targets.map(f))
  }

  /**
   * Transform the final predictions of this `Voter` with function `f`.
   */
  def map[Q](f: P => Q): Voter[T, Q] = new Voter[T, Q] {
    def combine(targets: Iterable[T]): Q =
      f(self.combine(targets))
  }

  final def predict[K, V, A](trees: Iterable[AnnotatedTree[K, V, T, A]], row: Map[K, V])(implicit traversal: AnnotatedTreeTraversal[K, V, T, A], semigroup: Semigroup[T]): P =
    combine(trees.flatMap(_.targetFor(row)))
}

object Voter {

  /**
   * Returns a [[Voter]] that uses `f` as the `combine` function.
   */
  def apply[A, B](f: Iterable[A] => B): Voter[A, B] =
    new Voter[A, B] {
      def combine(targets: Iterable[A]): B = f(targets)
    }

  /**
   * returns a [[Voter]] that sums up the targets using the monoid for `T`.
   */
  def fromMonoid[T: Monoid]: Voter[T, T] =
    Voter(Monoid.sum(_))

  /**
   * Returns a [[Voter]] that combines the targets using an aggregator.
   */
  def fromAggregator[A, B, C](aggregate: Aggregator[A, B, C]): Voter[A, C] =
    Voter(aggregate(_))

  type FrequencyVoter[L, M] = Voter[Map[L, M], Map[L, Double]]

  def soft[L, M: Numeric]: FrequencyVoter[L, M] =
    fromMonoid[(Map[L, Double], Int)]
      .contramap[Map[L, M]](normalize(_) -> 1)
      .map { case (m, cnt) => m.mapValues(_ / cnt) }

  def mode[L, M: Ordering]: FrequencyVoter[L, M] =
    fromMonoid[Map[L, Double]]
      .contramap[Map[L, M]](m => Map(mode(m) -> 1.0))
      .map(normalize(_))

  def threshold[M](threshold: Double, voter: FrequencyVoter[Boolean, M]): Voter[Map[Boolean, M], Boolean] =
    voter.map(m => m.getOrElse(true, 0D) > threshold)

  def mean[N: Numeric]: Voter[N, Double] =
    fromAggregator(AveragedValue.numericAggregator[N])

  private def normalize[L, N](m: Map[L, N])(implicit num: Numeric[N]): Map[L, Double] = {
    val nonNeg = m.mapValues { n => math.max(num.toDouble(n), 0.0) }
    val total = math.max(nonNeg.values.sum, 1.0)
    nonNeg.mapValues { _ / total }
  }

  private def mode[L, M: Ordering](m: Map[L, M]): L = m.maxBy { _._2 }._1
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy