All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.datasource.ScanHelper.scala Maven / Gradle / Ivy

The newest version!
package com.datastax.spark.connector.datasource

import java.io.IOException
import com.datastax.oss.driver.api.core.CqlIdentifier._
import com.datastax.oss.driver.api.core.cql.BoundStatement
import com.datastax.oss.driver.api.core.{ConsistencyLevel, CqlSession}
import com.datastax.spark.connector.cql.{CassandraConnector, ScanResult, Scanner, TableDef}
import com.datastax.spark.connector.rdd.CassandraLimit.limitToClause
import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenFactory}
import com.datastax.spark.connector.rdd.partitioner.{CassandraPartitionGenerator, CqlTokenRange, DataSizeEstimates}
import com.datastax.spark.connector.rdd.{CassandraLimit, ClusteringOrder, CqlWhereClause}
import com.datastax.spark.connector.types.ColumnType
import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, Predicate, RangePredicate}
import com.datastax.spark.connector.util.Quote.quote
import com.datastax.spark.connector.util.{CqlWhereParser, Logging}
import com.datastax.spark.connector.{ColumnName, ColumnRef, TTL, WriteTime}
import org.apache.spark.sql.cassandra.DsePredicateRules.StorageAttachedIndex

import scala.jdk.CollectionConverters._

object ScanHelper extends Logging {


  def checkColumnsExistence(columns: Seq[ColumnRef], tableDef: TableDef): Seq[ColumnRef] = {
    val allColumnNames = tableDef.columns.map(_.columnName).toSet
    val regularColumnNames = tableDef.regularColumns.map(_.columnName).toSet
    val keyspaceName = tableDef.keyspaceName
    val tableName = tableDef.tableName

    def checkSingleColumn(column: ColumnRef) = {
      column match {
        case ColumnName(columnName, _) =>
          if (!allColumnNames.contains(columnName))
            throw new IOException(s"Column $column not found in table $keyspaceName.$tableName")
        case TTL(columnName, _) =>
          if (!regularColumnNames.contains(columnName))
            throw new IOException(s"TTL can be obtained only for regular columns, " +
              s"but column $columnName is not a regular column in table $keyspaceName.$tableName.")
        case WriteTime(columnName, _) =>
          if (!regularColumnNames.contains(columnName))
            throw new IOException(s"TTL can be obtained only for regular columns, " +
              s"but column $columnName is not a regular column in table $keyspaceName.$tableName.")
        case _ =>
      }

      column
    }

    columns.map(checkSingleColumn)
  }


  //We call the types from the void
  type V = t forSome {type t}
  type T = t forSome {type t <: Token[V]}

  def fetchTokenRange(
    scanner: Scanner,
    tableDef: TableDef,
    queryParts: CqlQueryParts,
    range: CqlTokenRange[_, _],
    consistencyLevel: ConsistencyLevel,
    fetchSize: Int): ScanResult = {

    val session = scanner.getSession()

    val (cql, values) = tokenRangeToCqlQuery(range, tableDef, queryParts)

    logDebug(
      s"Fetching data for range ${range.cql(partitionKeyStr(tableDef))} " +
        s"with $cql " +
        s"with params ${values.mkString("[", ",", "]")}")

    val stmt = prepareScanStatement(session, cql, values: _*)
      .setConsistencyLevel(consistencyLevel)
      .setPageSize(fetchSize)
      .setRoutingToken(range.range.startNativeToken())

    val scanResult = scanner.scan(stmt)
    logDebug(s"Row iterator for range ${range.cql(partitionKeyStr(tableDef))} obtained successfully.")

    scanResult
  }

  def partitionKeyStr(tableDef: TableDef) = {
    tableDef.partitionKey.map(_.columnName).map(quote).mkString(", ")
  }

  def tokenRangeToCqlQuery(
    range: CqlTokenRange[_, _],
    tableDef: TableDef,
    cqlQueryParts: CqlQueryParts): (String, Seq[Any]) = {

    val columns = cqlQueryParts.selectedColumnRefs.map(_.cql).mkString(", ")

    val (cql, values) = if (containsPartitionKey(tableDef, cqlQueryParts.whereClause)) {
      ("", Seq.empty)
    } else {
      range.cql(partitionKeyStr(tableDef))
    }
    val filter = (cql +: cqlQueryParts.whereClause.predicates).filter(_.nonEmpty).mkString(" AND ")
    val limitClause = limitToClause(cqlQueryParts.limitClause)
    val orderBy = cqlQueryParts.clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
    val keyspaceName = fromInternal(tableDef.keyspaceName)
    val tableName = fromInternal(tableDef.tableName)
    val queryTemplate =
      s"SELECT $columns " +
        s"FROM ${keyspaceName.asCql(true)}.${tableName.asCql(true)} " +
        s"WHERE $filter $orderBy $limitClause ALLOW FILTERING"
    val queryParamValues = values ++ cqlQueryParts.whereClause.values
    (queryTemplate, queryParamValues)
  }

  def containsPartitionKey(tableDef: TableDef, clause: CqlWhereClause): Boolean = {
    val pk = tableDef.partitionKey.map(_.columnName).toSet
    val wherePredicates: Seq[Predicate] = clause.predicates.flatMap(CqlWhereParser.parse)
    val saiColumns = tableDef.indexes.filter(_.className.contains(StorageAttachedIndex)).map(_.target)

    val whereColumns: Set[String] = wherePredicates.collect {
      case EqPredicate(c, _) if pk.contains(c) => c
      case InPredicate(c) if pk.contains(c) => c
      case InListPredicate(c, _) if pk.contains(c) => c
      case RangePredicate(c, _, _) if pk.contains(c) && !saiColumns.contains(c) =>
        throw new UnsupportedOperationException(
          s"Range predicates on partition key columns (here: $c) are " +
            s"not supported in where. Use filter instead.")
    }.toSet

    val primaryKeyComplete = whereColumns.nonEmpty && whereColumns.size == pk.size
    val whereColumnsAllIndexed = whereColumns.forall(tableDef.isIndexed)

    if (!primaryKeyComplete && !whereColumnsAllIndexed) {
      val missing = pk -- whereColumns
      throw new UnsupportedOperationException(
        s"Partition key predicate must include all partition key columns or partition key columns need" +
          s" to be indexed. Missing columns: ${missing.mkString(",")}"
      )
    }
    primaryKeyComplete
  }

  def prepareScanStatement(session: CqlSession, cql: String, values: Any*): BoundStatement = {
    try {
      val stmt = session.prepare(cql)
      val converters = stmt.getVariableDefinitions.asScala
        .map(v => ColumnType.converterToCassandra(v.getType))
        .toArray
      val convertedValues =
        for ((value, converter) <- values zip converters)
          yield converter.convert(value)
      stmt.bind(convertedValues: _*)
        .setIdempotent(true)
    }
    catch {
      case t: Throwable =>
        throw new IOException(s"Exception during preparation of $cql: ${t.getMessage}", t)
    }
  }

  def getPartitionGenerator(
    connector: CassandraConnector,
    tableDef: TableDef,
    whereClause: CqlWhereClause,
    minSplitCount: Int,
    partitionCount: Option[Int],
    splitSize: Long): CassandraPartitionGenerator[V, T] = {

    implicit val tokenFactory = TokenFactory.forSystemLocalPartitioner(connector)

    if (containsPartitionKey(tableDef, whereClause)) {
      CassandraPartitionGenerator(connector, tableDef, 1)
    } else {
      partitionCount match {
        case Some(splitCount) => {
          CassandraPartitionGenerator(connector, tableDef, splitCount)
        }
        case None => {
          val estimateDataSize = new DataSizeEstimates(connector, tableDef.keyspaceName, tableDef.tableName).dataSizeInBytes
          val splitCount = if (estimateDataSize == Long.MaxValue || estimateDataSize < 0) {
            logWarning(
              s"""Size Estimates has overflowed and calculated that the data size is Infinite.
                 |Falling back to $minSplitCount (2 * SparkCores + 1) Split Count.
                 |This is most likely occurring because you are reading size_estimates
                 |from a DataCenter which has very small primary ranges. Explicitly set
                 |the splitCount when reading to manually adjust this.""".stripMargin)
            minSplitCount
          } else {
            val splitCountEstimate = estimateDataSize / splitSize
            Math.max(splitCountEstimate.toInt, Math.max(minSplitCount, 1))
          }
          CassandraPartitionGenerator(connector, tableDef, splitCount)
        }
      }
    }
  }

  case class CqlQueryParts(
    selectedColumnRefs: IndexedSeq[ColumnRef],
    whereClause: CqlWhereClause = CqlWhereClause.empty,
    limitClause: Option[CassandraLimit] = None,
    clusteringOrder: Option[ClusteringOrder] = None)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy