org.apache.spark.sql.redis.RedisSourceRelation.scala Maven / Gradle / Ivy
The newest version!
package org.apache.spark.sql.redis
import java.util.UUID
import java.util.{List => JList}
import com.redislabs.provider.redis.rdd.Keys
import com.redislabs.provider.redis.util.ConnectionUtils.withConnection
import com.redislabs.provider.redis.util.Logging
import com.redislabs.provider.redis.util.PipelineUtils._
import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisDataTypeHash, RedisDataTypeString, RedisEndpoint, RedisNode, toRedisContext}
import org.apache.commons.lang3.SerializationUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.redis.RedisSourceRelation._
import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import redis.clients.jedis.{PipelineBase, Protocol}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
class RedisSourceRelation(override val sqlContext: SQLContext,
parameters: Map[String, String],
userSpecifiedSchema: Option[StructType])
extends BaseRelation
with InsertableRelation
with PrunedFilteredScan
with Keys
with Serializable
with Logging {
private implicit val redisConfig: RedisConfig = RedisConfig.fromSparkConfAndParameters(
sqlContext.sparkContext.getConf, parameters)
implicit private val readWriteConfig: ReadWriteConfig = {
val global = ReadWriteConfig.fromSparkConf(sqlContext.sparkContext.getConf)
// override global config with dataframe specific settings
global.copy(
scanCount = parameters.get(SqlOptionScanCount).map(_.toInt).getOrElse(global.scanCount),
maxPipelineSize = parameters.get(SqlOptionMaxPipelineSize).map(_.toInt).getOrElse(global.maxPipelineSize)
)
}
logInfo(s"Redis config initial host: ${redisConfig.initialHost}")
@transient private val sc = sqlContext.sparkContext
/**
* Will be filled while saving data to Redis or reading from Redis.
*/
@volatile private var currentSchema: StructType = _
/** parameters (sorted alphabetically) **/
private val filterKeysByTypeEnabled = parameters.get(SqlOptionFilterKeysByType).exists(_.toBoolean)
private val inferSchemaEnabled = parameters.get(SqlOptionInferSchema).exists(_.toBoolean)
private val iteratorGroupingSize = parameters.get(SqlOptionIteratorGroupingSize).map(_.toInt)
.getOrElse(SqlOptionIteratorGroupingSizeDefault)
private val keyColumn = parameters.get(SqlOptionKeyColumn)
private val keyName = keyColumn.getOrElse("_id")
private val keysPatternOpt: Option[String] = parameters.get(SqlOptionKeysPattern)
private val numPartitions = parameters.get(SqlOptionNumPartitions).map(_.toInt)
.getOrElse(SqlOptionNumPartitionsDefault)
private val persistenceModel = parameters.getOrDefault(SqlOptionModel, SqlOptionModelHash)
private val persistence = RedisPersistence(persistenceModel)
private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName)
private val ttl = parameters.get(SqlOptionTTL).map(_.toInt).getOrElse(0)
/**
* redis key pattern for rows, based either on the 'keys.pattern' or 'table' parameter
*/
private val dataKeyPattern = keysPatternOpt
.orElse(tableNameOpt.map(tableName => tableDataKeyPattern(tableName)))
.getOrElse {
val msg = s"Neither '$SqlOptionKeysPattern' or '$SqlOptionTableName' option is set."
throw new IllegalArgumentException(msg)
}
/**
* Support key column extraction from Redis prefix pattern. Otherwise,
* return Redis key unmodified.
*/
private val keysPrefixPattern =
if (dataKeyPattern.endsWith("*") && dataKeyPattern.count(_ == '*') == 1) {
dataKeyPattern
} else {
""
}
// check specified parameters
if (tableNameOpt.isDefined && keysPatternOpt.isDefined) {
throw new IllegalArgumentException(s"Both options '$SqlOptionKeysPattern' and '$SqlOptionTableName' are set. " +
s"You should only use either one.")
}
override def schema: StructType = {
if (currentSchema == null) {
currentSchema = userSpecifiedSchema.getOrElse {
if (inferSchemaEnabled) inferSchema() else loadSchema()
}
}
currentSchema
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val schema = userSpecifiedSchema.getOrElse(data.schema)
// write schema, so that we can load dataframe back
currentSchema = saveSchema(schema)
if (overwrite) {
// truncate the table
sc.fromRedisKeyPattern(dataKeyPattern).foreachPartition { partition =>
groupKeysByNode(redisConfig.hosts, partition).foreach { case (node, keys) =>
val conn = node.connect()
foreachWithPipeline(conn, keys) { (pipeline, key) =>
(pipeline: PipelineBase).del(key) // fix ambiguous reference to overloaded definition
}
conn.close()
}
}
}
// write data
data.foreachPartition { partition: Iterator[Row] =>
// grouped iterator to only allocate memory for a portion of rows
partition.grouped(iteratorGroupingSize).foreach { batch =>
// the following can be optimized to not create a map
val rowsWithKey: Map[String, Row] = batch.map(row => dataKeyId(row) -> row).toMap
groupKeysByNode(redisConfig.hosts, rowsWithKey.keysIterator).foreach { case (node, keys) =>
val conn = node.connect()
foreachWithPipeline(conn, keys) { (pipeline, key) =>
val row = rowsWithKey(key)
val encodedRow = persistence.encodeRow(keyName, row)
persistence.save(pipeline, key, encodedRow, ttl)
}
conn.close()
}
}
}
}
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
logInfo("build scan")
val keysRdd = sc.fromRedisKeyPattern(dataKeyPattern, partitionNum = numPartitions)
if (requiredColumns.isEmpty) {
keysRdd.map { _ =>
new GenericRow(Array[Any]())
}
} else {
// filter schema columns, it should be in the same order as given 'requiredColumns'
val requiredSchema = {
val fieldsMap = schema.fields.map(f => (f.name, f)).toMap
val requiredFields = requiredColumns.map { c =>
fieldsMap(c)
}
StructType(requiredFields)
}
val keyType =
if (persistenceModel == SqlOptionModelBinary) {
RedisDataTypeString
} else {
RedisDataTypeHash
}
keysRdd.mapPartitions { partition =>
// grouped iterator to only allocate memory for a portion of rows
partition.grouped(iteratorGroupingSize).flatMap { batch =>
groupKeysByNode(redisConfig.hosts, batch.iterator)
.flatMap { case (node, keys) =>
scanRows(node, keys, keyType, requiredSchema, requiredColumns)
}
}
}
}
}
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
/**
* @return true if data exists in redis
*/
def isEmpty: Boolean = {
sc.fromRedisKeyPattern(dataKeyPattern, partitionNum = numPartitions).isEmpty()
}
/**
* @return true if no data exists in redis
*/
def nonEmpty: Boolean = {
!isEmpty
}
/**
* @return table name
*/
private def tableName(): String = {
tableNameOpt.getOrElse(throw new IllegalArgumentException(s"Option '$SqlOptionTableName' is not set."))
}
/**
* @return redis key for the row
*/
private def dataKeyId(row: Row): String = {
val id = keyColumn.map(id => row.getAs[Any](id)).map(_.toString).getOrElse(uuid())
dataKey(tableName(), id)
}
/**
* infer schema from a random redis row
*/
private def inferSchema(): StructType = {
if (persistenceModel != SqlOptionModelHash) {
throw new IllegalArgumentException(s"Cannot infer schema from model '$persistenceModel'. " +
s"Currently, only '$SqlOptionModelHash' is supported")
}
val keys = sc.fromRedisKeyPattern(dataKeyPattern)
if (keys.isEmpty()) {
throw new IllegalStateException("No key is available")
} else {
val firstKey = keys.first()
val node = getMasterNode(redisConfig.hosts, firstKey)
withConnection(node.connect()) { conn =>
val results = conn.hgetAll(firstKey).asScala.toSeq :+ keyName -> firstKey
val fields = results.map(kv => StructField(kv._1, StringType)).toArray
StructType(fields)
}
}
}
/**
* write schema to redis
*/
private def saveSchema(schema: StructType): StructType = {
val key = schemaKey(tableName())
logInfo(s"saving schema $key")
val schemaNode = getMasterNode(redisConfig.hosts, key)
val conn = schemaNode.connect()
val schemaBytes = SerializationUtils.serialize(schema)
conn.set(key.getBytes, schemaBytes)
conn.close()
schema
}
/**
* read schema from redis
*/
private def loadSchema(): StructType = {
val key = schemaKey(tableName())
logInfo(s"loading schema $key")
val schemaNode = getMasterNode(redisConfig.hosts, key)
val conn = schemaNode.connect()
val schemaBytes = conn.get(key.getBytes)
if (schemaBytes == null) {
throw new IllegalStateException(s"Unable to read dataframe schema by key '$key'. " +
s"If dataframe was not persisted by Spark, provide a schema explicitly with .schema() " +
s"or use 'infer.schema' option. ")
}
val schema = SerializationUtils.deserialize[StructType](schemaBytes)
conn.close()
schema
}
/**
* read rows from redis
*/
private def scanRows(node: RedisNode, keys: Seq[String], keyType: String, schema: StructType,
requiredColumns: Seq[String]): Seq[Row] = {
withConnection(node.connect()) { conn =>
val filteredKeys =
if (filterKeysByTypeEnabled) {
val keyTypes = mapWithPipeline(conn, keys) { (pipeline, key) =>
pipeline.`type`(key)
}
keys.zip(keyTypes).filter(_._2 == keyType).map(_._1)
} else {
keys
}
val pipelineValues = mapWithPipeline(conn, filteredKeys) { (pipeline, key) =>
persistence.load(pipeline, key, requiredColumns)
}
val keysAndValues = filteredKeys.zip(pipelineValues)
// if specific key (not pattern) is provided and the value doesn't exist, filter them out
val filteredKeysAndValues =
if (Keys.isRedisRegex(dataKeyPattern)) {
keysAndValues
} else {
keysAndValues.filter {
case (_, null) => false // binary model
case (_, value: JList[_]) if value.forall(_ == null) => false // hash model
case _ => true
}
}
filteredKeysAndValues.map { case (key, value) =>
val keyMap = keyName -> tableKey(keysPrefixPattern, key)
persistence.decodeRow(keyMap, value, schema, requiredColumns)
}
}
}
}
object RedisSourceRelation {
def schemaKey(tableName: String): String = s"_spark:$tableName:schema"
def dataKey(tableName: String, id: String = uuid()): String = {
s"$tableName:$id"
}
def uuid(): String = UUID.randomUUID().toString.replace("-", "")
def tableDataKeyPattern(tableName: String): String = s"$tableName:*"
def tableKey(keysPrefixPattern: String, redisKey: String): String = {
if (keysPrefixPattern.endsWith("*")) {
// keysPattern*
redisKey.substring(keysPrefixPattern.length - 1)
} else {
redisKey
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy