All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.streaming.rabbitmq.distributed.RabbitMQRDD.scala Maven / Gradle / Ivy

/**
 * Copyright (C) 2015 Stratio (http://stratio.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.streaming.rabbitmq.distributed

import akka.actor.ActorSystem
import com.rabbitmq.client.ConsumerCancelledException
import com.rabbitmq.client.QueueingConsumer.Delivery
import com.typesafe.config.ConfigFactory
import org.apache.spark.partial.{BoundedDouble, CountEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.rabbitmq.consumer.Consumer
import org.apache.spark.streaming.rabbitmq.consumer.Consumer._
import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.{Accumulator, Logging, Partition, SparkContext, SparkException, TaskContext}

import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

private[rabbitmq]
class RabbitMQRDD[R: ClassTag](
                                @transient sc: SparkContext,
                                distributedKeys: Seq[RabbitMQDistributedKey],
                                rabbitMQParams: Map[String, String],
                                val countAccumulator: Accumulator[Long],
                                messageHandler: Delivery => R
                              ) extends RDD[R](sc, Nil) with Logging {

  @volatile private var totalCalculated: Option[Long] = None

  /**
   * Return the number of elements in the RDD. Optimized when is called the second place
   */
  override def count(): Long = {
    totalCalculated.getOrElse {
      withScope {
        sc.runJob(this, Utils.getIteratorSize _)
        totalCalculated = Option(countAccumulator.value)
        totalCalculated.get
      }
    }
  }

  /**
   * Return the number of elements in the RDD approximately. Optimized when count are called before
   */
  override def countApprox(
                            timeout: Long,
                            confidence: Double = 0.95): PartialResult[BoundedDouble] = {
    if (totalCalculated.isDefined) {
      val c = count()
      new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
    } else {
      withScope {
        val countElements: (TaskContext, Iterator[R]) => Long = { (ctx, iter) =>
          var result = 0L
          while (iter.hasNext) {
            result += 1L
            iter.next()
          }
          result
        }
        val evaluator = new CountEvaluator(partitions.length, confidence)
        sc.runApproximateJob(this, countElements, evaluator, timeout)
      }
    }
  }

  /**
   * Return if the RDD is empty. Optimized when count are called before
   */
  override def isEmpty(): Boolean = {
    totalCalculated.fold {
      withScope {
        partitions.length == 0 || take(1).length == 0
      }
    } { total => total == 0L }
  }

  override def take(num: Int): Array[R] = {
    if (totalCalculated.isEmpty) count()
    super.take(num)
  }

  /**
   * The number of partitions are calculated in base of the selected parallelism multiplied by the number of RabbitMQ
   * connections selected by the user when create the DStream, if the sequence of distributed keys is empty, it is
   * multiplied by the distributedKey calculate in base of the params.
   *
   * @return the number of Partitions calculated for this RDD
   */
  override def getPartitions: Array[Partition] = {
    val parallelism = getParallelism(rabbitMQParams)
    val keys = if (distributedKeys.nonEmpty)
      distributedKeys
    else getDistributedKeysParams(rabbitMQParams)

    keys.zipWithIndex.flatMap { case (key, index) =>
      (0 until parallelism).map(indexParallelism => {
        new RabbitMQPartition(
          parallelism * index + indexParallelism,
          key.queue,
          key.exchangeAndRouting,
          rabbitMQParams ++ key.connectionParams,
          parallelism > 1
        )
      })
    }.toArray
  }

  /**
   * The Preferred locations are calculate in base of the hosts when the partition was created
   *
   * @param thePart Partition to calculate the locations
   * @return The sequence of locations
   */
  override def getPreferredLocations(thePart: Partition): Seq[String] = {
    val part = thePart.asInstanceOf[RabbitMQPartition]

    Seq(getHosts(part.connectionParams))
  }

  override def compute(thePart: Partition, context: TaskContext): Iterator[R] = {
    val rabbitMQPartition = thePart.asInstanceOf[RabbitMQPartition]

    log.debug(s"Computing Partition: ${thePart.index} from \t[${rabbitMQPartition.toStringPretty}]")

    new RabbitMQRDDIterator(rabbitMQPartition, context)
  }

  private class RabbitMQRDDIterator(
                                     part: RabbitMQPartition,
                                     context: TaskContext) extends NextIterator[R] {

    import system.dispatcher

    //Listener for control the shutdown process when the tasks are interrupted
    context.addTaskCompletionListener(context => {
      if (context.isInterrupted()) {
        RabbitMQRDD.shutDownActorSystem()
        log.info(s"Task interrupted, closing RabbitMQ connections in partition: ${part.index}")
        closeIfNeeded()
        closeConnections()
      }
    })

    //Parameters of the RDD are merged with the parameters for this partition
    val rabbitParams = part.connectionParams
    //Get or create one consumer, create one new channel if this consumer use one connection that was created
    // previously is reused
    val consumer = getConsumer(part, rabbitParams)
    val queueConsumer = consumer.startConsumer
    //Counter to control the number of messages consumed by this partition
    @volatile var numMessages = 0

    //The actorSystem and the receiveTime are used to limit the number of milliseconds that the partition is
    // receiving data from RabbitMQ
    val system = RabbitMQRDD.getActorSystem
    val receiveTime = getMaxReceiveTime(rabbitParams)
    val maxMessagesPerPartition = getMaxMessagesPerPartition(rabbitParams)

    //Execute this code every certain time, the consumer must stop with this timeout
    val scheduleProcess = system.scheduler.scheduleOnce(receiveTime milliseconds) {
      finished = true
      queueConsumer.handleCancel("timeout")
    }

    log.info(s"Receiving data in Partition ${part.index} from \t[${part.toStringPretty}]")

    override def getNext(): R = {
      synchronized {
        if (finished || (maxMessagesPerPartition.isDefined && numMessages >= maxMessagesPerPartition.get)) {
          finishIterationAndReturn()
        } else {
          Try(queueConsumer.nextDelivery())
          match {
            case Success(delivery) =>
              processDelivery(delivery)
            case Failure(e: ConsumerCancelledException) =>
              finishIterationAndReturn()
            case Failure(e) =>
              Try {
                finished = true
                closeIfNeeded()
              }
              throw new SparkException(s"Error receiving data from RabbitMQ with error: ${e.getLocalizedMessage}", e)
          }
        }
      }
    }

    private def processDelivery(delivery:Delivery): R = {
      Try(messageHandler(delivery))
      match {
        case Success(data) =>
          //Send ack if not set the auto ack property
          if (sendingBasicAckFromParams(rabbitParams))
            consumer.sendBasicAck(delivery)
          //Increment the number of messages consumed correctly
          numMessages += 1
          data
        case Failure(e) =>
          //Send noack if not set the auto ack property
          if (sendingBasicAckFromParams(rabbitParams)) {
            log.warn(s"failed to process message. Sending noack ...", e)
            consumer.sendBasicNAck(delivery)
          }
          null.asInstanceOf[R]
      }
    }

    override def close(): Unit = {
      //Increment the accumulator to control in the driver the number of messages consumed by all executors, this is
      // used to report in the Spark UI this number for the next iteration
      countAccumulator += numMessages
      log.info(s"******* Received $numMessages messages by Partition : ${part.index}  before close Channel ******")
      //Close the scheduler and the channel in the consumer
      scheduleProcess.cancel()
      consumer.close()
    }

    private def finishIterationAndReturn(): R = {
      finished = true
      null.asInstanceOf[R]
    }

    private def getConsumer(part: RabbitMQPartition, consumerParams: Map[String, String]): Consumer = {
      val consumer = Consumer(consumerParams)
      consumer.setQueue(
        part.queue,
        part.exchangeAndRouting.exchangeName,
        part.exchangeAndRouting.exchangeType,
        part.exchangeAndRouting.routingKeys,
        consumerParams
      )
      //If the number of consumers in the same queue are more than one, the Fair Dispatch should be 1, in other case
      // the user can lose events
      if (getFairDispatchFromParams(consumerParams))
        consumer.setFairDispatchQoS(getPrefetchCountFromParams(consumerParams))

      consumer
    }
  }

}

private[rabbitmq]
object RabbitMQRDD extends Logging {

  @volatile private var system: Option[ActorSystem] = None

  def apply[R: ClassTag](sc: SparkContext,
                         distributedKeys: Seq[RabbitMQDistributedKey],
                         rabbitMQParams: Map[String, String],
                         countAccumulator: Accumulator[Long],
                         messageHandler: Delivery => R
                        ): RabbitMQRDD[R] = {

    new RabbitMQRDD[R](sc, distributedKeys, rabbitMQParams, countAccumulator, messageHandler)
  }

  def getActorSystem: ActorSystem = {
    synchronized {
      if (system.isEmpty || (system.isDefined && system.get.isTerminated))
        system = Option(akka.actor.ActorSystem(
          s"system-${System.currentTimeMillis()}",
          ConfigFactory.load(ConfigFactory.parseString("akka.daemonic=on"))
        ))
      system.get
    }
  }

  def shutDownActorSystem(): Unit = {
    synchronized {
      system.foreach(actorSystem => {
        log.debug(s"Shutting down actor system: ${actorSystem.name}")
        actorSystem.shutdown()
      })
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy