All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.cassandra.CassandraMetadataFunctions.scala Maven / Gradle / Ivy

The newest version!
package org.apache.spark.sql.cassandra

import com.datastax.spark.connector.datasource.CassandraTable
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, ExpressionInfo, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SparkSession, functions}

trait CassandraMetadataFunction extends UnaryExpression with Unevaluable {
  def confParam: String

  def cql: String = confParam.toUpperCase

  def isCollectionType(dataType: DataType) : Boolean = {
    dataType match {
      case _: ArrayType => true
      case _: MapType => true
      case _ => false
    }
  }
}

case class CassandraTTL(child: Expression) extends CassandraMetadataFunction {
  override def nullable: Boolean = false

  override def sql: String = s"TTL(${child.sql})"

  override def dataType: DataType = IntegerType

  override def confParam: String = CassandraSourceRelation.TTLParam.name

  override protected def withNewChildInternal(newChild: Expression): CassandraTTL = copy(child = newChild)
}

case class CassandraWriteTime(child: Expression) extends CassandraMetadataFunction {
  override def nullable: Boolean = false

  override def sql: String = s"WRITETIME(${child.sql})"

  override def dataType: DataType = LongType

  override def confParam: String = CassandraSourceRelation.WriteTimeParam.name

  override protected def withNewChildInternal(newChild: Expression): CassandraWriteTime = copy(child = newChild)
}

object CassandraMetadataFunction {

  def registerMetadataFunctions(session: SparkSession): Unit = {
    session.sessionState.functionRegistry.registerFunction(
      FunctionIdentifier("ttl"),
      ttlBuilder,
    "")
    session.sessionState.functionRegistry.registerFunction(
      FunctionIdentifier("writetime"),
      writeTimeBuilder,
    "")
  }

  private val ttlBuilder: Seq[Expression] => CassandraTTL = (input: Seq[Expression]) =>
    CassandraMetadataFunction.cassandraTTLFunctionBuilder(input)

  val cassandraTTLFunctionDescriptor  = (
    FunctionIdentifier("ttl"),
    new ExpressionInfo(getClass.getSimpleName, "ttl"),
    ttlBuilder)

  def cassandraTTLFunctionBuilder(args: Seq[Expression]): CassandraTTL = {
    if (args.length != 1) {
      throw new AnalysisException(s"Unable to call Cassandra ttl with more than 1 argument, given" +
        s" $args")
    }
    CassandraTTL(args.head)
  }

  private val writeTimeBuilder: Seq[Expression] => CassandraWriteTime = (input: Seq[Expression]) =>
    CassandraMetadataFunction.cassandraWriteTimeFunctionBuilder(input)

  val cassandraWriteTimeFunctionDescriptor  = (
    FunctionIdentifier("writetime"),
    new ExpressionInfo(getClass.getSimpleName, "writetime"),
    writeTimeBuilder)

  def cassandraWriteTimeFunctionBuilder(args: Seq[Expression]): CassandraWriteTime = {
    if (args.length != 1) {
      throw new AnalysisException(s"Unable to call Cassandra writetime with more than 1 argument," +
        s" given $args")
    }
    CassandraWriteTime(args.head)
  }
}

//A Nullable version of Unresolved Attribute to Fix Union's Output checking behavior
class NullableUnresolvedAttribute(name: String) extends UnresolvedAttribute(Seq(name)) {
  override def nullable = true;
}

object CassandraMetaDataRule extends Rule[LogicalPlan] {

  def replaceMetadata(metaDataExpression: CassandraMetadataFunction, plan: LogicalPlan)
  : LogicalPlan = {
    assert(metaDataExpression.child.isInstanceOf[AttributeReference],
      s"""Can only use Cassandra Metadata Functions on Attribute References,
         |found a ${metaDataExpression.child.getClass}""".stripMargin)

    val cassandraColumnName = metaDataExpression.child.asInstanceOf[AttributeReference].name
    val cassandraCql = s"${metaDataExpression.cql}($cassandraColumnName)"

    val (cassandraTable) = plan.collectFirst {
      case DataSourceV2Relation(table: CassandraTable, _, _, _, _)
        if table.tableDef.columnByName.contains(cassandraColumnName) => table }
      .getOrElse(throw new IllegalArgumentException(
        s"Unable to find Cassandra Source Relation for TTL/Writetime for column $cassandraColumnName"))

    val columnDef = cassandraTable.tableDef.columnByName(cassandraColumnName)

    if (columnDef.isPrimaryKeyColumn)
      throw new AnalysisException(s"Unable to use ${metaDataExpression.cql} function on non-normal column ${columnDef.columnName}")

    //Used for CassandraRelation Leaves, giving them a reference to the underlying Metadata
    val (cassandraAttributeReference, cassandraField) = if (columnDef.isMultiCell) {
      (AttributeReference(cassandraCql, ArrayType(metaDataExpression.dataType), nullable = true)(),
        StructField(cassandraCql, ArrayType(metaDataExpression.dataType), true))
      } else {
      (AttributeReference(cassandraCql, metaDataExpression.dataType, nullable = true)(),
        StructField(cassandraCql, metaDataExpression.dataType, true))
      }

    //Used as a placeholder for everywhere except leaf nodes, to be resolved by the Catalyst Analyzer
    val unResolvedAttributeReference =  new NullableUnresolvedAttribute(cassandraCql)

    //Used for any leaf nodes that do not have the ability to produce a true Metadata Value
    val nullAttributeReference = Alias(functions.lit(null).cast(metaDataExpression.dataType).expr, cassandraCql)()

    // Remove Metadata Expressions
    val metadataFunctionRemovedPlan = plan.transformAllExpressions{
      case expression: Expression if expression == metaDataExpression => unResolvedAttributeReference
    }

    // Add Metadata to CassandraSource
    val cassandraSourceModifiedPlan = metadataFunctionRemovedPlan.transform {
      case cassandraRelation@DataSourceV2Relation(table: CassandraTable, _, _, _, _)
        if table.tableDef.columnByName.contains(cassandraColumnName) =>
        val modifiedCassandraTable = table.copy(optionalSchema = Some(table.schema().add(cassandraField)))
        cassandraRelation.copy(
          modifiedCassandraTable,
          cassandraRelation.output :+ cassandraAttributeReference,
        )
    }

    def containsAnyReferenceToTTL(logicalPlan: LogicalPlan): Boolean ={
      val references = Seq(cassandraAttributeReference, nullAttributeReference, unResolvedAttributeReference)
      val input = logicalPlan.inputSet
      references.exists(input.contains)
    }

    /* Find the leaves of unsatisfied TTL references. Replace them either with a Cassandra TTL attribute
    * or a null if no CassandraTTL is possible for that leaf. All other locations are marked as unresolved
    * for the next pass of the Analyzer */
    val fixedPlan = cassandraSourceModifiedPlan.transformDown{
      case plan if (plan.missingInput.contains(unResolvedAttributeReference)) =>
        plan.mapChildren(_.transformUp {
          case child: Project =>
            if (containsAnyReferenceToTTL(child)) {
              //This node's input contains a value with the Cassandra TTL name, add an unresolved reference to it
              child.copy(child.projectList :+ unResolvedAttributeReference, child.child)
            } else {
              /* This node's input is missing any child reference to the Cassandra TTL we are adding add a null column reference
                 with the same name.
                 This is specifically for graphframes which unions Null References with C* columns
               */
              child.copy(child.projectList :+ nullAttributeReference, child.child)
            }
        })
    }

    fixedPlan
  }

  def findMetadataExpressions(logicalPlan: LogicalPlan): Seq[CassandraMetadataFunction] = {
    def findMetadataExpressions(expressions: Seq[Expression]): Seq[CassandraMetadataFunction] = {
      expressions.collect{
        case metadata: CassandraMetadataFunction => Seq(metadata)
        case parent: Expression => findMetadataExpressions(parent.children)
      }.flatten
    }
    findMetadataExpressions(logicalPlan.expressions)
  }

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan.transform {
      case planWithMetaData: LogicalPlan if findMetadataExpressions(planWithMetaData).nonEmpty =>
        val metadataExpressions = findMetadataExpressions(planWithMetaData)
        metadataExpressions.foldLeft[LogicalPlan](planWithMetaData) {
          case (plan, expression) => replaceMetadata(expression, plan)
        }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy