All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.datasource.JoinHelper.scala Maven / Gradle / Ivy

The newest version!
package com.datastax.spark.connector.datasource

import java.util.concurrent.Future

import com.datastax.oss.driver.api.core.cql.{PreparedStatement, Row, SimpleStatement}
import com.datastax.oss.driver.api.core.{ConsistencyLevel, CqlIdentifier, CqlSession}
import com.datastax.spark.connector.cql.{TableDef, getRowBinarySize}
import com.datastax.spark.connector.datasource.ScanHelper.CqlQueryParts
import com.datastax.spark.connector.rdd.{CassandraLimit, CqlWhereClause, ReadConf}
import com.datastax.spark.connector.util.CqlWhereParser.{EqPredicate, InListPredicate, InPredicate, RangePredicate}
import com.datastax.spark.connector.util.{CqlWhereParser, Logging}
import com.datastax.spark.connector.writer.{BoundStatementBuilder, RateLimiter, RowWriter}
import com.datastax.spark.connector.{AllColumns, CassandraRowMetadata, ColumnRef, ColumnSelector, PartitionKeyColumns, PrimaryKeyColumns, SomeColumns}

object JoinHelper extends Logging {

  def joinColumnNames(joinColumns: ColumnSelector, tableDef: TableDef): Seq[ColumnRef] = joinColumns match {
    case AllColumns => throw new IllegalArgumentException(
      "Unable to join against all columns in a Cassandra Table. Only primary key columns allowed."
    )
    case PrimaryKeyColumns =>
      tableDef.primaryKey.map(col => col.columnName: ColumnRef)
    case PartitionKeyColumns =>
      tableDef.partitionKey.map(col => col.columnName: ColumnRef)
    case SomeColumns(cs@_*) =>
      ScanHelper.checkColumnsExistence(cs, tableDef)
      cs.map {
        case c: ColumnRef => c
        case _ => throw new IllegalArgumentException(
          "Unable to join against unnamed columns. No CQL Functions allowed."
        )
      }
  }

  def getJoinQueryString(
    tableDef: TableDef,
    joinColumns: Seq[ColumnRef],
    queryParts: CqlQueryParts) = {

    val whereClauses = queryParts.whereClause.predicates.flatMap(CqlWhereParser.parse)
    val joinColumnNames = joinColumns.map(_.columnName)

    val joinColumnPredicates = whereClauses.collect {
      case EqPredicate(c, _) if joinColumnNames.contains(c) => c
      case InPredicate(c) if joinColumnNames.contains(c) => c
      case InListPredicate(c, _) if joinColumnNames.contains(c) => c
      case RangePredicate(c, _, _) if joinColumnNames.contains(c) => c
    }.toSet

    require(
      joinColumnPredicates.isEmpty,
      s"""Columns specified in both the join on clause and the where clause.
         |Partition key columns are always part of the join clause.
         |Columns in both: ${joinColumnPredicates.mkString(", ")}""".stripMargin
    )


    logDebug("Generating Single Key Query Prepared Statement String")
    logDebug(s"SelectedColumns : ${queryParts.selectedColumnRefs} -- JoinColumnNames : $joinColumnNames")
    val columns = queryParts.selectedColumnRefs.map(_.cql).mkString(", ")
    val joinWhere = joinColumnNames.map(name => s"${CqlIdentifier.fromInternal(name).asCql(true)} = :$name")
    val limitClause = CassandraLimit.limitToClause(queryParts.limitClause)
    val orderBy = queryParts.clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
    val filter = (queryParts.whereClause.predicates ++ joinWhere).mkString(" AND ")
    val quotedKeyspaceName = CqlIdentifier.fromInternal(tableDef.keyspaceName).asCql(true)
    val quotedTableName = CqlIdentifier.fromInternal(tableDef.tableName).asCql(true)
    val query =
      s"SELECT $columns " +
        s"FROM $quotedKeyspaceName.$quotedTableName " +
        s"WHERE $filter $orderBy $limitClause"
    logDebug(s"Query : $query")
    query
  }

  def getJoinPreparedStatement(
    session: CqlSession,
    queryString: String,
    consistencyLevel: ConsistencyLevel): PreparedStatement = {

    val stmt = SimpleStatement.newInstance(queryString).setConsistencyLevel(consistencyLevel).setIdempotent(true)
    session.prepare(stmt)
  }

  def getCassandraRowMetadata(
    session: CqlSession,
    statement: PreparedStatement,
    selectedColumnRefs: IndexedSeq[ColumnRef]): CassandraRowMetadata = {

    val codecRegistry = session.getContext.getCodecRegistry
    val columnNames = selectedColumnRefs.map(_.selectedAs).toIndexedSeq
    CassandraRowMetadata.fromPreparedStatement(columnNames, statement, codecRegistry)
  }

  def getKeyBuilderStatementBuilder[L](
    session: CqlSession,
    rowWriter: RowWriter[L],
    preparedStatement: PreparedStatement,
    whereClause: CqlWhereClause): BoundStatementBuilder[L] = {

    val protocolVersion = session.getContext.getProtocolVersion
    new BoundStatementBuilder[L](rowWriter, preparedStatement, whereClause.values, protocolVersion = protocolVersion)
  }

  /** Prefetches a batchSize of elements at a time **/
  def slidingPrefetchIterator[T](it: Iterator[Future[T]], batchSize: Int): Iterator[T] = {
    val (firstElements, lastElement) = it
      .grouped(batchSize)
      .sliding(2)
      .span(_ => it.hasNext)

    (firstElements.map(_.head) ++ lastElement.flatten).flatten.map(_.get)
  }

  def requestsPerSecondRateLimiter(readConf: ReadConf) = new RateLimiter(
    readConf.readsPerSec.getOrElse(Integer.MAX_VALUE).toLong,
    readConf.readsPerSec.getOrElse(Integer.MAX_VALUE).toLong
  )

  def maybeRateLimit(readConf: ReadConf): (Row => Row) = readConf.throughputMiBPS match {
    case Some(throughput) =>
      val bytesPerSecond: Long = (throughput * 1024 * 1024).toLong
      val rateLimiter = new RateLimiter(bytesPerSecond, bytesPerSecond)
      logDebug(s"Throttling join at $bytesPerSecond bytes per second")
      (row: Row) => {
        rateLimiter.maybeSleep(getRowBinarySize(row))
        row
      }
    case None => identity[Row]
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy