org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.internal.SQLConf
/**
* A rule to optimize the shuffle reader to local reader iff no additional shuffles
* will be introduced:
* 1. if the input plan is a shuffle, add local reader directly as we can never introduce
* extra shuffles in this case.
* 2. otherwise, add local reader to the probe side of broadcast hash join and
* then run `EnsureRequirements` to check whether additional shuffle introduced.
* If introduced, we will revert all the local readers.
*/
case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
import OptimizeLocalShuffleReader._
private val ensureRequirements = EnsureRequirements(conf)
// The build side is a broadcast query stage which should have been optimized using local reader
// already. So we only need to deal with probe side here.
private def createProbeSideLocalReader(plan: SparkPlan): SparkPlan = {
val optimizedPlan = plan.transformDown {
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) =>
val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildLeft) =>
val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
}
val numShuffles = ensureRequirements.apply(optimizedPlan).collect {
case e: ShuffleExchangeExec => e
}.length
// Check whether additional shuffle introduced. If introduced, revert the local reader.
if (numShuffles > 0) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
" to additional shuffles will be introduced.")
plan
} else {
optimizedPlan
}
}
private def createLocalReader(plan: SparkPlan): CustomShuffleReaderExec = {
plan match {
case c @ CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
CustomShuffleReaderExec(
s, getPartitionSpecs(s, Some(c.partitionSpecs.length)), LOCAL_SHUFFLE_READER_DESCRIPTION)
case s: ShuffleQueryStageExec =>
CustomShuffleReaderExec(s, getPartitionSpecs(s, None), LOCAL_SHUFFLE_READER_DESCRIPTION)
}
}
// TODO: this method assumes all shuffle blocks are the same data size. We should calculate the
// partition start indices based on block size to avoid data skew.
private def getPartitionSpecs(
shuffleStage: ShuffleQueryStageExec,
advisoryParallelism: Option[Int]): Seq[ShufflePartitionSpec] = {
val numMappers = shuffleStage.shuffle.numMappers
val numReducers = shuffleStage.shuffle.numPartitions
val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
val splitPoints = if (numMappers == 0) {
Seq.empty
} else {
equallyDivide(numReducers, math.max(1, expectedParallelism / numMappers))
}
(0 until numMappers).flatMap { mapIndex =>
(splitPoints :+ numReducers).sliding(2).map {
case Seq(start, end) => PartialMapperPartitionSpec(mapIndex, start, end)
}
}
}
/**
* To equally divide n elements into m buckets, basically each bucket should have n/m elements,
* for the remaining n%m elements, add one more element to the first n%m buckets each. Returns
* a sequence with length numBuckets and each value represents the start index of each bucket.
*/
private def equallyDivide(numElements: Int, numBuckets: Int): Seq[Int] = {
val elementsPerBucket = numElements / numBuckets
val remaining = numElements % numBuckets
val splitPoint = (elementsPerBucket + 1) * remaining
(0 until remaining).map(_ * (elementsPerBucket + 1)) ++
(remaining until numBuckets).map(i => splitPoint + (i - remaining) * elementsPerBucket)
}
override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}
plan match {
case s: SparkPlan if canUseLocalShuffleReader(s) =>
createLocalReader(s)
case s: SparkPlan =>
createProbeSideLocalReader(s)
}
}
}
object OptimizeLocalShuffleReader {
val LOCAL_SHUFFLE_READER_DESCRIPTION: String = "local"
object BroadcastJoinWithShuffleLeft {
def unapply(plan: SparkPlan): Option[(SparkPlan, BuildSide)] = plan match {
case join: BroadcastHashJoinExec if canUseLocalShuffleReader(join.left) =>
Some((join.left, join.buildSide))
case _ => None
}
}
object BroadcastJoinWithShuffleRight {
def unapply(plan: SparkPlan): Option[(SparkPlan, BuildSide)] = plan match {
case join: BroadcastHashJoinExec if canUseLocalShuffleReader(join.right) =>
Some((join.right, join.buildSide))
case _ => None
}
}
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) =>
s.shuffle.canChangeNumPartitions && s.mapStats.isDefined
case _ => false
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy