All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.lucidworks.spark.rdd.StreamingSolrRDD.scala Maven / Gradle / Ivy

package com.lucidworks.spark.rdd

import com.lucidworks.spark.query.{SolrStreamIterator, StreamingExpressionResultIterator, TupleStreamIterator}
import com.lucidworks.spark.util.QueryConstants._
import com.lucidworks.spark.util.{SolrQuerySupport, SolrSupport}
import com.lucidworks.spark._
import org.apache.solr.client.solrj.SolrQuery
import org.apache.solr.common.params.ShardParams
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition, SparkContext, TaskContext}

import scala.collection.JavaConverters

class StreamingSolrRDD(
    zkHost: String,
    collection: String,
    @transient private val sc: SparkContext,
    requestHandler: Option[String] = None,
    query : Option[String] = Option(DEFAULT_QUERY),
    fields: Option[Array[String]] = None,
    rows: Option[Int] = Option(DEFAULT_PAGE_SIZE),
    splitField: Option[String] = None,
    splitsPerShard: Option[Int] = Some(1),
    solrQuery: Option[SolrQuery] = None,
    uKey: Option[String] = None,
    val accumulator: Option[SparkSolrAccumulator] = None)
  extends SolrRDD[java.util.Map[_, _]](zkHost, collection, sc, uKey = uKey)
  with LazyLogging {

  protected def copy(
    requestHandler: Option[String] = requestHandler,
    query: Option[String] = query,
    fields: Option[Array[String]] = fields,
    rows: Option[Int] = rows,
    splitField: Option[String] = splitField,
    splitsPerShard: Option[Int] = splitsPerShard,
    solrQuery: Option[SolrQuery] = solrQuery): StreamingSolrRDD = {
    new StreamingSolrRDD(zkHost, collection, sc, requestHandler, query, fields, rows, splitField, splitsPerShard, solrQuery, uKey, accumulator)
  }

  /*
   * Get an Iterator that uses the export handler in Solr
   */
  @throws(classOf[Exception])
  private def getExportHandlerBasedIterator(shardUrl : String, query : SolrQuery, numWorkers: Int, workerId: Int) = {

    // Direct the queries to each shard, so we don't want distributed
    query.set("distrib", false)

    val sorts = query.getSorts
    val sortParam = query.get("sort")
    if ((sorts == null || sorts.isEmpty) && (sortParam == null || sortParam.isEmpty)) {
      val fields = query.getFields
      if (fields != null) {
        if (fields.contains("id")) {
          query.addSort("id", SolrQuery.ORDER.asc)
        } else {
          val firstField = fields.split(",")(0)
          query.addSort(firstField, SolrQuery.ORDER.asc)
        }
      } else {
        query.addSort("id", SolrQuery.ORDER.asc)
      }
      logger.warn(s"Added required sort clause: "+query.getSorts+
        "; this is probably incorrect so you should provide your own sort criteria.")
    }

    new SolrStreamIterator(shardUrl, SolrSupport.getCachedCloudClient(zkHost), SolrSupport.getCachedHttpSolrClient(shardUrl, zkHost), query, numWorkers, workerId)
  }


  @DeveloperApi
  override def compute(split: Partition, context: TaskContext): Iterator[java.util.Map[_, _]] = {
    logger.debug(s"Computing split: ${split.index}")
    val iterator: TupleStreamIterator = split match {
      case partition: CloudStreamPartition =>
        logger.info(s"Using StreamingExpressionResultIterator to process streaming expression for $partition")
        val resultsIterator = new StreamingExpressionResultIterator(
          SolrSupport.getCachedCloudClient(zkHost),
          SolrSupport.getCachedHttpSolrClient(SolrSupport.getSolrBaseUrl(zkHost) + partition.collection, zkHost), // the baseUrl is just a dummy. It will be later replaced with valid host name at {@code SparkSolrClientCache#getHttpSolrClient}
          partition.collection,
          partition.params)
        resultsIterator
      case partition: ExportHandlerPartition =>

        val url = getReplicaToQuery(partition, context.attemptNumber())
        val query = partition.query
        logger.debug(s"Using the shard url ${url} for getting partition data for split: ${split.index}")
        val solrRequestHandler = requestHandler.getOrElse(DEFAULT_REQUEST_HANDLER)
        query.setRequestHandler(solrRequestHandler)
        logger.debug(s"Using export handler to fetch documents from ${partition.preferredReplica} for query: ${partition.query}")
        val resultsIterator = getExportHandlerBasedIterator(url, query, partition.numWorkers, partition.workerId)
        context.addTaskCompletionListener[Unit] { (context) =>
          logger.info(f"Fetched ${resultsIterator.getNumDocs} rows from shard $url for partition ${split.index}")
        }
        resultsIterator
      case partition: AnyRef => throw new Exception("Unknown partition type '" + partition.getClass)
    }
    if (accumulator.isDefined) {
      iterator.setAccumulator(accumulator.get)
    }
    JavaConverters.asScalaIteratorConverter(iterator.iterator()).asScala
  }

  override def getPartitions: Array[Partition] = {
    val query = if (solrQuery.isEmpty) buildQuery else solrQuery.get
    val rq = requestHandler.getOrElse(DEFAULT_REQUEST_HANDLER)
    if (rq == QT_STREAM || rq == QT_SQL) {
      logger.info(s"Using SolrCloud stream partitioning scheme to process request to $rq for collection $collection using query: $query")
      return Array(CloudStreamPartition(0, zkHost, collection, query))
    }
    logger.info(s"Updated Solr query: ${query.toString}")

    val shardsTolerant : Boolean =
      if (query.get(ShardParams.SHARDS_TOLERANT) != null)
        query.get(ShardParams.SHARDS_TOLERANT).toBoolean
      else
        false

    val shards = SolrSupport.buildShardList(zkHost, collection, shardsTolerant)
    val numReplicas = shards.head.replicas.length
    val numSplits = splitsPerShard.getOrElse(calculateSplitsPerShard(query, shards.size, numReplicas, 100000))
    logger.debug(s"Using splitField=$splitField, splitsPerShard=$splitsPerShard, and numReplicas=$numReplicas for computing partitions.")

    val partitions : Array[Partition] = if (numSplits > 1) {
      val splitFieldName = splitField.getOrElse(DEFAULT_SPLIT_FIELD)
      logger.debug(s"Applied $numSplits intra-shard splits on the $splitFieldName field for $collection to better utilize all active replicas. Set the 'split_field' option to override this behavior or set the 'splits_per_shard' option = 1 to disable splits per shard.")
      query.set("partitionKeys", splitFieldName)
      // Workaround for SOLR-10490. TODO: Replace with SolrPartitioner#getSplitPartitions once SOLR-10490 is resolved
      SolrPartitioner.getExportHandlerPartitions(shards, query, splitFieldName, numSplits)
    } else {
      // no explicit split field and only one replica || splits_per_shard was explicitly set to 1, no intra-shard splitting needed
      SolrPartitioner.getExportHandlerPartitions(shards, query)
    }

    if (logger.underlying.isTraceEnabled()) {
      logger.trace(s"Found ${partitions.length} partitions: ${partitions.mkString(",")}")
    } else {
      logger.info(s"Found ${partitions.length} partitions.")
    }
    partitions
  }

  override def query(q: String): StreamingSolrRDD = copy(query = Some(q))

  override def query(solrQuery: SolrQuery): StreamingSolrRDD = copy(solrQuery = Some(solrQuery))

  override def select(fl: String): StreamingSolrRDD = copy(fields = Some(fl.split(",")))

  override def select(fl: Array[String]): StreamingSolrRDD = copy(fields = Some(fl))

  override def rows(rows: Int): StreamingSolrRDD = copy(rows = Some(rows))

  override def doSplits(): StreamingSolrRDD = copy(splitField = Some(DEFAULT_SPLIT_FIELD))

  override def splitField(field: String): StreamingSolrRDD = copy(splitField = Some(field))

  override def splitsPerShard(splitsPerShard: Int): StreamingSolrRDD = copy(splitsPerShard = Some(splitsPerShard))

  override def requestHandler(requestHandler: String): StreamingSolrRDD = copy(requestHandler = Some(requestHandler))

  override def buildQuery: SolrQuery = {
    var solrQuery : SolrQuery = SolrQuerySupport.toQuery(query.get)
    if (!solrQuery.getFields.eq(null) && solrQuery.getFields.nonEmpty) {
      solrQuery = solrQuery.setFields(fields.getOrElse(Array.empty[String]):_*)
    }
    if (!solrQuery.getRows.eq(null)) {
      solrQuery = solrQuery.setRows(rows.get)
    }

    solrQuery.set("collection", collection)
    solrQuery
  }

}

object StreamingSolrRDD {
  def apply(zkHost: String, collection: String, sparkContext: SparkContext): StreamingSolrRDD = {
    new StreamingSolrRDD(zkHost, collection, sparkContext)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy