org.apache.spark.sql.execution.UserProvidedPlanner.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
private object ExtractJoinTopKKeys extends Logging with PredicateHelper {
/** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
type ReturnType =
(Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression],
Option[Expression], LogicalPlan, LogicalPlan)
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case join @ JoinTopK(k, left, right, joinType, condition) =>
logDebug(s"Considering join on: $condition")
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap {
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could
// be joined together
case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
Some((Coalesce(Seq(l, Literal.default(l.dataType))),
Coalesce(Seq(r, Literal.default(r.dataType)))))
case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
Some((Coalesce(Seq(r, Literal.default(r.dataType))),
Coalesce(Seq(l, Literal.default(l.dataType)))))
case other => None
}
val otherPredicates = predicates.filterNot {
case EqualTo(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
case other => false
}
if (joinKeys.nonEmpty) {
val (leftKeys, rightKeys) = joinKeys.unzip
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys,
otherPredicates.reduceOption(And), left, right))
} else {
None
}
case p =>
None
}
}
private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractJoinTopKKeys(
k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) =>
Seq(joins.ShuffledHashJoinTopKExec(
k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr))
case _ =>
Nil
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy