All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.scalding.spark_backend.Op.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.scalding.spark_backend

import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
import com.twitter.scalding.{Config, FutureCache}
import com.twitter.scalding.typed.TypedSource

sealed abstract class Op[+A] {
  import Op.{Transformed, fakeClassTag}

  def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: A]]

  def map[B](fn: A => B): Op[B] =
    Transformed[A, B](this, _.map(fn))
  def concatMap[B](fn: A => TraversableOnce[B]): Op[B] =
    Transformed[A, B](this, _.flatMap(fn))
  def filter(fn: A => Boolean): Op[A] =
    Transformed[A, A](this, _.filter(fn))
  def persist(sl: StorageLevel): Op[A] =
    Transformed[A, A](this, _.persist(sl))
  def mapPartitions[B](fn: Iterator[A] => Iterator[B]): Op[B] =
    Transformed[A, B](this, _.mapPartitions(fn, preservesPartitioning = true))
}

object Op extends Serializable {
  // TODO, this may be just inefficient, or it may be wrong
  implicit private def fakeClassTag[A]: ClassTag[A] = ClassTag(classOf[AnyRef]).asInstanceOf[ClassTag[A]]

  implicit class PairOp[K, V](val op: Op[(K, V)]) extends AnyVal {
    def flatMapValues[U](fn: V => TraversableOnce[U]): Op[(K, U)] =
      Transformed[(K, V), (K, U)](op, _.flatMapValues(fn))
    def mapValues[U](fn: V => U): Op[(K, U)] =
      Transformed[(K, V), (K, U)](op, _.mapValues(fn))
    def mapGroup[U](fn: (K, Iterator[V]) => Iterator[U])(implicit ordK: Ordering[K]): Op[(K, U)] =
      Transformed[(K, V), (K, U)](op, { rdd: RDD[(K, V)] =>
        val numPartitions = rdd.getNumPartitions
        val partitioner = rdd.partitioner.getOrElse(new HashPartitioner(numPartitions))
        val partitioned = rdd.repartitionAndSortWithinPartitions(partitioner)
        partitioned.mapPartitions({ its =>
          // since we are sorted, the adjacent keys are next to each other
          val grouped = Iterators.groupSequential(its)
          grouped.flatMap { case (k, vs) => fn(k, vs).map((k, _)) }
        }, preservesPartitioning = true)
      })

    def hashJoin[U, W](right: Op[(K, U)])(fn: (K, V, Iterable[U]) => Iterator[W]): Op[(K, W)] =
      HashJoinOp(op, right, fn)

    def sorted(implicit ordK: Ordering[K], ordV: Ordering[V]): Op[(K, V)] =
      Transformed[(K, V), (K, V)](op, { rdd: RDD[(K, V)] =>
        // The idea here is that we put the key and the value in
        // logical key, but partition only on the left part of the key
        val numPartitions = rdd.getNumPartitions
        val partitioner = rdd.partitioner.getOrElse(new HashPartitioner(numPartitions))
        val keyOnlyPartioner = KeyHashPartitioner(partitioner)
        val unitValue: RDD[((K, V), Unit)] = rdd.map { kv => (kv, ()) }
        val partitioned = unitValue.repartitionAndSortWithinPartitions(keyOnlyPartioner)
        partitioned.mapPartitions({ its =>
          // discard the unit value
          its.map { case (kv, _) => kv }
        }, preservesPartitioning = true) // the keys haven't changed
      })

    def sortedMapGroup[U](fn: (K, Iterator[V]) => Iterator[U])(implicit ordK: Ordering[K], ordV: Ordering[V]): Op[(K, U)] =
      Transformed[(K, V), (K, U)](op, { rdd: RDD[(K, V)] =>
        // The idea here is that we put the key and the value in
        // logical key, but partition only on the left part of the key
        val numPartitions = rdd.getNumPartitions
        val partitioner = rdd.partitioner.getOrElse(new HashPartitioner(numPartitions))
        val keyOnlyPartioner = KeyHashPartitioner(partitioner)
        val unitValue: RDD[((K, V), Unit)] = rdd.map { kv => (kv, ()) }
        val partitioned = unitValue.repartitionAndSortWithinPartitions(keyOnlyPartioner)
        partitioned.mapPartitions({ its =>
          // discard the unit value
          val kviter = its.map { case (kv, _) => kv }
          // since we are sorted first by key, then value, the keys are grouped
          val grouped = Iterators.groupSequential(kviter)
          grouped.flatMap { case (k, vs) => fn(k, vs).map((k, _)) }
        }, preservesPartitioning = true) // the keys haven't changed
      })
  }

  private case class KeyHashPartitioner(partitioner: Partitioner) extends Partitioner {
    override def numPartitions: Int = partitioner.numPartitions

    override def getPartition(keyValue: Any): Int = {
      val key: Any = keyValue.asInstanceOf[(Any, Any)]._1
      partitioner.getPartition(key)
    }

  }

  implicit class InvariantOp[A](val op: Op[A]) extends AnyVal {
    def ++(that: Op[A]): Op[A] =
      op match {
        case Empty => that
        case nonEmpty =>
          that match {
            case Empty => nonEmpty
            case thatNE =>
              Merged(nonEmpty, thatNE)
          }
      }
  }

  object Empty extends Op[Nothing] {
    def run(session: SparkSession)(implicit ec: ExecutionContext) =
      Future(session.sparkContext.emptyRDD[Nothing])

    override def map[B](fn: Nothing => B): Op[B] = this
    override def concatMap[B](fn: Nothing => TraversableOnce[B]): Op[B] = this
    override def filter(fn: Nothing => Boolean): Op[Nothing] = this
    override def persist(sl: StorageLevel): Op[Nothing] = this
    override def mapPartitions[B](fn: Iterator[Nothing] => Iterator[B]): Op[B] = this
  }

  final case class FromIterable[A](iterable: Iterable[A]) extends Op[A] {
    def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: A]] =
      Future(session.sparkContext.makeRDD(iterable.toSeq, 1))
  }

  final case class Source[A](conf: Config, original: TypedSource[A], input: Option[SparkSource[A]]) extends Op[A] {
    def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: A]] =
      input match {
        case None => Future.failed(new IllegalArgumentException(s"source $original was not connected to a spark source"))
        case Some(src) => src.read(session, conf)
      }
  }

  private def widen[A](r: RDD[_ <: A]): RDD[A] =
    //r.map { a => a }
    // or we could just cast
    r.asInstanceOf[RDD[A]]

  final case class Transformed[Z, A](input: Op[Z], fn: RDD[Z] => RDD[A]) extends Op[A] {
    @transient private val cache = new FutureCache[SparkSession, RDD[_ <: A]]

    def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: A]] =
      cache.getOrElseUpdate(session,
        input.run(session).map { rdd => fn(widen(rdd)) })
  }

  final case class Merged[A](left: Op[A], right: Op[A]) extends Op[A] {
    @transient private val cache = new FutureCache[SparkSession, RDD[_ <: A]]

    def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: A]] =
      cache.getOrElseUpdate(session, {
        // start running in parallel
        val lrdd = left.run(session)
        val rrdd = right.run(session)
        for {
          l <- lrdd
          r <- rrdd
        } yield widen[A](l) ++ widen[A](r)
      })
  }

  final case class HashJoinOp[A, B, C, D](left: Op[(A, B)], right: Op[(A, C)], joiner: (A, B, Iterable[C]) => Iterator[D]) extends Op[(A, D)] {
    @transient private val cache = new FutureCache[SparkSession, RDD[_ <: (A, D)]]

    def run(session: SparkSession)(implicit ec: ExecutionContext): Future[RDD[_ <: (A, D)]] =
      cache.getOrElseUpdate(session, {
        // start running in parallel
        val rrdd = right.run(session)
        val lrdd = left.run(session)

        rrdd.flatMap { rightRdd =>
          // TODO: spark has some thing to send replicated data to nodes
          // we should materialize the small side, use the above, then
          // implement a join using mapPartitions
          val rightMap: Map[A, List[C]] = rightRdd
            .toLocalIterator
            .toList
            .groupBy(_._1)
            .map { case (k, vs) => (k, vs.map(_._2)) }

          val bcastMap = session.sparkContext.broadcast(rightMap)
          lrdd.map { leftrdd =>

            val localJoiner = joiner

            leftrdd.mapPartitions({ it: Iterator[(A, B)] =>
              val rightMap = bcastMap.value
              it.flatMap { case (a, b) => localJoiner(a, b, rightMap.getOrElse(a, Nil)).map((a, _)) }
            }, preservesPartitioning = true)
          }
        }
      })

  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy