All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.DatasetFunctions.scala Maven / Gradle / Ivy

The newest version!
package com.datastax.spark.connector

import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.mapper.DataFrameColumnMapper
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, Encoder}
import org.apache.spark.sql.cassandra.{AlwaysOn, CassandraSourceRelation, DirectJoinSetting}

/** Provides Cassandra-specific methods on [[org.apache.spark.sql.DataFrame]] */
class DatasetFunctions[K: Encoder](dataset: Dataset[K]) extends Serializable {

  val sparkContext: SparkContext = (dataset.sqlContext.sparkContext)

  def directJoin(directJoinSetting: DirectJoinSetting= AlwaysOn): Dataset[K] = {
    CassandraSourceRelation.setDirectJoin(dataset, directJoinSetting)
  }

  /**
   *  Creates a C* table based on the Dataset Struct provided. Optionally
   *  takes in a list of partition columns or clustering columns names. When absent
   *  the first column will be used as the partition key and there will be no clustering
   *  keys.
   */
  @deprecated("Use DatasourceV2 Catalog Api", "3.0.0")
  def createCassandraTable(
    keyspaceName: String,
    tableName: String,
    partitionKeyColumns: Option[Seq[String]] = None,
    clusteringKeyColumns: Option[Seq[String]] = None)(
  implicit
    connector: CassandraConnector = CassandraConnector(sparkContext)): Unit = {

    val protocolVersion = connector.withSessionDo(_.getContext.getProtocolVersion)
    val rawTable = new DataFrameColumnMapper(dataset.schema).newTable(keyspaceName, tableName, protocolVersion)
    val partitionKeyNames = partitionKeyColumns.getOrElse(rawTable.partitionKey.map(_.columnName))
    val clusteringKeyNames = clusteringKeyColumns.getOrElse(Seq.empty)

    createCassandraTableEx(keyspaceName, tableName, partitionKeyNames,
      clusteringKeyNames.map((_, ClusteringColumn.Ascending)))(connector)
  }

    /**
    *  Creates a C* table based on the Dataset Struct provided.
    *  Takes in a list of partition columns, clustering columns names, and optionally, the table options.
    */
  @deprecated("Use DatasourceV2 Catalog Api", "3.0.0")
  def createCassandraTableEx(
    keyspaceName: String,
    tableName: String,
    partitionKeyColumns: Seq[String],
    clusteringKeyColumns: Seq[(String, ClusteringColumn.SortingOrder)],
    ifNotExists: Boolean = false,
    tableOptions: Map[String, String] = Map())(
  implicit
    connector: CassandraConnector = CassandraConnector(sparkContext)): Unit = {

    val protocolVersion = connector.withSessionDo(_.getContext.getProtocolVersion)

    val rawTable = new DataFrameColumnMapper(dataset.schema).newTable(keyspaceName, tableName, protocolVersion)
    val columnMapping = rawTable.columnByName

    val columnNames = columnMapping.keys.toSet
    val partitionKeyNames = partitionKeyColumns
    val clusteringKeyNames = clusteringKeyColumns.map(_._1)
    val regularColumnNames = (columnNames -- (partitionKeyNames ++ clusteringKeyNames)).toSeq

    def missingColumnException(columnName: String, columnType: String) = {
      new IllegalArgumentException(
        s""""$columnName" not Found. Unable to make specified column $columnName a $columnType.
          |Available Columns: $columnNames""".stripMargin)
    }

    val table = rawTable.copy (
      partitionKey = partitionKeyNames
          .map(partitionKeyName =>
            columnMapping.getOrElse(partitionKeyName,
              throw missingColumnException(partitionKeyName, "Partition Key Column")))
          .map(_.copy(columnRole = PartitionKeyColumn))
      ,
      clusteringColumns = clusteringKeyColumns
          .map(clusteringKey =>
            (columnMapping.getOrElse(clusteringKey._1,
              throw missingColumnException(clusteringKey._1, "Clustering Column")),
              clusteringKey._2))
          .zipWithIndex
          .map { case (col, index) => col._1.copy(columnRole = ClusteringColumn(index, col._2))}
      ,
      regularColumns = regularColumnNames
          .map(regularColumnName =>
             columnMapping.getOrElse(regularColumnName,
                  throw missingColumnException(regularColumnName, "Regular Column")))
          .map(_.copy(columnRole = RegularColumn))
      ,
      ifNotExists = ifNotExists
      ,
      tableOptions = tableOptions
    )

    connector.withSessionDo(session => session.execute(table.cql))
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy