com.datastax.spark.connector.rdd.CassandraLeftJoinRDD.scala Maven / Gradle / Ivy
The newest version!
package com.datastax.spark.connector.rdd
import com.datastax.oss.driver.api.core.CqlSession
import com.datastax.oss.driver.api.core.cql.AsyncResultSet
import com.datastax.oss.driver.internal.core.cql.ResultSets
import com.datastax.spark.connector._
import com.datastax.spark.connector.util.maybeExecutingAs
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.datasource.JoinHelper
import com.datastax.spark.connector.rdd.reader._
import com.datastax.spark.connector.writer._
import com.google.common.util.concurrent.{FutureCallback, Futures, SettableFuture}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
import org.apache.spark.metrics.InputMetricsUpdater
import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success}
/**
* An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified
* Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take
* advantage of RDDs that have been partitioned with the
* [[com.datastax.spark.connector.rdd.partitioner.ReplicaPartitioner]]
*
* @tparam L item type on the left side of the join (any RDD)
* @tparam R item type on the right side of the join (fetched from Cassandra)
*/
class CassandraLeftJoinRDD[L, R] (
override val left: RDD[L],
val keyspaceName: String,
val tableName: String,
val connector: CassandraConnector,
val columnNames: ColumnSelector = AllColumns,
val joinColumns: ColumnSelector = PartitionKeyColumns,
val where: CqlWhereClause = CqlWhereClause.empty,
val limit: Option[CassandraLimit] = None,
val clusteringOrder: Option[ClusteringOrder] = None,
val readConf: ReadConf = ReadConf(),
manualRowReader: Option[RowReader[R]] = None,
override val manualRowWriter: Option[RowWriter[L]] = None)(
implicit
val leftClassTag: ClassTag[L],
val rightClassTag: ClassTag[R],
@transient val rowWriterFactory: RowWriterFactory[L],
@transient val rowReaderFactory: RowReaderFactory[R])
extends CassandraRDD[(L, Option[R])](left.sparkContext, left.dependencies)
with CassandraTableRowReaderProvider[R]
with AbstractCassandraJoin[L, Option[R]] {
override type Self = CassandraLeftJoinRDD[L, R]
override protected val classTag = rightClassTag
override lazy val rowReader: RowReader[R] = manualRowReader match {
case Some(rr) => rr
case None => rowReaderFactory.rowReader(tableDef, columnNames.selectFrom(tableDef))
}
override protected def copy(
columnNames: ColumnSelector = columnNames,
where: CqlWhereClause = where,
limit: Option[CassandraLimit] = limit,
clusteringOrder: Option[ClusteringOrder] = None,
readConf: ReadConf = readConf,
connector: CassandraConnector = connector
): Self = {
new CassandraLeftJoinRDD[L, R](
left = left,
keyspaceName = keyspaceName,
tableName = tableName,
connector = connector,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}
override def cassandraCount(): Long = {
columnNames match {
case SomeColumns(_) =>
logWarning("You are about to count rows but an explicit projection has been specified.")
case _ =>
}
val counts =
new CassandraLeftJoinRDD[L, Long](
left = left,
connector = connector,
keyspaceName = keyspaceName,
tableName = tableName,
columnNames = SomeColumns(RowCountRef),
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
counts.map(_._2.getOrElse(0L)).reduce(_ + _)
}
def on(joinColumns: ColumnSelector): CassandraLeftJoinRDD[L, R] = {
new CassandraLeftJoinRDD[L, R](
left = left,
connector = connector,
keyspaceName = keyspaceName,
tableName = tableName,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}
/**
* Turns this CassandraLeftJoinRDD into a factory for converting other RDD's after being serialized
* This method is for streaming operations as it allows us to Serialize a template JoinRDD
* and the use that serializable template in the DStream closure. This gives us a fully serializable
* leftJoinWithCassandra operation
*/
private[connector] def applyToRDD(left: RDD[L]): CassandraLeftJoinRDD[L, R] = {
new CassandraLeftJoinRDD[L, R](
left,
keyspaceName,
tableName,
connector,
columnNames,
joinColumns,
where,
limit,
clusteringOrder,
readConf,
Some(rowReader),
Some(rowWriter)
)
}
private[rdd] def fetchIterator(
session: CqlSession,
bsb: BoundStatementBuilder[L],
rowMetadata: CassandraRowMetadata,
leftIterator: Iterator[L],
metricsUpdater: InputMetricsUpdater
): Iterator[(L, Option[R])] = {
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
def pairWithRight(left: L): SettableFuture[Iterator[(L, Option[R])]] = {
val resultFuture = SettableFuture.create[Iterator[(L, Option[R])]]
val leftSide = Iterator.continually(left)
val stmt = bsb.bind(left)
.update(_.setPageSize(readConf.fetchSizeInRows))
.executeAs(readConf.executeAs)
queryExecutor.executeAsync(stmt).onComplete {
case Success(rs) =>
val resultSet = new PrefetchingResultSetIterator(rs)
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
/* This is a much less than ideal place to actually rate limit, we are buffering
these futures this means we will most likely exceed our threshold*/
val throttledIterator = iteratorWithMetrics.map(maybeRateLimit)
val rightSide = resultSet.isEmpty match {
case true => Iterator.single(None)
case false => throttledIterator.map(r => Some(rowReader.read(r, rowMetadata)))
}
resultFuture.set(leftSide.zip(rightSide))
case Failure(throwable) =>
resultFuture.setException(throwable)
}
resultFuture
}
val queryFutures = leftIterator.map(left => {
requestsPerSecondRateLimiter.maybeSleep(1)
pairWithRight(left)
})
JoinHelper.slidingPrefetchIterator(queryFutures, readConf.parallelismLevel).flatten
}
}