All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.lucidworks.spark.rdd.SelectSolrRDD.scala Maven / Gradle / Ivy

package com.lucidworks.spark.rdd

import com.lucidworks.spark.query.StreamingResultsIterator
import com.lucidworks.spark.util.QueryConstants._
import com.lucidworks.spark.util.{SolrQuerySupport, SolrSupport}
import com.lucidworks.spark._
import org.apache.solr.client.solrj.SolrQuery
import org.apache.solr.common.SolrDocument
import org.apache.solr.common.params.ShardParams
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition, SparkContext, TaskContext}

import scala.collection.JavaConverters

class SelectSolrRDD(
    zkHost: String,
    collection: String,
    @transient private val sc: SparkContext,
    requestHandler: Option[String] = None,
    query : Option[String] = Option(DEFAULT_QUERY),
    fields: Option[Array[String]] = None,
    rows: Option[Int] = Option(DEFAULT_PAGE_SIZE),
    splitField: Option[String] = None,
    splitsPerShard: Option[Int] = None,
    solrQuery: Option[SolrQuery] = None,
    uKey: Option[String] = None,
    val maxRows: Option[Int] = None,
    val accumulator: Option[SparkSolrAccumulator] = None)
  extends SolrRDD[SolrDocument](zkHost, collection, sc, uKey = uKey)
  with LazyLogging {

  protected def copy(
    requestHandler: Option[String] = requestHandler,
    query: Option[String] = query,
    fields: Option[Array[String]] = fields,
    rows: Option[Int] = rows,
    splitField: Option[String] = splitField,
    splitsPerShard: Option[Int] = splitsPerShard,
    solrQuery: Option[SolrQuery] = solrQuery,
    uKey: Option[String] = uKey,
    maxRows: Option[Int] = maxRows): SelectSolrRDD = {
      new SelectSolrRDD(zkHost, collection, sc, requestHandler, query, fields, rows, splitField, splitsPerShard, solrQuery, uKey, maxRows, accumulator)
  }

  @DeveloperApi
  override def compute(split: Partition, context: TaskContext): Iterator[SolrDocument] = {

    val iterator: StreamingResultsIterator = split match {
      case partition: SelectSolrRDDPartition => {
        val url = getReplicaToQuery(partition, context.attemptNumber())
        val query = partition.query
        logger.debug(s"Using the shard url ${url} for getting partition data for split: ${split.index}")
        val solrRequestHandler = requestHandler.getOrElse(DEFAULT_REQUEST_HANDLER)
        query.setRequestHandler(solrRequestHandler)
        logger.debug(s"Using cursorMarks to fetch documents from ${partition.preferredReplica} for query: ${partition.query}")
        val resultsIterator = new StreamingResultsIterator(SolrSupport.getCachedHttpSolrClient(url, zkHost), partition.query, partition.cursorMark)
        context.addTaskCompletionListener[Unit] { (context) =>
          logger.info(f"Fetched ${resultsIterator.getNumDocs} rows from shard $url for partition ${split.index}")
        }
        resultsIterator
      }
      case p: SolrLimitPartition => {
        // this is a single partition for the entire query ... we'll read all rows at once
        val query = p.query
        val solrRequestHandler = requestHandler.getOrElse(DEFAULT_REQUEST_HANDLER)
        query.setRequestHandler(solrRequestHandler)
        query.setRows(p.maxRows)
        query.set("distrib.singlePass", "true")
        query.setStart(null) // important! must start as null else the Iterator will advance the start position by the row size
        val resultsIterator = new StreamingResultsIterator(SolrSupport.getCachedCloudClient(p.zkhost), query)
        resultsIterator.setMaxSampleDocs(p.maxRows)
        context.addTaskCompletionListener[Unit] { (context) =>
          logger.info(f"Fetched ${resultsIterator.getNumDocs} rows from the limit (${p.maxRows}) partition of ${p.collection}")
        }
        resultsIterator
      }
      case partition: AnyRef => throw new Exception("Unknown partition type '" + partition.getClass)
    }
    if (accumulator.isDefined)
      iterator.setAccumulator(accumulator.get)
    JavaConverters.asScalaIteratorConverter(iterator.iterator()).asScala
  }

  override def getPartitions: Array[Partition] = {
    val query = if (solrQuery.isEmpty) buildQuery else solrQuery.get
    val rq = requestHandler.getOrElse(DEFAULT_REQUEST_HANDLER)
    // if the user requested a max # of rows, use a single partition
    if (maxRows.isDefined) {
      logger.debug(s"Using a single limit partition for a maxRows=${maxRows} query.")
      return Array(SolrLimitPartition(0, zkHost, collection, maxRows.get, query))
    }

    val shardsTolerant : Boolean =
      if (query.get(ShardParams.SHARDS_TOLERANT) != null)
        query.get(ShardParams.SHARDS_TOLERANT).toBoolean
      else
        false
    val shards = SolrSupport.buildShardList(zkHost, collection, shardsTolerant)
    val numReplicas = shards.head.replicas.length
    val numSplits = splitsPerShard.getOrElse(calculateSplitsPerShard(query, shards.size, numReplicas))

    logger.debug(s"rq = $rq, setting query defaults for query = $query uniqueKey = $uniqueKey")
    SolrQuerySupport.setQueryDefaultsForShards(query, uniqueKey)
    // Freeze the index by adding a filter query on _version_ field
    val max = SolrQuerySupport.getMaxVersion(SolrSupport.getCachedCloudClient(zkHost), collection, query, DEFAULT_SPLIT_FIELD)
    if (max.isDefined) {
      val rangeFilter = DEFAULT_SPLIT_FIELD + ":[* TO " + max.get + "]"
      logger.debug("Range filter added to the query: " + rangeFilter)
      query.addFilterQuery(rangeFilter)
    }

    logger.debug(s"Using splitField=$splitField, splitsPerShard=$splitsPerShard, and numReplicas=$numReplicas for computing partitions.")

    logger.info(s"Updated Solr query: ${query.toString}")
    val partitions : Array[Partition] = if (numSplits > 1) {
      val splitFieldName = splitField.getOrElse(DEFAULT_SPLIT_FIELD)
      logger.debug(s"Applied $numSplits intra-shard splits on the $splitFieldName field for $collection to better utilize all active replicas. Set the 'split_field' option to override this behavior or set the 'splits_per_shard' option = 1 to disable splits per shard.")
      SolrPartitioner.getSplitPartitions(shards, query, splitFieldName, numSplits)
    } else {
      // no explicit split field and only one replica || splits_per_shard was explicitly set to 1, no intra-shard splitting needed
      SolrPartitioner.getShardPartitions(shards, query)
    }

    if (logger.underlying.isTraceEnabled()) {
      logger.trace(s"Found ${partitions.length} partitions: ${partitions.mkString(",")}")
    } else {
      logger.info(s"Found ${partitions.length} partitions")
    }
    partitions
  }

  override def query(q: String): SelectSolrRDD = copy(query = Some(q))

  override def query(solrQuery: SolrQuery): SelectSolrRDD = copy (solrQuery = Some (solrQuery))

  override def select(fl: String): SelectSolrRDD = copy(fields = Some(fl.split(",")))

  override def select(fl: Array[String]): SelectSolrRDD = copy(fields = Some(fl))

  override def rows(rows: Int): SelectSolrRDD = copy(rows = Some(rows))

  override def doSplits(): SelectSolrRDD = copy(splitField = Some(DEFAULT_SPLIT_FIELD))

  override def splitField(field: String): SelectSolrRDD = copy(splitField = Some(field))

  override def splitsPerShard(splitsPerShard: Int): SelectSolrRDD = copy(splitsPerShard = Some(splitsPerShard))

  override def requestHandler(requestHandler: String): SelectSolrRDD = copy(requestHandler = Some(requestHandler))

  override def buildQuery: SolrQuery = {
    var solrQuery : SolrQuery = SolrQuerySupport.toQuery(query.get)
    if (!solrQuery.getFields.eq(null) && solrQuery.getFields.nonEmpty) {
      solrQuery = solrQuery.setFields(fields.getOrElse(Array.empty[String]):_*)
    }
    if (!solrQuery.getRows.eq(null) && rows.isDefined) {
      solrQuery = solrQuery.setRows(rows.get)
    }

    solrQuery.set("collection", collection)
    solrQuery
  }

}

object SelectSolrRDD {
  def apply(zkHost: String, collection: String, sparkContext: SparkContext): SelectSolrRDD = {
    new SelectSolrRDD(zkHost, collection, sparkContext)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy