All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.sparklinedata.druid.DruidRDD.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.sparklinedata.druid

import org.apache.http.client.methods.{HttpExecutionAware, HttpRequestBase}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRowWithSchema
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.{InterruptibleIterator, Logging, Partition, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.druid.{DruidPlanner, DruidQueryCostModel, DummyResultIterator}
import org.apache.spark.sql.sparklinedata.execution.metrics.DruidQueryExecutionMetric
import org.joda.time.Interval
import org.sparklinedata.druid.client.{CancellableHolder, ConnectionManager, DruidQueryServerClient, ResultRow}
import org.sparklinedata.druid.metadata._

import scala.collection._
import scala.collection.convert.decorateAsScala._
import java.util.concurrent.ConcurrentHashMap

import com.fasterxml.jackson.core.Base64Variants
import org.apache.http.concurrent.Cancellable

import scala.util.Random

abstract class DruidPartition extends Partition {
  def queryClient(useSmile : Boolean,
                  httpMaxPerRoute : Int, httpMaxTotal : Int) : DruidQueryServerClient
  def intervals : List[Interval]
  def segIntervals : List[(DruidSegmentInfo, Interval)]

  def setIntervalsOnQuerySpec(q : QuerySpec) : QuerySpec = {
    if ( segIntervals == null) {
      q.setIntervals(intervals)
    } else {
      q.setSegIntervals(segIntervals)
    }
  }
}

class HistoricalPartition(idx: Int, hs : HistoricalServerAssignment) extends DruidPartition {
  private[druid] var _index: Int = idx
  def index = _index
  val hsName = hs.server.host

  def queryClient(useSmile : Boolean,
                  httpMaxPerRoute : Int, httpMaxTotal : Int) : DruidQueryServerClient = {
    ConnectionManager.init(httpMaxPerRoute, httpMaxTotal)
    new DruidQueryServerClient(hsName, useSmile)
  }

  val intervals : List[Interval] = hs.segmentIntervals.map(_._2)

  val segIntervals : List[(DruidSegmentInfo, Interval)] = hs.segmentIntervals

  override def setIntervalsOnQuerySpec(q : QuerySpec) : QuerySpec = {
    val r = super.setIntervalsOnQuerySpec(q)
    if (r.context.isDefined) {
      r.context.get.queryId = s"${r.context.get.queryId}-${index}"
    }
    r
  }
}

class BrokerPartition(idx: Int,
                      val broker : String,
                      val i : Interval) extends DruidPartition {
  override def index: Int = idx
  def queryClient(useSmile : Boolean,
                  httpMaxPerRoute : Int, httpMaxTotal : Int)  : DruidQueryServerClient = {
    ConnectionManager.init(httpMaxPerRoute, httpMaxTotal)
    new DruidQueryServerClient(broker, useSmile)
  }

  def intervals : List[Interval] = List(i)

  def segIntervals : List[(DruidSegmentInfo, Interval)] = null
}

class DruidRDD(sqlContext: SQLContext,
               drInfo : DruidRelationInfo,
                val dQuery : DruidQuery)  extends  RDD[InternalRow](sqlContext.sparkContext, Nil) {

  val recordDruidQuery = DruidPlanner.getConfValue(sqlContext,
    DruidPlanner.DRUID_RECORD_QUERY_EXECUTION
  )
  val druidQueryAcc : DruidQueryExecutionMetric = if (recordDruidQuery) {
    new DruidQueryExecutionMetric()
  } else {
    null
  }
  val numSegmentsPerQuery = dQuery.numSegmentsPerQuery
  val schema = dQuery.schema(drInfo)
  val useSmile = dQuery.useSmile && smileCompatible(schema)
  val drOptions = drInfo.options
  val drFullName = drInfo.fullName
  val drDSIntervals = drInfo.druidDS.intervals
  val inputEstimate = DruidQueryCostModel.estimateInput(dQuery.q, drInfo)
  val outputEstimate = DruidQueryCostModel.estimateOutput(dQuery.q, drInfo)
  val (httpMaxPerRoute, httpMaxTotal) = (
    DruidPlanner.getConfValue(sqlContext,
      DruidPlanner.DRUID_CONN_POOL_MAX_CONNECTIONS_PER_ROUTE),
    DruidPlanner.getConfValue(sqlContext,
      DruidPlanner.DRUID_CONN_POOL_MAX_CONNECTIONS)
    )

  val sparkToDruidColName: Map[String, String] =
    dQuery.q.mapSparkColNameToDruidColName(drInfo).map(identity)
  // scalastyle:off line.size.limit
  // why map identity?
  // see http://stackoverflow.com/questions/17709995/notserializableexception-for-mapstring-string-alias
  // scalastyle:on line.size.limit

  def druidColName(n : String) = sparkToDruidColName.getOrElse(n, n)

  /*
   * for now if there is a binary field don't use Smile.
   */
  def smileCompatible(typ : StructType) : Boolean = {
    typ.fields.foldLeft(true) {
      case (false, _) => false
      case (_, StructField(_, BinaryType, _, _)) => false
      case (_, StructField(_, st:StructType, _,_)) => smileCompatible(st)
      case _  => true
    }
  }

  @DeveloperApi
  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {

    val p = split.asInstanceOf[DruidPartition]
    val mQry = p.setIntervalsOnQuerySpec(dQuery.q)
    Utils.logQuery(mQry)
    /*
     * Druid Query execution steps:
     * 1. Register queryId with TaskCancelHandler
     * 2. Provide the [[org.sparklinedata.druid.client.CancellableHolder]] to the DruidClient,
     *    so it can relay any [[org.apache.http.concurrent.Cancellable]] resources to
     *    the holder.
     * 3. If an error occurs and the holder's ''wasCancelTriggered'' flag is set, a
     *    DummyResultIterator is returned. By this time the Task has been cancelled by
     *    the ''Spark Executor'', so it is safe to return an empty iterator.
     * 4. Always clear this Druid Query from the TaskCancelHandler
     */
    var cancelCallback : TaskCancelHandler.TaskCancelHolder = null
    var dr: CloseableIterator[ResultRow] = null
    var client : DruidQueryServerClient = null
    val qryId = mQry.context.map(_.queryId).getOrElse(s"q-${System.nanoTime()}")
    var qrySTime = System.currentTimeMillis()
    var qrySTimeStr = s"${new java.util.Date()}"
    try {
      cancelCallback = TaskCancelHandler.registerQueryId(qryId, context)
      client = p.queryClient(useSmile, httpMaxPerRoute, httpMaxTotal)
      client.setCancellableHolder(cancelCallback)
      qrySTime = System.currentTimeMillis()
      qrySTimeStr = s"${new java.util.Date()}"
      dr = mQry.executeQuery(client)
    } catch {
      case _ if cancelCallback.wasCancelTriggered && client != null => {
        dr = new DummyResultIterator()
      }
      case e: Throwable => throw e
    }
    finally {
      TaskCancelHandler.clearQueryId(qryId)
    }

    val druidExecTime = (System.currentTimeMillis() - qrySTime)
    var numRows : Int = 0

    context.addTaskCompletionListener{ context =>
      val queryExecTime = (System.currentTimeMillis() - qrySTime)
      if (recordDruidQuery) {
        druidQueryAcc.add(
          DruidQueryExecutionView(
            context.stageId,
            context.partitionId(),
            context.taskAttemptId(),
            s"${client.host}:${client.port}",
            if (p.segIntervals == null) None
            else {
              Some(
                p.segIntervals.map(t => (t._1.identifier, t._2.toString))
              )
            },
            qrySTimeStr,
            druidExecTime,
            queryExecTime,
            numRows,
            Utils.queryToString(mQry)
          )
        )
      }
      dr.closeIfNeeded()
    }

    val r = new InterruptibleIterator[ResultRow](context, dr)
    val nameToTF = dQuery.getValTFMap

    /*
     * multiple output fields may project on the same druid column with different
     * transformations. True for the dimensions of a spatialIndex point.
     */
    def tfName(f : StructField) : Option[String] = {
      if (nameToTF.contains(f.name)) {
        nameToTF.get(f.name)
      } else {
        nameToTF.get(druidColName(f.name))
      }
    }

    r.map { r =>
      numRows += 1
      val row = new GenericInternalRowWithSchema(schema.fields.map
      (f => DruidValTransform.sparkValue(
        f, r.event(druidColName(f.name)), tfName(f))), schema)
      row
    }
  }

  override protected def getPartitions: Array[Partition] = {
    if (dQuery.queryHistoricalServer) {
    val hAssigns = DruidMetadataCache.assignHistoricalServers(
      drFullName,
      drOptions,
      dQuery.intervalSplits
    )
      var idx = -1

      val l  = (for(
        hA <- hAssigns;
           segIns <- hA.segmentIntervals.sliding(numSegmentsPerQuery,numSegmentsPerQuery)
      ) yield {
        idx = idx + 1
        new HistoricalPartition(idx, new HistoricalServerAssignment(hA.server, segIns))
      }
        )

      val l1 : Array[Partition] = Random.shuffle(l).zipWithIndex.map{t =>
        val p = t._1
        val i = t._2
        p._index = i
        p
      }.toArray

      l1
  } else {
      // ensure DataSource is in the Metadata Cache.
      DruidMetadataCache.getDataSourceInfo(drFullName, drOptions)
      val broker = DruidMetadataCache.getDruidClusterInfo(drFullName,
        drOptions).curatorConnection.getBroker
      dQuery.intervalSplits.zipWithIndex.map(t => new BrokerPartition(t._2, broker, t._1)).toArray
    }
  }
}

/**
  * conversion from Druid values to Spark values. Most of the conversion cases are handled by
  * cast expressions in the [[org.apache.spark.sql.execution.Project]] operator above the
  * DruidRelation Operator; but some values needs massaging like TimeStamps, Strings...
  */
object DruidValTransform {

  private[this] val dTZ = org.joda.time.DateTimeZone.getDefault

  private[this] val toTSWithTZAdj = (druidVal: Any) => {
    val dvLong = if (druidVal.isInstanceOf[Double]) {
      druidVal.asInstanceOf[Double].toLong
    } else if (druidVal.isInstanceOf[BigInt]) {
      druidVal.asInstanceOf[BigInt].toLong
    } else if (druidVal.isInstanceOf[String]){
      druidVal.asInstanceOf[String].toLong
    }else {
      druidVal
    }

    new org.joda.time.DateTime(dvLong, dTZ).getMillis() * 1000.asInstanceOf[SQLTimestamp]
  }

  private[this] val toTS = (druidVal: Any) => {
    if (druidVal.isInstanceOf[Double]) {
      druidVal.asInstanceOf[Double].longValue().asInstanceOf[SQLTimestamp]
    } else if (druidVal.isInstanceOf[BigInt]) {
      druidVal.asInstanceOf[BigInt].toLong.asInstanceOf[SQLTimestamp]
    } else {
      druidVal
    }
  }

  private[this] val toString = (druidVal: Any) => {
    UTF8String.fromString(druidVal.toString)
  }

  private[this] val toInt = (druidVal: Any) => {
    if (druidVal.isInstanceOf[Double]) {
      druidVal.asInstanceOf[Double].toInt
    } else if (druidVal.isInstanceOf[BigInt]) {
      druidVal.asInstanceOf[BigInt].toInt
    } else if (druidVal.isInstanceOf[String]) {
      druidVal.asInstanceOf[String].toInt
    }else {
      druidVal
    }
  }

  private[this] val toLong = (druidVal: Any) => {
    if (druidVal.isInstanceOf[Double]) {
      druidVal.asInstanceOf[Double].toLong
    } else if (druidVal.isInstanceOf[BigInt]) {
      druidVal.asInstanceOf[BigInt].toLong
    } else if (druidVal.isInstanceOf[String]) {
      druidVal.asInstanceOf[String].toLong
    }else {
      druidVal
    }
  }

  private[this] val toFloat = (druidVal: Any) => {
    if (druidVal.isInstanceOf[Double]) {
      druidVal.asInstanceOf[Double].toFloat
    } else if (druidVal.isInstanceOf[BigInt]) {
      druidVal.asInstanceOf[BigInt].toFloat
    } else if (druidVal.isInstanceOf[String]) {
      druidVal.asInstanceOf[String].toFloat
    }else {
      druidVal
    }
  }

  private[this] def pointToDouble(dim : Int)(druidVal : Any) : Any = {
    if (druidVal == null) return null
    val point : Array[Double] = druidVal.toString.split(",").map(_.toDouble)
    if ( dim >= 0 && dim < point.length ) {
      point(dim)
    } else null
  }

  def dimConversion(i : Int) : String = {
    s"dim$i"
  }

  /**
    * conversion from Druid values to Spark values. Most of the conversion cases are handled by
    * cast expressions in the [[org.apache.spark.sql.execution.Project]] operator above the
    * DruidRelation Operator; but Strings need to be converted to [[UTF8String]] strings.
    *
    * @param f
    * @param druidVal
    * @return
    */
  def defaultValueConversion(f : StructField, druidVal : Any) : Any = f.dataType match {
    case TimestampType if druidVal.isInstanceOf[Double] =>
      druidVal.asInstanceOf[Double].longValue().asInstanceOf[SQLTimestamp]
    case StringType if druidVal != null => UTF8String.fromString(druidVal.toString)
    case LongType if druidVal.isInstanceOf[BigInt] =>
      druidVal.asInstanceOf[BigInt].longValue()
    case LongType if druidVal.isInstanceOf[Double] =>
      druidVal.asInstanceOf[Double].longValue()
    case BinaryType if druidVal.isInstanceOf[String] => {
      Base64Variants.getDefaultVariant.decode(druidVal.asInstanceOf[String])
    }
    case _ => druidVal
  }

  // TODO: create an enum of TFs
  private[this] val tfMap: Map[String, Any => Any] = {
    Map[String, Any => Any](
      "toTSWithTZAdj" -> toTSWithTZAdj,
      "toTS" -> toTS,
      "toString" -> toString,
      "toInt" -> toInt,
      "toLong" -> toLong,
      "toFloat" -> toFloat
    ) ++ (0 to 20).map { i =>
      s"dim$i" -> pointToDouble(i) _
    }
  }

  def sparkValue(f : StructField, druidVal: Any, tfName: Option[String]): Any = {
    tfName match {
      case Some(tf) if (tfMap.contains(tf) && druidVal != null) => tfMap(tf)(druidVal)
      case _ => defaultValueConversion(f, druidVal)
    }
  }

  def getTFName(sparkDT: DataType, adjForTZ: Boolean = false): String = sparkDT match {
    case TimestampType if adjForTZ => "toTSWithTZAdj"
    case TimestampType if !adjForTZ => "toTS"
    case StringType if !adjForTZ => "toString"
    case ShortType | IntegerType => "toInt"
    case LongType => "toLong"
    case FloatType => "toFloat"
    case _ => ""
  }
}

/**
  * The '''TaskCancel Thread''' tracks the Spark tasks that are executing Druid Queries.
  * Periodically(currently every 5 secs) it checks if any of the Spark Tasks have been
  * cancelled and relays this to the current [[org.apache.http.concurrent.Cancellable]] associated
  * with the [[org.apache.http.client.methods.HttpExecutionAware]] connectio handling the
  * ''Druid Query''
  *
  */
object TaskCancelHandler extends Logging {

  private val taskMap : concurrent.Map[String, (Cancellable, TaskCancelHolder, TaskContext)] =
    new ConcurrentHashMap[String, (Cancellable, TaskCancelHolder, TaskContext)]().asScala

  class TaskCancelHolder(val queryId : String,
                         val taskContext : TaskContext) extends CancellableHolder {
    def setCancellable(c : Cancellable) : Unit = {
      log.debug("set cancellable for query {}", queryId)
      taskMap(queryId) = (c, this, taskContext)
    }
    @volatile
    var wasCancelTriggered : Boolean = false
  }


  def registerQueryId(queryId : String, taskContext : TaskContext) : TaskCancelHolder = {
    log.debug("register query {}", queryId)
    new TaskCancelHolder(queryId, taskContext)
  }

  def clearQueryId(queryId : String) : Unit = taskMap.remove(queryId)

  val secs5 : Long = 5 * 1000

  object cancelCheckThread extends Runnable with Logging {

    def run() : Unit = {

      while(true) {
        Thread.sleep(secs5)
        log.debug(s"cancelThread woke up")
        var canceledTasks : Seq[String] = Seq()
        taskMap.foreach{t =>
          val (queryId, (req, cancellableHolder, taskContext)) = t
          log.debug(s"checking task stageid=${taskContext.stageId()}, " +
            s"partitionId=${taskContext.partitionId()}, " +
            s"isInterrupted=${taskContext.isInterrupted()}")
          if (taskContext.isInterrupted()) {
            try {
              cancellableHolder.wasCancelTriggered = true
              req.cancel()
              log.info("aborted http request for query {}: {}", Array[Any](queryId, req))
              canceledTasks = canceledTasks :+ queryId
            } catch {
              case e : Throwable => log.warn("failed to abort http request: {}", req)
            }
          }

        }
        canceledTasks.foreach(t => clearQueryId(t))
      }

    }

  }

  {
    val t = new Thread(cancelCheckThread)
    t.setName("DruidRDD-TaskCancelCheckThread")
    t.setDaemon(true)
    t.start()
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy