All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.rdd.CassandraMergeJoinRDD.scala Maven / Gradle / Ivy

The newest version!
/**
  * Copyright DataStax, Inc.
  *
  * Please see the included license file for details.
  */
package com.datastax.spark.connector.rdd

import java.io.IOException

import com.datastax.bdp.util.ScalaJavaUtil.asScalaFuture

import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.metrics.InputMetricsUpdater
import org.apache.spark.rdd.RDD
import org.apache.spark.{Partition, SparkContext, TaskContext}
import com.datastax.oss.driver.api.core.CqlSession
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row}
import com.datastax.oss.driver.api.core.metadata.Metadata
import com.datastax.oss.driver.api.core.metadata.token.Token
import com.datastax.spark.connector.CassandraRowMetadata
import com.datastax.spark.connector.cql.{CassandraConnector, ColumnDef}
import com.datastax.spark.connector.rdd.partitioner.{CassandraPartition, CqlTokenRange, NodeAddresses}
import com.datastax.spark.connector.rdd.reader.{PrefetchingResultSetIterator, RowReader}
import com.datastax.spark.connector.types.ColumnType
import com.datastax.spark.connector.util.Quote._
import com.datastax.spark.connector.util.{CountingIterator, MergeJoinIterator, NameTools, schemaFromCassandra}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._

/**
  * A RDD which pulls from two separate CassandraTableScanRDDs which share partition keys and
  * keyspaces. These tables will be joined on READ using a merge iterator. As long as we join
  * on the token of the partition key the two iterators should be read in order.
  */
class CassandraMergeJoinRDD[L,R](
    sc: SparkContext,
    leftScanRDD: CassandraTableScanRDD[L],
    rightScanRDD: CassandraTableScanRDD[R])(
  implicit
    leftClassTag: ClassTag[L],
    rightClassTag: ClassTag[R])
  extends RDD[(Seq[L], Seq[R])](sc, Seq.empty){

  val connector = leftScanRDD.connector

  def getPartitionKey(connector: CassandraConnector, keyspaceName: String, tableName: String): Seq[ColumnDef] = {
    schemaFromCassandra(connector, Some(keyspaceName), Some(tableName)).tables.headOption match {
      case Some(table) => table.partitionKey
      case None => {
        val metadata: Metadata = connector.withSessionDo(_.getMetadata)
        val suggestions = NameTools.getSuggestions(metadata, keyspaceName, tableName)
        val errorMessage = NameTools.getErrorString(keyspaceName, Some(tableName), suggestions)
        throw new IOException(errorMessage)
      }
    }
  }

  def checkValidMergeJoin() {
    val leftPartitionKeyTypes = getPartitionKey(
      leftScanRDD.connector,
      leftScanRDD.keyspaceName,
      leftScanRDD.tableName)
      .map(_.columnType)

    val rightPartitionKeyTypes = getPartitionKey(
      rightScanRDD.connector,
      rightScanRDD.keyspaceName,
      rightScanRDD.tableName)
      .map(_.columnType)

    assert( leftPartitionKeyTypes == rightPartitionKeyTypes,
      "Partition key types do not match between Right and Left RDDs in MergeJoin")
  }

  val TokenColumn = "mj_tok_col"
  def tokenExtractor (row: Row): Token = {
    row.getToken(TokenColumn)
  }

  private lazy val leftPartitionKeyStr =
    getPartitionKey(connector, leftScanRDD.keyspaceName, leftScanRDD.tableName)
      .map(_.columnName)
      .map(quote)
      .mkString(", ")

  private def tokenRangeToCqlQuery[T](
    fromRDD: CassandraTableScanRDD[T],
    range: CqlTokenRange[_, _]): (String, Seq[Any]) = {

    val columns = fromRDD.selectedColumnRefs.map(_.cql).mkString(", ")
    val pk = fromRDD.tableDef.partitionKey.map(colDef => quote(colDef.columnName)).mkString(",")
    val (cql, values) = range.cql(leftPartitionKeyStr)
    val filter = (cql +: fromRDD.where.predicates).filter(_.nonEmpty).mkString(" AND ")
    val limitClause = fromRDD.limit.map(limit => s"LIMIT $limit").getOrElse("")
    val orderBy = fromRDD.clusteringOrder.map(_.toCql(fromRDD.tableDef)).getOrElse("")
    val quotedKeyspaceName = quote(fromRDD.keyspaceName)
    val quotedTableName = quote(fromRDD.tableName)
    val queryTemplate =
      s"SELECT $columns, TOKEN($pk) as $TokenColumn " +
        s"FROM $quotedKeyspaceName.$quotedTableName " +
        s"WHERE $filter $orderBy $limitClause ALLOW FILTERING"
    val queryParamValues = values ++ fromRDD.where.values
    (queryTemplate, queryParamValues)
  }

  private def createStatement(
    session: CqlSession,
    readConf: ReadConf,
    cql: String,
    values: Any*): BoundStatement = {

    try {
      val stmt = session.prepare(cql)
      val converters = stmt.getVariableDefinitions.asScala
        .map(v => ColumnType.converterToCassandra(v.getType))
        .toArray
      val convertedValues =
        for ((value, converter) <- values zip converters)
          yield converter.convert(value)
      stmt.bind(convertedValues: _*)
        .setIdempotent(true)
        .setPageSize(readConf.fetchSizeInRows)
        .setConsistencyLevel(readConf.consistencyLevel)
    }
    catch {
      case t: Throwable =>
        throw new IOException(s"Exception during preparation of $cql: ${t.getMessage}", t)
    }
  }

  private def convertRowIterator[T](
      iterator: Iterator[Row],
      rowReader: RowReader[T],
      columnMetaData: CassandraRowMetadata): Iterator[T] = {

    iterator.map{ row =>
      rowReader.read(row, columnMetaData)
    }
  }

  private def fetchTokenRange[T](
    session: CqlSession,
    fromRDD: CassandraTableScanRDD[T],
    range: CqlTokenRange[_, _],
    inputMetricsUpdater: InputMetricsUpdater): (CassandraRowMetadata, Iterator[Row]) = {

    val (cql, values) = tokenRangeToCqlQuery[T](fromRDD, range)
    logDebug(
      s"Fetching data for range ${range} " +
        s"with $cql " +
        s"with params ${values.mkString("[", ",", "]")}")
    val stmt = createStatement(session, fromRDD.readConf, cql, values: _*)

    import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext

    val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
      val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
      val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session)
      val iterator = new PrefetchingResultSetIterator(rs)
      val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
      logDebug(s"Row iterator for range $range obtained successfully.")
      (columnMetaData, iteratorWithMetrics)
    }.recover {
      case t: Throwable => throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
    }
    Await.result(fetchResult, Duration.Inf)
  }

  @DeveloperApi
  override def compute(split: Partition, context: TaskContext): Iterator[(Seq[L], Seq[R])] = {

    /** Open two sessions if Cluster Configurations are different **/
    def openSessions(): (CqlSession, CqlSession) = {
      if (leftScanRDD.connector == rightScanRDD.connector) {
        val session = leftScanRDD.connector.openSession()
        (session, session)
      } else {
        (leftScanRDD.connector.openSession, rightScanRDD.connector.openSession)
      }
    }

    def closeSessions(leftSession: CqlSession, rightSession : CqlSession): Unit = {
      if (leftSession != rightSession) rightSession.close()
      leftSession.close()
    }

    val (leftSession, rightSession) = openSessions()

    type V = t forSome { type t }
    type T = t forSome { type t <: com.datastax.spark.connector.rdd.partitioner.dht.Token[V] }
    val partition = split.asInstanceOf[CassandraPartition[V, T]]
    val tokenRanges = partition.tokenRanges

    val metricsReadConf = new ReadConf(taskMetricsEnabled =
      leftScanRDD.readConf.taskMetricsEnabled || rightScanRDD.readConf.taskMetricsEnabled)

    val metricsUpdater = InputMetricsUpdater(context, metricsReadConf)

    val mergingIterator = tokenRanges.iterator.flatMap { tokenRange =>
      val (leftMetadata, leftRowIterator) = fetchTokenRange(leftSession, leftScanRDD, tokenRange, metricsUpdater)
      val (rightMetadata, rightRowIterator) = fetchTokenRange(rightSession, rightScanRDD, tokenRange, metricsUpdater)

      val rowMerger = new MergeJoinIterator[Row, Row, Token](
        leftRowIterator,
        rightRowIterator,
        tokenExtractor,
        tokenExtractor
      )
      rowMerger.map { case (t: Token, lRows : Seq[Row], rRows: Seq[Row]) => (
        t,
        convertRowIterator(lRows.iterator, leftScanRDD.rowReader, leftMetadata).toList,
        convertRowIterator(rRows.iterator, rightScanRDD.rowReader, rightMetadata).toList)
      }
    }

    val countingIterator = new CountingIterator(mergingIterator)

    context.addTaskCompletionListener { (context) =>
      val duration = metricsUpdater.finish() / 1000000000d
      logDebug(
        f"""Fetched ${countingIterator.count} rows from
            |${leftScanRDD.keyspaceName} ${leftScanRDD.tableName} and ${rightScanRDD.tableName}
            |for partition ${partition.index} in $duration%.3f s.""".stripMargin)
      closeSessions(leftSession, rightSession)
      context
    }
    val iteratorWithoutToken = countingIterator.map(tuple => (tuple._2, tuple._3))
    iteratorWithoutToken
  }

  override protected def getPartitions: Array[Partition] = {
    checkValidMergeJoin()
    if (leftScanRDD.partitions.length >= rightScanRDD.partitions.length)
      leftScanRDD.partitions else rightScanRDD.partitions
  }

  override def getPreferredLocations(split: Partition): Seq[String] =
    split.asInstanceOf[CassandraPartition[_, _]].endpoints
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy