All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.spark.connector.rdd.DseGraphUnionedRDD.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright DataStax, Inc.
 *
 * Please see the included license file for details.
 */

package com.datastax.spark.connector.rdd

import java.lang.{String => JString}
import java.util.{Map => JMap}

import scala.annotation.meta.param
import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.{Partitioner, SparkContext}
import com.datastax.spark.connector.cql.{CassandraConnector, Schema, TableDef}
import com.datastax.spark.connector.rdd.partitioner.dht.{Token, TokenFactory}
import com.datastax.spark.connector.rdd.partitioner.{BucketingRangeIndex, CassandraPartition, TokenGenerator, TokenRangeWithPartitionIndex}
import com.datastax.spark.connector.util.{Logging, schemaFromCassandra}
import com.datastax.spark.connector.writer.RowWriter
import scala.jdk.CollectionConverters._

/**
  * A Java Friendly api for DseGraphUnionedRDD to make it easier to call
  * from VertexInputRDD
  */
object DseGraphUnionedRDD {
  def fromJava[R: ClassTag](
      sc: SparkContext,
      rdds: java.util.List[RDD[R]],
      keyspace: String,
      graphLabels: java.util.List[String])(
      implicit
      connector: CassandraConnector): DseGraphUnionedRDD[R] = {
    val rddSeq: Seq[RDD[R]] = rdds.asScala.toSeq
    val labelSeq: Seq[String] = graphLabels.asScala.toSeq
    new DseGraphUnionedRDD(sc, rddSeq, keyspace, labelSeq)
  }
}

/**
  * An extension of UnionRDD which automatically assigns a partitioner based on the way DseGraph stores and partitions
  * vertex information. The graphLabels should map 1 to 1 with the RDDs provided in sequence. This ordering is used
  * to develop a mapping between labels and the RDDs which represent the data for that label.
  */
class DseGraphUnionedRDD[R: ClassTag](
    sc: SparkContext,
    rdds: Seq[RDD[R]],
    keyspace: String,
    graphLabels: Seq[String])(
    implicit
    connector: CassandraConnector = CassandraConnector(sc.getConf))
    extends UnionRDD[R](sc, rdds) {

  assert(graphLabels.length == rdds.length)

  override val partitioner: Option[Partitioner] = Some(new DseGraphPartitioner(keyspace, rdds, graphLabels))
}


/**
  * A custom partitoner specifically for RDDs made for DseGraph.
  *
  * The general idea is
  *
  * For a vertex
  *    Determine the ~label property
  *    Determine which RDD represents the data in that label
  *    Determine the C* token of the vertex given it's label
  *    Determine which partition in the found RDD would contain that Token
  *    Determine the offset of that RDD's partitions in the UnionRDD
  *    Return the partition index added to offset
  *
  */
class DseGraphPartitioner[V, T <: Token[V]](
    keyspaceName: String,
    @(transient @param) rdds: Seq[RDD[_]],
    graphLabels: Seq[String])(
    implicit
    val connector: CassandraConnector)
    extends Partitioner with Logging {

  type ITR = TokenRangeWithPartitionIndex[V, T]

  /* To avoid a dependency on tinkerpop we explictly set this variable. This may change in Future Tinkerpop */
  val LabelAccessor = "~label"

  /** Spark Conf is not Serializable, use it to get a Connector Conf locally **/
  @transient lazy val conf = rdds.head.sparkContext.getConf

  /** We can only use this locally since we cannot serialize the RDDs **/
  @transient private[connector] val rddToLabel = rdds.zip(graphLabels).toMap

  implicit private val tokenFactory: TokenFactory[V, T] =
    TokenFactory.forSystemLocalPartitioner(connector).asInstanceOf[TokenFactory[V, T]]

  implicit lazy private val tokenOrdering = tokenFactory.tokenOrdering
  implicit lazy private val tokenBucketing = tokenFactory.tokenBucketing

  /**
    * Within the Unioned RDD all of the partitions are offset by their relative position in the
    * RDD Dependency ordering. This will allow us to determine where a Partition belongs given
    * the C* TableName.
    *
    * Serializable
    */
  private[connector] val labelToRddOffset: Map[String, Int] = {
    var cumSum = 0
    val offsets = for (rdd <- rdds) yield {
      val offset = (rddToLabel(rdd), cumSum)
      cumSum += rdd.partitions.length
      offset
    }
    offsets.toMap
  }

  /**
    * All of the partitions in each parent RDD, this is serialized so that we can create our
    * lookup tables remotely
    */
  private[connector] val labelToPartitions: Map[String, Array[CassandraPartition[V, T]]] = {
    rdds.map { case rdd: RDD[_] =>
      (rddToLabel(rdd), rdd.partitions.map { case part: CassandraPartition[V, T] => part })
    }.toMap
  }

  /**
    * Generates a quick lookup data structure for every Vertex Label. These structures and will allow for O(1)
    * lookup of the correct partition for a given Cassandra Token.
    *
    * Our BucketingRangeIndex is not serializable, relies on Partitions being serialized and then building the
    * hashing remotely
    */
  @transient private[connector] lazy val labelToTokenRangeLookup: Map[String, BucketingRangeIndex[ITR, T]] = {
    graphLabels.map { case label =>
      val cassandraPartitions = labelToPartitions(label)
      val indexedTokenRanges = for (p <- cassandraPartitions; tr <- p.tokenRanges) yield {
        TokenRangeWithPartitionIndex(tr.range, p.index)
      }

      (label, new BucketingRangeIndex[ITR, T](indexedTokenRanges))
    }.toMap
  }

  /**
    * Lookup the C* Schemas for every RDD in this Union RDD. These are required to determine the
    * partition key when creating TokenGenerators
    *
    * Table Defs are serializable so we can make this and ship it
    */
  private[connector] val labelToTableDef: Map[String, TableDef] = {
    graphLabels.map { labelName =>
      (labelName,
          (schemaFromCassandra(connector, Some(keyspaceName), Some(s"${labelName}_p")).tables ++
            schemaFromCassandra(connector, Some(keyspaceName), Some(s"${labelName}_e")).tables)
              .headOption
              .getOrElse(
                throw new IllegalArgumentException(s"""Couldn't find table for label ${labelName}""")))
    }.toMap
  }

  /**
    * Creates a HashMap -> C* Token mapper which can translate a given Map representation of a vertex and
    * extract the partition keys required and transform them into a C* token.
    *
    * Must be Made Remotely. Token Generators are not Serializable but only relies on labels
    * and tableDefs, both of which are serializable
    */
  @transient lazy private[connector] val labelToTokenGenerator: Map[String, TokenGenerator[JMap[JString, AnyRef]]] =
    graphLabels.map { graphLabel =>
      val tableDef = labelToTableDef(graphLabel)
      (graphLabel,
          new TokenGenerator(
            connector,
            tableDef,
            new MapRowWriter(tableDef.partitionKey.map(_.columnName))))
    }.toMap


  /**
    * For a given Map based Vertex Key determine which table it belongs to. From this table lookup the correct
    * RDD and the partitioning for that RDD. Once narrowed down generate the Token given the now known table. This
    * token will then be compared against the quick lookup Bucket structure for that RDD. We then add that partition
    * index to the offset for that RDD.
    *
    * Triggers the creation of all lazy transient properties
    */
  override def getPartition(key: Any): Int = key match {
    case vertexId: JMap[JString, AnyRef]@unchecked => {
      val label: String = vertexId.asScala.getOrElse(
        LabelAccessor,
        throw new IllegalArgumentException(s"Couldn't find $LabelAccessor in key $key"))
          .asInstanceOf[String]

      val tokenGenerator = labelToTokenGenerator(label)
      val token = tokenFactory.tokenFromString(tokenGenerator.getStringTokenFor(vertexId))
      val bucketRange = labelToTokenRangeLookup(label)

      val originalPartitionIndex = bucketRange
        .rangesContaining(token)
        .headOption
        .getOrElse(throw new IllegalArgumentException(
          s"Couldn't find $token in for $label"))
        .partitionIndex

      val offSetOfRDD = labelToRddOffset(label)

      originalPartitionIndex + offSetOfRDD
    }
    case x => throw new IllegalArgumentException(s"Unable to partition object $x (${x.getClass})")
  }

  override val numPartitions: Int = rdds.map(_.partitions.length).sum
}

class MapRowWriter(override val columnNames: Seq[String]) extends RowWriter[JMap[JString, AnyRef]] {
  override def readColumnValues(data: JMap[JString, AnyRef], buffer: Array[Any]): Unit =
    columnNames.zipWithIndex.foreach { case (columnName, index) =>
      buffer(index) = data.asScala.getOrElse(columnName,
        throw new IllegalArgumentException(s"""Couldn't find $columnName in $data, unable to generate token"""))
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy