All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.algebird.macros.Cuber.scala Maven / Gradle / Ivy

There is a newer version: 0.13.4
Show newest version
package com.twitter.algebird.macros

import scala.language.experimental.{ macros => sMacros }
import scala.reflect.macros.Context
import scala.reflect.runtime.universe._

/**
 * "Cubes" a case class or tuple, i.e. for a tuple of type
 * (T1, T2, ... , TN) generates all 2^N possible combinations of type
 * (Option[T1], Option[T2], ... , Option[TN]).
 *
 * This is useful for comparing some metric across all possible subsets.
 * For example, suppose we have a set of people represented as
 * case class Person(gender: String, age: Int, height: Double)
 * and we want to know the average height of
 *  - people, grouped by gender and age
 *  - people, grouped by only gender
 *  - people, grouped by only age
 *  - all people
 *
 * Then we could do
 * > import com.twitter.algebird.macros.Cuber.cuber
 * > val people: List[People]
 * > val averageHeights: Map[(Option[String], Option[Int]), Double] =
 * >   people.flatMap { p => cuber((p.gender, p.age)).map((_,p)) }
 * >     .groupBy(_._1)
 * >     .mapValues { xs => val heights = xs.map(_.height); heights.sum / heights.length }
 */
trait Cuber[I] {
  type K
  def apply(in: I): TraversableOnce[K]
}

/**
 * Given a TupleN, produces a sequence of (N + 1) tuples each of arity N
 * such that, for all k from 0 to N, there is a tuple with k Somes
 * followed by (N - k) Nones.
 *
 * This is useful for comparing some metric across multiple layers of
 * some hierarchy.
 * For example, suppose we have some climate data represented as
 * case class Data(continent: String, country: String, city: String, temperature: Double)
 * and we want to know the average temperatures of
 *   - each continent
 *   - each (continent, country) pair
 *   - each (continent, country, city) triple
 *
 * Here we desire the (continent, country) and (continent, country, city)
 * pair because, for example, if we grouped by city instead of by
 * (continent, country, city), we would accidentally combine the results for
 * Paris, Texas and Paris, France.
 *
 * Then we could do
 * > import com.twitter.algebird.macros.Roller.roller
 * > val data: List[Data]
 * > val averageTemps: Map[(Option[String], Option[String], Option[String]), Double] =
 * > data.flatMap { d => roller((d.continent, d.country, d.city)).map((_, d)) }
 * >   .groupBy(_._1)
 * >   .mapValues { xs => val temps = xs.map(_.temperature); temps.sum / temps.length }
 */
trait Roller[I] {
  type K
  def apply(in: I): TraversableOnce[K]
}

object Cuber {
  implicit def cuber[T]: Cuber[T] = macro cuberImpl[T]

  def cuberImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[Cuber[T]] = {
    import c.universe._

    ensureCaseClass(c)

    val params = getParams(c)
    val arity = params.length
    if (arity > 22)
      c.abort(c.enclosingPosition, s"Cannot create Cuber for $T because it has more than 22 parameters.")
    if (arity == 0)
      c.abort(c.enclosingPosition, s"Cannot create Cuber for $T because it has no parameters.")

    val tupleName = {
      val types = getParamTypes(c)
      val optionTypes = types.map { t => tq"_root_.scala.Option[$t]" }
      val tupleType = newTypeName(s"Tuple${arity}")
      tq"_root_.scala.$tupleType[..$optionTypes]"
    }

    val somes = params.zip(Stream.from(1)).map {
      case (param, index) =>
        val name = newTermName(s"some$index")
        q"val $name = _root_.scala.Some(in.$param)"
    }

    val options = (1 to arity).map { index =>
      val some = newTermName(s"some$index")
      q"if (((1 << ${index - 1}) & i) == 0) _root_.scala.None else $some"
    }

    val cuber = q"""
    new _root_.com.twitter.algebird.macros.Cuber[${T}] {
      type K = $tupleName
      def apply(in: ${T}): _root_.scala.Seq[K] = {
        ..$somes
        (0 until (1 << $arity)).map { i =>
          new K(..$options)
        }
      }
    }
    """
    c.Expr[Cuber[T]](cuber)
  }
}

object Roller {
  implicit def roller[T]: Roller[T] = macro rollerImpl[T]

  def rollerImpl[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[Roller[T]] = {
    import c.universe._

    ensureCaseClass(c)

    val params = getParams(c)
    val arity = params.length
    if (arity > 22)
      c.abort(c.enclosingPosition, s"Cannot create Roller for $T because it has more than 22 parameters.")
    if (arity == 0)
      c.abort(c.enclosingPosition, s"Cannot create Roller for $T because it has no parameters.")

    val tupleName = {
      val types = getParamTypes(c)
      val optionTypes = types.map { t => tq"_root_.scala.Option[$t]" }
      val tupleType = newTypeName(s"Tuple${arity}")
      tq"_root_.scala.$tupleType[..$optionTypes]"
    }

    val somes = params.zip(Stream.from(1)).map {
      case (param, index) =>
        val name = newTermName(s"some$index")
        q"val $name = _root_.scala.Some(in.$param)"
    }

    val items = (0 to arity).map { i =>
      val args = (1 to arity).map { index =>
        val some = newTermName(s"some$index")
        if (index <= i) q"$some" else q"_root_.scala.None"
      }
      q"new K(..$args)"
    }

    val roller = q"""
    new _root_.com.twitter.algebird.macros.Roller[${T}] {
      type K = $tupleName
      def apply(in: ${T}): _root_.scala.Seq[K] = {
        ..$somes
        Seq(..$items)
      }
    }
    """
    c.Expr[Roller[T]](roller)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy