All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.rdd.CassandraTableScanRDD.scala Maven / Gradle / Ivy

The newest version!
package com.datastax.spark.connector.rdd

import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.datasource.ScanHelper
import com.datastax.spark.connector.datasource.ScanHelper.CqlQueryParts
import com.datastax.spark.connector.rdd.CassandraLimit._
import com.datastax.spark.connector.rdd.partitioner.dht.{TokenFactory, Token => ConnectorToken}
import com.datastax.spark.connector.rdd.partitioner.{CassandraPartition, CassandraPartitionGenerator, CqlTokenRange, NodeAddresses, _}
import com.datastax.spark.connector.rdd.reader._
import com.datastax.spark.connector.util.CountingIterator
import com.datastax.spark.connector.writer.RowWriterFactory
import org.apache.spark.metrics.InputMetricsUpdater
import org.apache.spark.rdd.{PartitionCoalescer, RDD}
import org.apache.spark.{Partition, Partitioner, SparkContext, TaskContext}

import java.io.IOException
import scala.language.existentials
import scala.reflect.ClassTag
import scala.jdk.CollectionConverters._


/** RDD representing a Table Scan of A Cassandra table.
  *
  * This class is the main entry point for analyzing data in Cassandra database with Spark.
  * Obtain objects of this class by calling
  * [[com.datastax.spark.connector.SparkContextFunctions.cassandraTable]].
  *
  * Configuration properties should be passed in the [[org.apache.spark.SparkConf SparkConf]]
  * configuration of [[org.apache.spark.SparkContext SparkContext]].
  * `CassandraRDD` needs to open connection to Cassandra, therefore it requires appropriate
  * connection property values to be present in [[org.apache.spark.SparkConf SparkConf]].
  * For the list of required and available properties, see
  * [[com.datastax.spark.connector.cql.CassandraConnector CassandraConnector]].
  *
  * `CassandraRDD` divides the data set into smaller partitions, processed locally on every
  * cluster node. A data partition consists of one or more contiguous token ranges.
  * To reduce the number of roundtrips to Cassandra, every partition is fetched in batches.
  *
  * The following properties control the number of partitions and the fetch size:
  * - spark.cassandra.input.split.sizeInMB: approx amount of data to be fetched into a single Spark
  *   partition, default 512 MB
  * - spark.cassandra.input.fetch.sizeInRows:  number of CQL rows fetched per roundtrip,
  *   default 1000
  *
  * A `CassandraRDD` object gets serialized and sent to every Spark Executor, which then
  * calls the `compute` method to fetch the data on every node. The `getPreferredLocations`
  * method tells Spark the preferred nodes to fetch a partition from, so that the data for
  * the partition are at the same node the task was sent to. If Cassandra nodes are collocated
  * with Spark nodes, the queries are always sent to the Cassandra process running on the same
  * node as the Spark Executor process, hence data are not transferred between nodes.
  * If a Cassandra node fails or gets overloaded during read, the queries are retried
  * to a different node.
  *
  * By default, reads are performed at ConsistencyLevel.LOCAL_ONE in order to leverage data-locality
  * and minimize network traffic. This read consistency level is controlled by the
  * spark.cassandra.input.consistency.level property.
  */
class CassandraTableScanRDD[R] private[connector](
    @transient val sc: SparkContext,
    val connector: CassandraConnector,
    val keyspaceName: String,
    val tableName: String,
    val columnNames: ColumnSelector = AllColumns,
    val where: CqlWhereClause = CqlWhereClause.empty,
    val limit: Option[CassandraLimit] = None,
    val clusteringOrder: Option[ClusteringOrder] = None,
    val readConf: ReadConf = ReadConf(),
    overridePartitioner: Option[Partitioner] = None,
    val tokenRangeFilter: (ConnectorToken[_], ConnectorToken[_]) => Boolean = (_, _) => true)(
  implicit
    val classTag: ClassTag[R],
    @transient val rowReaderFactory: RowReaderFactory[R])
  extends CassandraRDD[R](sc, Seq.empty)
  with CassandraTableRowReaderProvider[R] {

  override type Self = CassandraTableScanRDD[R]

  override protected def copy(
    columnNames: ColumnSelector = columnNames,
    where: CqlWhereClause = where,
    limit: Option[CassandraLimit] = limit,
    clusteringOrder: Option[ClusteringOrder] = None,
    readConf: ReadConf = readConf,
    connector: CassandraConnector = connector): Self = {

    require(sc != null,
      "RDD transformation requires a non-null SparkContext. " +
        "Unfortunately SparkContext in this CassandraRDD is null. " +
        "This can happen after CassandraRDD has been deserialized. " +
        "SparkContext is not Serializable, therefore it deserializes to null." +
        "RDD transformations are not allowed inside lambdas used in other RDD transformations.")

    new CassandraTableScanRDD[R](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = overridePartitioner,
      tokenRangeFilter = tokenRangeFilter)
  }



  override protected def convertTo[B : ClassTag : RowReaderFactory]: CassandraTableScanRDD[B] = {
    new CassandraTableScanRDD[B](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = overridePartitioner,
      tokenRangeFilter = tokenRangeFilter)
  }

  /**
    * Internal method for assigning a partitioner to this RDD, this lacks type safety checks for
    * the Partitioner of type [K, V, T]. End users will use the implicit provided in
    * [[CassandraTableScanPairRDDFunctions]]
    */
  private[connector] def withPartitioner[K, V, T <: ConnectorToken[V]](
    partitioner: Option[Partitioner]): CassandraTableScanRDD[R] = {

    val cassPart = partitioner match {
      case Some(newPartitioner: CassandraPartitioner[_, _, _]) => {
        this.partitioner match {
          case Some(currentPartitioner: CassandraPartitioner[_, _, _]) =>
            /** Preserve the mapping set by the current partitioner **/
            logDebug(
              s"""Preserving Partitioner: $currentPartitioner with mapping
                 |${currentPartitioner.keyMapping}""".stripMargin)
            Some(
              newPartitioner
                .withTableDef(tableDef)
                .withKeyMapping(currentPartitioner.keyMapping))
          case _ =>
            logDebug(s"Assigning new Partitioner $newPartitioner")
            Some(newPartitioner.withTableDef(tableDef))
        }
      }
      case Some(other: Partitioner) => throw new IllegalArgumentException(
        s"""Unable to assign
          |non-CassandraPartitioner $other to CassandraTableScanRDD """.stripMargin)
      case None => None
    }

    new CassandraTableScanRDD[R](
      sc = sc,
      connector = connector,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf,
      overridePartitioner = cassPart,
      tokenRangeFilter = tokenRangeFilter)
  }

  /** Selects a subset of columns mapped to the key and returns an RDD of pairs.
    * Similar to the builtin Spark keyBy method, but this one uses implicit
    * RowReaderFactory to construct the key objects.
    * The selected columns must be available in the CassandraRDD.
    *
    * If the selected columns contain the complete partition key a
    * `CassandraPartitioner` will also be created.
    *
    * @param columns column selector passed to the rrf to create the row reader,
    *                useful when the key is mapped to a tuple or a single value
    */
  def keyBy[K](columns: ColumnSelector)(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] = {

    val kRRF = implicitly[RowReaderFactory[K]]
    val vRRF = rowReaderFactory
    implicit val kvRRF = new KeyValueRowReaderFactory[K, R](columns, kRRF, vRRF)

    val selectedColumnNames = columns.selectFrom(tableDef).map(_.columnName).toSet
    val partitionKeyColumnNames = PartitionKeyColumns.selectFrom(tableDef).map(_.columnName).toSet

    if (selectedColumnNames.asJava.containsAll(partitionKeyColumnNames.asJava)) {
      val partitioner = partitionGenerator.partitioner[K](columns)
      logDebug(
        s"""Made partitioner ${partitioner} for $this""".stripMargin)
      convertTo[(K, R)].withPartitioner(partitioner)

    } else {
      convertTo[(K, R)]
    }
  }

  /** Extracts a key of the given class from the given columns.
    *
    *  @see `keyBy(ColumnSelector)` */
  def keyBy[K](columns: ColumnRef*)(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] =
    keyBy(SomeColumns(columns: _*))

  /** Extracts a key of the given class from all the available columns.
    *
    * @see `keyBy(ColumnSelector)` */
  def keyBy[K]()(implicit
    classtag: ClassTag[K],
    rrf: RowReaderFactory[K],
    rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, R)] =
    keyBy(AllColumns)

  def minimalSplitCount: Int = {
    context.defaultParallelism * 2 + 1
  }

  @transient lazy val partitionGenerator = ScanHelper.getPartitionGenerator(
    connector,
    tableDef,
    where,
    minimalSplitCount,
    readConf.splitCount,
    splitSize)

  /**
    * This method overrides the default spark behavior and will not create a CoalesceRDD. Instead it will reduce
    * the number of partitions by adjusting the partitioning of C* data on read. Using this method will override
    * spark.cassandra.input.split.size.
    * The method is useful with where() method call, when actual size of data is smaller then the table size.
    * It has no effect if a partition key is used in where clause.
    *
    * @param numPartitions      number of partitions
    * @param shuffle            whether to call shuffle after
    * @param partitionCoalescer is ignored if no shuffle, or just passed to shuffled CoalesceRDD
    * @param ord
    * @return new CassandraTableScanRDD with predefined number of partitions
    */

  override def coalesce(numPartitions: Int, shuffle: Boolean = false, partitionCoalescer: Option[PartitionCoalescer])(implicit ord: Ordering[R] = null): RDD[R]
  = {
    val rdd = copy(readConf = readConf.copy(splitCount = Some(numPartitions)))
    if (shuffle) {
      rdd.superCoalesce(numPartitions, shuffle, partitionCoalescer)
    } else {
      rdd
    }
  }

  private def superCoalesce(numPartitions: Int, shuffle: Boolean = false, partitionCoalescer: Option[PartitionCoalescer])(implicit ord: Ordering[R] = null) =
    super.coalesce(numPartitions, shuffle, partitionCoalescer);


  @transient override val partitioner = overridePartitioner

  override def getPartitions: Array[Partition] = {
    verify() // let's fail fast
    val partitions: Array[Partition] = partitioner match {
      case Some(cassPartitioner: CassandraPartitioner[_, _, _]) => {
        cassPartitioner.verify()
        cassPartitioner.partitions.toArray[Partition]
      }
      case Some(other: Partitioner) =>
        throw new IllegalArgumentException(s"Invalid partitioner $other")

      case None => partitionGenerator.partitions.toArray[Partition]
    }

    logDebug(s"Created total ${partitions.length} partitions for $keyspaceName.$tableName.")
    logTrace("Partitions: \n" + partitions.mkString("\n"))
    partitions
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    split.asInstanceOf[CassandraPartition[_, _]].endpoints

  override def compute(split: Partition, context: TaskContext): Iterator[R] = {
    val partition = split.asInstanceOf[CassandraPartition[TokenFactory.V, TokenFactory.T]]
    val tokenRanges = // Only let token ranges that shouldn't be skipped through
      partition.tokenRanges.filter { cqlRange =>
        val (start, end) = (cqlRange.range.start.asInstanceOf[ConnectorToken[_]],
          cqlRange.range.end.asInstanceOf[ConnectorToken[_]])
        val result = tokenRangeFilter(start, end)

        logInfo(s"tokenRangeFilter(${start}, ${end}) = ${result}")

        result
      }
    val metricsUpdater = InputMetricsUpdater(context, readConf)

    val columnNames = selectedColumnRefs.map(_.selectedAs).toIndexedSeq

    val scanner = connector.connectionFactory.getScanner(readConf, connector.conf, columnNames)
    val queryParts = CqlQueryParts(selectedColumnRefs, where, limit, clusteringOrder)

    // Iterator flatMap trick flattens the iterator-of-iterator structure into a single iterator.
    // flatMap on iterator is lazy, therefore a query for the next token range is executed not earlier
    // than all of the rows returned by the previous query have been consumed
    val rowIterator = tokenRanges.iterator.flatMap { range =>
      try {
        val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, consistencyLevel, fetchSize)
        val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics)
        val result = iteratorWithMetrics.map(rowReader.read(_, scanResult.metadata))
        result
      } catch {
        case t: Throwable =>
          throw new IOException(s"Exception during scan execution for $range : ${t.getMessage}", t)
      }
    }
    val countingIterator = new CountingIterator(rowIterator, limitForIterator(limit))

    context.addTaskCompletionListener { context =>
      val duration = metricsUpdater.finish() / 1000000000d
      logDebug(f"Fetched ${countingIterator.count} rows from $keyspaceName.$tableName " +
        f"for partition ${partition.index} in $duration%.3f s.")
      scanner.close()
      context
    }

    countingIterator
  }

  override def toEmptyCassandraRDD: EmptyCassandraRDD[R] = {
    new EmptyCassandraRDD[R](
      sc = sc,
      keyspaceName = keyspaceName,
      tableName = tableName,
      columnNames = columnNames,
      where = where,
      limit = limit,
      clusteringOrder = clusteringOrder,
      readConf = readConf)
  }

  override def cassandraCount(): Long = {
    columnNames match {
      case SomeColumns(_) =>
        logWarning("You are about to count rows but an explicit projection has been specified.")
      case _ =>
    }

    val counts = CassandraTableScanRDD.countRDD(this)
    counts.reduce(_ + _)
  }

}

object CassandraTableScanRDD {

  def apply[T : ClassTag : RowReaderFactory](
    sc: SparkContext,
    keyspaceName: String,
    tableName: String): CassandraTableScanRDD[T] = {

    new CassandraTableScanRDD[T](
      sc = sc,
      connector = CassandraConnector(sc),
      keyspaceName = keyspaceName,
      tableName = tableName,
      readConf = ReadConf.fromSparkConf(sc.getConf),
      columnNames = AllColumns,
      where = CqlWhereClause.empty)
  }

  def apply[K, V](
      sc: SparkContext,
      keyspaceName: String,
      tableName: String)(
    implicit
      keyCT: ClassTag[K],
      valueCT: ClassTag[V],
      rrf: RowReaderFactory[(K, V)],
      rwf: RowWriterFactory[K]): CassandraTableScanRDD[(K, V)] = {

    val rdd = new CassandraTableScanRDD[(K, V)](
      sc = sc,
      connector = CassandraConnector(sc),
      keyspaceName = keyspaceName,
      tableName = tableName,
      readConf = ReadConf.fromSparkConf(sc.getConf),
      columnNames = AllColumns,
      where = CqlWhereClause.empty)
    rdd.withPartitioner(rdd.partitionGenerator.partitioner[K](PartitionKeyColumns))
  }

  /**
    * It is used by cassandraCount() and spark sql cassandra source to push down counts to cassandra
    * @param rdd
    * @tparam R
    * @return rdd, each partitions will have only one long value: number of rows in the partition
    */
  def countRDD[R] (rdd: CassandraTableScanRDD[R]): CassandraTableScanRDD[Long] = {
    new CassandraTableScanRDD[Long](
      sc = rdd.sc,
      connector = rdd.connector,
      keyspaceName = rdd.keyspaceName,
      tableName = rdd.tableName,
      columnNames = SomeColumns(RowCountRef),
      where = rdd.where,
      limit = rdd.limit,
      clusteringOrder = rdd.clusteringOrder,
      readConf = rdd.readConf,
      tokenRangeFilter = rdd.tokenRangeFilter)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy