All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.cassandra.CassandraSourceRelation.scala Maven / Gradle / Ivy

There is a newer version: 3.0.0-alpha2
Show newest version
package org.apache.spark.sql.cassandra

import java.net.InetAddress
import java.util.{Locale, UUID}

import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, Schema}
import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates
import com.datastax.spark.connector.rdd.{CassandraRDD, CassandraTableScanRDD, ReadConf}
import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner
import com.datastax.spark.connector.types.{InetType, UUIDType, VarIntType}
import com.datastax.spark.connector.util.Quote._
import com.datastax.spark.connector.util.{ConfigParameter, Logging, ReflectionUtil}
import com.datastax.spark.connector.writer.{SqlRowWriter, WriteConf}
import com.datastax.spark.connector.{SomeColumns, _}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.cassandra.CassandraSQLRow.CassandraSQLRowReader
import org.apache.spark.sql.cassandra.DataTypeConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext, sources}
import org.apache.spark.unsafe.types.UTF8String

/**
 * Implements [[BaseRelation]]]], [[InsertableRelation]]]] and [[PrunedFilteredScan]]]]
 * It inserts data to and scans Cassandra table. If filterPushdown is true, it pushs down
 * some filters to CQL
 *
 */
private[cassandra] class CassandraSourceRelation(
    tableRef: TableRef,
    userSpecifiedSchema: Option[StructType],
    filterPushdown: Boolean,
    confirmTruncate: Boolean,
    tableSizeInBytes: Option[Long],
    connector: CassandraConnector,
    readConf: ReadConf,
    writeConf: WriteConf,
    sparkConf: SparkConf,
    override val sqlContext: SQLContext)
  extends BaseRelation
  with InsertableRelation
  with PrunedFilteredScan
  with Logging {

  private[this] val tableDef = Schema.tableFromCassandra(
    connector,
    tableRef.keyspace,
    tableRef.table)

  override def schema: StructType = {
    userSpecifiedSchema.getOrElse(StructType(tableDef.columns.map(toStructField)))
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if (overwrite) {
      if (confirmTruncate) {
        connector.withSessionDo {
          val keyspace = quote(tableRef.keyspace)
          val table = quote(tableRef.table)
          session => session.execute(s"TRUNCATE $keyspace.$table")
        }
      } else {
        throw new UnsupportedOperationException(
          """You are attempting to use overwrite mode which will truncate
          |this table prior to inserting data. If you would merely like
          |to change data already in the table use the "Append" mode.
          |To actually truncate please pass in true value to the option
          |"confirm.truncate" when saving. """.stripMargin)
      }

    }

    implicit val rwf = SqlRowWriter.Factory
    val columns = SomeColumns(data.columns.map(x => x: ColumnRef): _*)
    data.rdd.saveToCassandra(tableRef.keyspace, tableRef.table, columns, writeConf)
  }

  override def sizeInBytes: Long = {
    // If it's not found, use SQLConf default setting
    tableSizeInBytes.getOrElse(sqlContext.conf.defaultSizeInBytes)
  }

  implicit val cassandraConnector = connector
  implicit val readconf = readConf
  private[this] val baseRdd =
    sqlContext.sparkContext.cassandraTable[CassandraSQLRow](tableRef.keyspace, tableRef.table)

  def buildScan(): RDD[Row] = baseRdd.asInstanceOf[RDD[Row]]

  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filterPushdown match {
    case true => predicatePushDown(filters).handledBySpark.toArray
    case false => filters
  }

  lazy val additionalRules: Seq[CassandraPredicateRules] = {
    import CassandraSourceRelation.AdditionalCassandraPushDownRulesParam
    val sc = sqlContext.sparkContext

    /* So we can set this in testing to different values without
     making a new context check local property as well */
    val userClasses: Option[String] =
      sc.getConf.getOption(AdditionalCassandraPushDownRulesParam.name)
        .orElse(Option(sc.getLocalProperty(AdditionalCassandraPushDownRulesParam.name)))

    userClasses match {
      case Some(classes) =>
        classes
          .trim
          .split("""\s*,\s*""")
          .map(ReflectionUtil.findGlobalObject[CassandraPredicateRules])
          .reverse
      case None => AdditionalCassandraPushDownRulesParam.default
    }
  }

  private def predicatePushDown(filters: Array[Filter]) = {
    logInfo(s"Input Predicates: [${filters.mkString(", ")}]")

    val pv = connector.withClusterDo(_.getConfiguration.getProtocolOptions.getProtocolVersion)

    /** Apply built in rules **/
    val bcpp = new BasicCassandraPredicatePushDown(filters.toSet, tableDef, pv)
    val basicPushdown = AnalyzedPredicates(bcpp.predicatesToPushDown, bcpp.predicatesToPreserve)
    logDebug(s"Basic Rules Applied:\n$basicPushdown")

    /** Apply any user defined rules **/
    val finalPushdown =  additionalRules.foldRight(basicPushdown)(
      (rules, pushdowns) => {
        val pd = rules(pushdowns, tableDef, sparkConf)
        logDebug(s"Applied ${rules.getClass.getSimpleName} Pushdown Filters:\n$pd")
        pd
      }
    )

    logDebug(s"Final Pushdown filters:\n$finalPushdown")
    finalPushdown
  }

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    val filteredRdd = {
      if(filterPushdown) {
        val pushdownFilters = predicatePushDown(filters).handledByCassandra.toArray
        maybePushdownFilters(baseRdd, pushdownFilters)
      } else {
        baseRdd
      }
    }
    maybeSelect(filteredRdd, requiredColumns)
  }

  /** Define a type for CassandraRDD[CassandraSQLRow]. It's used by following methods */
  private type RDDType = CassandraRDD[CassandraSQLRow]

  /** Transfer selection to limit to columns specified */
  private def maybeSelect(rdd: RDDType, requiredColumns: Array[String]) : RDD[Row] = {
    val prunedRdd = if (requiredColumns.nonEmpty) {
      rdd.select(requiredColumns.map(column => column: ColumnRef): _*)
    } else {
      rdd match {
        case rdd: CassandraTableScanRDD[_] =>
          CassandraTableScanRDD.countRDD(rdd)
            .mapPartitions(_.flatMap(count => Iterator.fill(count.toInt)(CassandraSQLRow.empty)))
        case _ => rdd
      }

    }
    prunedRdd.asInstanceOf[RDD[Row]]
  }

  /** Push down filters to CQL query */
  private def maybePushdownFilters(rdd: RDDType, filters: Seq[Filter]) : RDDType = {
    whereClause(filters) match {
      case (cql, values) if values.nonEmpty => rdd.where(cql, values: _*)
      case _ => rdd
    }
  }

  /** Construct Cql clause and retrieve the values from filter */
  private def filterToCqlAndValue(filter: Any): (String, Seq[Any]) = {
    filter match {
      case sources.EqualTo(attribute, value)            => (s"${quote(attribute)} = ?", Seq(toCqlValue(attribute, value)))
      case sources.LessThan(attribute, value)           => (s"${quote(attribute)} < ?", Seq(toCqlValue(attribute, value)))
      case sources.LessThanOrEqual(attribute, value)    => (s"${quote(attribute)} <= ?", Seq(toCqlValue(attribute, value)))
      case sources.GreaterThan(attribute, value)        => (s"${quote(attribute)} > ?", Seq(toCqlValue(attribute, value)))
      case sources.GreaterThanOrEqual(attribute, value) => (s"${quote(attribute)} >= ?", Seq(toCqlValue(attribute, value)))
      case sources.In(attribute, values)                 =>
        (quote(attribute) + " IN " + values.map(_ => "?").mkString("(", ", ", ")"), toCqlValues(attribute, values))
      case _ =>
        throw new UnsupportedOperationException(
          s"It's not a valid filter $filter to be pushed down, only >, <, >=, <= and In are allowed.")
    }
  }

  private def toCqlValues(columnName: String, values: Array[Any]): Seq[Any] = {
    values.map(toCqlValue(columnName, _)).toSeq
  }

  /** If column is VarInt column, convert data to BigInteger */
  private def toCqlValue(columnName: String, value: Any): Any = {
    value match {
      case decimal: Decimal =>
        val isVarIntColumn = tableDef.columnByName(columnName).columnType == VarIntType
        if (isVarIntColumn) decimal.toJavaBigDecimal.toBigInteger else decimal
      case utf8String: UTF8String =>
        val columnType = tableDef.columnByName(columnName).columnType
        if (columnType == InetType) {
          InetAddress.getByName(utf8String.toString)
        } else if(columnType == UUIDType) {
          UUID.fromString(utf8String.toString)
        } else {
          utf8String
        }
      case other => other
    }
  }

  /** Construct where clause from pushdown filters */
  private def whereClause(pushdownFilters: Seq[Any]): (String, Seq[Any]) = {
    val cqlValue = pushdownFilters.map(filterToCqlAndValue)
    val cql = cqlValue.map(_._1).mkString(" AND ")
    val args = cqlValue.flatMap(_._2)
    (cql, args)
  }
}

object CassandraSourceRelation {
  val ReferenceSection = "Cassandra DataFrame Source Parameters"

  val TableSizeInBytesParam = ConfigParameter[Option[Long]](
    name = "spark.cassandra.table.size.in.bytes",
    section = ReferenceSection,
    default = None,
    description =
      """Used by DataFrames Internally, will be updated in a future release to
        |retrieve size from Cassandra. Can be set manually now""".stripMargin
  )

  val AdditionalCassandraPushDownRulesParam = ConfigParameter[List[CassandraPredicateRules]] (
    name = "spark.cassandra.sql.pushdown.additionalClasses",
    section = ReferenceSection,
    default = List.empty,
    description =
      """A comma separated list of classes to be used (in order) to apply additional
        | pushdown rules for Cassandra Dataframes. Classes must implement CassandraPredicateRules
      """.stripMargin
  )

  val Properties = Seq(
    AdditionalCassandraPushDownRulesParam,
    TableSizeInBytesParam
  )

  val defaultClusterName = "default"

  def apply(
    tableRef: TableRef,
    sqlContext: SQLContext,
    options: CassandraSourceOptions = CassandraSourceOptions(),
    schema : Option[StructType] = None) : CassandraSourceRelation = {

    val sparkConf = sqlContext.sparkContext.getConf
    val sqlConf = sqlContext.getAllConfs
    val conf =
      consolidateConfs(sparkConf, sqlConf, tableRef, options.cassandraConfs)
    val tableSizeInBytesString = conf.getOption(TableSizeInBytesParam.name)
    val cassandraConnector =
      new CassandraConnector(CassandraConnectorConf(conf))
    val tableSizeInBytes = tableSizeInBytesString match {
      case Some(size) => Option(size.toLong)
      case None =>
        val tokenFactory = forSystemLocalPartitioner(cassandraConnector)
        val dataSizeInBytes =
          new DataSizeEstimates(
            cassandraConnector,
            tableRef.keyspace,
            tableRef.table)(tokenFactory).totalDataSizeInBytes
        if (dataSizeInBytes <= 0L) {
          None
        } else {
          Option(dataSizeInBytes)
        }
    }
    val readConf = ReadConf.fromSparkConf(conf)
    val writeConf = WriteConf.fromSparkConf(conf)

    new CassandraSourceRelation(
      tableRef = tableRef,
      userSpecifiedSchema = schema,
      filterPushdown = options.pushdown,
      confirmTruncate = options.confirmTruncate,
      tableSizeInBytes = tableSizeInBytes,
      connector = cassandraConnector,
      readConf = readConf,
      writeConf = writeConf,
      sparkConf = conf,
      sqlContext = sqlContext)
  }

  /**
   * Consolidate Cassandra conf settings in the order of
   * table level -> keyspace level -> cluster level ->
   * default. Use the first available setting. Default
   * settings are stored in SparkConf.
   */
  def consolidateConfs(
    sparkConf: SparkConf,
    sqlConf: Map[String, String],
    tableRef: TableRef,
    tableConf: Map[String, String]) : SparkConf = {

    //Default settings
    val conf = sparkConf.clone()
    val cluster = tableRef.cluster.getOrElse(defaultClusterName)
    val ks = tableRef.keyspace
    //Keyspace/Cluster level settings
    for (prop <- DefaultSource.confProperties) {
      val lowerCasedProp = prop.toLowerCase(Locale.ROOT)
      val value = Seq(
        tableConf.get(lowerCasedProp),
        sqlConf.get(s"$cluster:$ks/$prop"),
        sqlConf.get(s"$cluster/$prop"),
        sqlConf.get(s"default/$prop"),
        sqlConf.get(prop)).flatten.headOption
      value.foreach(conf.set(prop, _))
    }
    //Set all user properties
    conf.setAll(tableConf -- DefaultSource.confProperties)
    conf
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy