All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.rdd.CassandraTableScanRDD.scala Maven / Gradle / Ivy

package com.datastax.spark.connector.rdd

import java.io.IOException

import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.rdd.partitioner._
import com.datastax.spark.connector.rdd.partitioner.dht.{Token => ConnectorToken}
import com.datastax.spark.connector.rdd.reader._
import com.datastax.spark.connector.types.ColumnType
import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, Predicate, RangePredicate}
import com.datastax.spark.connector.util.Quote._
import com.datastax.spark.connector.util.{CountingIterator, CqlWhereParser}
import com.datastax.spark.connector.writer.RowWriterFactory

import com.datastax.driver.core._
import org.apache.spark.metrics.InputMetricsUpdater
import org.apache.spark.{Partition, Partitioner, SparkContext, TaskContext}

import scala.collection.JavaConversions._
import scala.language.existentials
import scala.reflect.ClassTag


/** RDD representing a Table Scan of A Cassandra table.
  *
  * This class is the main entry point for analyzing data in Cassandra database with Spark.
  * Obtain objects of this class by calling
  * [[com.datastax.spark.connector.SparkContextFunctions.cassandraTable]].
  *
  * Configuration properties should be passed in the [[org.apache.spark.SparkConf SparkConf]]
  * configuration of [[org.apache.spark.SparkContext SparkContext]].
  * `CassandraRDD` needs to open connection to Cassandra, therefore it requires appropriate
  * connection property values to be present in [[org.apache.spark.SparkConf SparkConf]].
  * For the list of required and available properties, see
  * [[com.datastax.spark.connector.cql.CassandraConnector CassandraConnector]].
  *
  * `CassandraRDD` divides the data set into smaller partitions, processed locally on every
  * cluster node. A data partition consists of one or more contiguous token ranges.
  * To reduce the number of roundtrips to Cassandra, every partition is fetched in batches.
  *
  * The following properties control the number of partitions and the fetch size:
  * - spark.cassandra.input.split.size_in_mb: approx amount of data to be fetched into a single Spark
  *   partition, default 64 MB
  * - spark.cassandra.input.fetch.size_in_rows:  number of CQL rows fetched per roundtrip,
  *   default 1000
  *
  * A `CassandraRDD` object gets serialized and sent to every Spark Executor, which then
  * calls the `compute` method to fetch the data on every node. The `getPreferredLocations`
  * method tells Spark the preferred nodes to fetch a partition from, so that the data for
  * the partition are at the same node the task was sent to. If Cassandra nodes are collocated
  * with Spark nodes, the queries are always sent to the Cassandra process running on the same
  * node as the Spark Executor process, hence data are not transferred between nodes.
  * If a Cassandra node fails or gets overloaded during read, the queries are retried
  * to a different node.
  *
  * By default, reads are performed at ConsistencyLevel.LOCAL_ONE in order to leverage data-locality
  * and minimize network traffic. This read consistency level is controlled by the
  * spark.cassandra.input.consistency.level property.
  */
class CassandraTableScanRDD[R] private[connector](
    @transient val sc: SparkContext,
    val connector: CassandraConnector,
    val keyspaceName: String,
    val tableName: String,
    val columnNames: ColumnSelector = AllColumns,
    val where: CqlWhereClause = CqlWhereClause.empty,
    val limit: Option[Long] = None,
    val clusteringOrder: Option[ClusteringOrder] = None,
    val readConf: ReadConf = ReadConf(),
    overridePartitioner: Option[Partitioner] = None)(
  implicit
    val classTag: ClassTag[R],
    @transient val rowReaderFactory: RowReaderFactory[R])
  extends CassandraRDD[R](sc, Seq.empty)
  with CassandraTableRowReaderProvider[R] {

  override type Self = CassandraTableScanRDD[R]

  override protected def copy(
    columnNames: ColumnSelector = columnNames,
    where: CqlWhereClause = where,
    limit: Option[Long] = limit,
    clusteringOrder: Option[ClusteringOrder] = None,
    readConf: ReadConf = readConf,
    connector: CassandraConnector = connector): Self = {

    require(sc != null,
      "RDD transformation requires a non-null SparkContext. " +
        "Unfortunately SparkContext in this CassandraRDD is null. " +
        "This can happen after CassandraRDD has been deserialized. " +
        "SparkContext is not Serializable, therefore it deserializes to null." +
        "RDD transformations are not allowed inside lambdas used in other RDD transformations.")

    new CassandraTableScanRDD[R](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = overridePartitioner)
  }



  override protected def convertTo[B : ClassTag : RowReaderFactory]: CassandraTableScanRDD[B] = {
    new CassandraTableScanRDD[B](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = overridePartitioner)
  }

  /**
    * Internal method for assigning a partitioner to this RDD, this lacks type safety checks for
    * the Partitioner of type [K]. End users will use the implicit provided in
    * [[CassandraTableScanPairRDDFunctions]]
    */
  private[connector] def withPartitioner[K, V, T <: ConnectorToken[V]](
    partitioner: Option[Partitioner]): CassandraTableScanRDD[R] = {

    val cassPart = partitioner match {
      case Some(newPartitioner: CassandraPartitioner[K, V, T]) => {
        this.partitioner match {
          case Some(currentPartitioner: CassandraPartitioner[K, V, T]) =>
            /** Preserve the mapping set by the current partitioner **/
            logDebug(
              s"""Preserving Partitioner: $currentPartitioner with mapping
                 |${currentPartitioner.keyMapping}""".stripMargin)
            Some(
              newPartitioner
                .withTableDef(tableDef)
                .withKeyMapping(currentPartitioner.keyMapping))
          case _ =>
            logDebug(s"Assigning new Partitioner $newPartitioner")
            Some(newPartitioner.withTableDef(tableDef))
        }
      }
      case Some(other: Partitioner) => throw new IllegalArgumentException(
        s"""Unable to assign
          |non-CassandraPartitioner $other to CassandraTableScanRDD """.stripMargin)
      case None => None
    }

    new CassandraTableScanRDD[R](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = cassPart)
  }

  /** Selects a subset of columns mapped to the key and returns an RDD of pairs.
    * Similar to the builtin Spark keyBy method, but this one uses implicit
    * RowReaderFactory to construct the key objects.
    * The selected columns must be available in the CassandraRDD.
    *
    * If the selected columns contain the complete partition key a
    * `CassandraPartitioner` will also be created.
    *
    * @param columns column selector passed to the rrf to create the row reader,
    *                useful when the key is mapped to a tuple or a single value
    */
  def keyBy[K](columns: ColumnSelector)(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] = {

    val kRRF = implicitly[RowReaderFactory[K]]
    val vRRF = rowReaderFactory
    implicit val kvRRF = new KeyValueRowReaderFactory[K, R](columns, kRRF, vRRF)

    val selectedColumnNames = columns.selectFrom(tableDef).map(_.columnName).toSet
    val partitionKeyColumnNames = PartitionKeyColumns.selectFrom(tableDef).map(_.columnName).toSet

    if (selectedColumnNames.containsAll(partitionKeyColumnNames)) {
      val partitioner = partitionGenerator.partitioner[K](columns)
      logDebug(
        s"""Made partitioner ${partitioner} for $this""".stripMargin)
      convertTo[(K, R)].withPartitioner(partitioner)

    } else {
      convertTo[(K, R)]
    }
  }

  /** Extracts a key of the given class from the given columns.
    *
    *  @see `keyBy(ColumnSelector)` */
  def keyBy[K](columns: ColumnRef*)(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] =
    keyBy(SomeColumns(columns: _*))

  /** Extracts a key of the given class from all the available columns.
    *
    * @see `keyBy(ColumnSelector)` */
  def keyBy[K]()(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] =
    keyBy(AllColumns)

  @transient lazy val partitionGenerator = {
    if (containsPartitionKey(where)) {
      CassandraPartitionGenerator(connector, tableDef, Some(1), splitSize)
    } else {
      CassandraPartitionGenerator(connector, tableDef, splitCount, splitSize)
    }
  }

  @transient override val partitioner = overridePartitioner

  override def getPartitions: Array[Partition] = {
    verify() // let's fail fast
    val partitions: Array[Partition] = partitioner match {
      case Some(cassPartitioner: CassandraPartitioner[_, _, _]) => {
        cassPartitioner.verify()
        cassPartitioner.partitions.toArray[Partition]
      }
      case Some(other: Partitioner) =>
        throw new IllegalArgumentException(s"Invalid partitioner $other")

      case None => partitionGenerator.partitions.toArray[Partition]
    }

    logDebug(s"Created total ${partitions.length} partitions for $keyspaceName.$tableName.")
    logTrace("Partitions: \n" + partitions.mkString("\n"))
    partitions
  }

  private lazy val nodeAddresses = new NodeAddresses(connector)

  private lazy val partitionKeyStr =
    tableDef.partitionKey.map(_.columnName).map(quote).mkString(", ")

  override def getPreferredLocations(split: Partition): Seq[String] =
    split.asInstanceOf[CassandraPartition[_, _]].endpoints.flatMap(nodeAddresses.hostNames).toSeq

  private def tokenRangeToCqlQuery(range: CqlTokenRange[_, _]): (String, Seq[Any]) = {
    val columns = selectedColumnRefs.map(_.cql).mkString(", ")
    val (cql, values) = if (containsPartitionKey(where)) {
      ("", Seq.empty)
    } else {
      range.cql(partitionKeyStr)
    }
    val filter = (cql +: where.predicates).filter(_.nonEmpty).mkString(" AND ")
    val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("")
    val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
    val quotedKeyspaceName = quote(keyspaceName)
    val quotedTableName = quote(tableName)
    val queryTemplate =
      s"SELECT $columns " +
        s"FROM $quotedKeyspaceName.$quotedTableName " +
        s"WHERE $filter $orderBy $limitClause ALLOW FILTERING"
    val queryParamValues = values ++ where.values
    (queryTemplate, queryParamValues)
  }

  private def createStatement(session: Session, cql: String, values: Any*): Statement = {
    try {
      val stmt = session.prepare(cql)
      stmt.setConsistencyLevel(consistencyLevel)
      val converters = stmt.getVariables
        .map(v => ColumnType.converterToCassandra(v.getType))
        .toArray
      val convertedValues =
        for ((value, converter) <- values zip converters)
        yield converter.convert(value)
      val bstm = stmt.bind(convertedValues: _*)
      bstm.setFetchSize(fetchSize)
      bstm
    }
    catch {
      case t: Throwable =>
        throw new IOException(s"Exception during preparation of $cql: ${t.getMessage}", t)
    }
  }

  private def fetchTokenRange(
    session: Session,
    range: CqlTokenRange[_, _],
    inputMetricsUpdater: InputMetricsUpdater): Iterator[R] = {

    val (cql, values) = tokenRangeToCqlQuery(range)
    logDebug(
      s"Fetching data for range ${range.cql(partitionKeyStr)} " +
        s"with $cql " +
        s"with params ${values.mkString("[", ",", "]")}")
    val stmt = createStatement(session, cql, values: _*)

    try {
      val rs = session.execute(stmt)
      val columnNames = selectedColumnRefs.map(_.selectedAs).toIndexedSeq
      val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames,rs)

      val iterator = new PrefetchingResultSetIterator(rs, fetchSize)
      val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
      val result = iteratorWithMetrics.map(rowReader.read(_, columnMetaData))
      logDebug(s"Row iterator for range ${range.cql(partitionKeyStr)} obtained successfully.")
      result
    } catch {
      case t: Throwable =>
        throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
    }
  }

  override def compute(split: Partition, context: TaskContext): Iterator[R] = {
    val session = connector.openSession()
    val partition = split.asInstanceOf[CassandraPartition[_, _]]
    val tokenRanges = partition.tokenRanges
    val metricsUpdater = InputMetricsUpdater(context, readConf)

    // Iterator flatMap trick flattens the iterator-of-iterator structure into a single iterator.
    // flatMap on iterator is lazy, therefore a query for the next token range is executed not earlier
    // than all of the rows returned by the previous query have been consumed
    val rowIterator = tokenRanges.iterator.flatMap(
      fetchTokenRange(session, _: CqlTokenRange[_, _], metricsUpdater))
    val countingIterator = new CountingIterator(rowIterator, limit)

    context.addTaskCompletionListener { (context) =>
      val duration = metricsUpdater.finish() / 1000000000d
      logDebug(f"Fetched ${countingIterator.count} rows from $keyspaceName.$tableName " +
        f"for partition ${partition.index} in $duration%.3f s.")
      session.close()
    }
    countingIterator
  }

  override def toEmptyCassandraRDD: EmptyCassandraRDD[R] = {
    new EmptyCassandraRDD[R](
      sc = sc,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf)
  }

  override def cassandraCount(): Long = {
    columnNames match {
      case SomeColumns(_) =>
        logWarning("You are about to count rows but an explicit projection has been specified.")
      case _ =>
    }

    val counts =
      new CassandraTableScanRDD[Long](
        sc = sc,
        connector = connector,
        keyspaceName = keyspaceName,
        tableName = tableName,
        columnNames = SomeColumns(RowCountRef),
        where = where,
        limit = limit,
        clusteringOrder = clusteringOrder,
        readConf = readConf)

    counts.reduce(_ + _)
  }


  private def containsPartitionKey(clause: CqlWhereClause): Boolean = {
    val pk = tableDef.partitionKey.map(_.columnName).toSet
    val wherePredicates: Seq[Predicate] = clause.predicates.flatMap(CqlWhereParser.parse)

    val whereColumns: Set[String] = wherePredicates.collect {
      case EqPredicate(c, _) if pk.contains(c) => c
      case InPredicate(c) if pk.contains(c) => c
      case InListPredicate(c, _) if pk.contains(c) => c
      case RangePredicate(c, _, _) if pk.contains(c) =>
        throw new UnsupportedOperationException(
          s"Range predicates on partition key columns (here: $c) are " +
            s"not supported in where. Use filter instead.")
    }.toSet

    val primaryKeyComplete = whereColumns.nonEmpty && whereColumns.size == pk.size
    val whereColumnsAllIndexed = whereColumns.forall(tableDef.isIndexed)

    if (!primaryKeyComplete && !whereColumnsAllIndexed) {
      val missing = pk -- whereColumns
      throw new UnsupportedOperationException(
        s"Partition key predicate must include all partition key columns or partition key columns need" +
          s" to be indexed. Missing columns: ${missing.mkString(",")}"
      )
    }
    primaryKeyComplete
  }
}

object CassandraTableScanRDD {

  def apply[T : ClassTag : RowReaderFactory](
    sc: SparkContext,
    keyspaceName: String,
    tableName: String): CassandraTableScanRDD[T] = {

    new CassandraTableScanRDD[T](
      sc = sc,
      connector = CassandraConnector(sc.getConf),
      keyspaceName = keyspaceName,
      tableName = tableName,
      readConf = ReadConf.fromSparkConf(sc.getConf),
      columnNames = AllColumns,
      where = CqlWhereClause.empty)
  }

  def apply[K, V](
      sc: SparkContext,
      keyspaceName: String,
      tableName: String)(
    implicit
      keyCT: ClassTag[K],
      valueCT: ClassTag[V],
      rrf: RowReaderFactory[(K, V)],
      rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, V)] = {

    val rdd = new CassandraTableScanRDD[(K, V)](
      sc = sc,
      connector = CassandraConnector(sc.getConf),
      keyspaceName = keyspaceName,
      tableName = tableName,
      readConf = ReadConf.fromSparkConf(sc.getConf),
      columnNames = AllColumns,
      where = CqlWhereClause.empty)
    rdd.withPartitioner(rdd.partitionGenerator.partitioner[K](PartitionKeyColumns))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy