
org.apache.spark.sql.catalyst.planning.patterns.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
trait OperationHelper {
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] =
AttributeMap(fields.collect {
case a: Alias => (a.toAttribute, a.child)
})
protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = {
// use transformUp instead of transformDown to avoid dead loop
// in case of there's Alias whose exprId is the same as its child attribute.
expr.transformUp {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref)
.map(Alias(_, name)(a.exprId, a.qualifier))
.getOrElse(a)
case a: AttributeReference =>
aliases.get(a)
.map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a)
}
}
}
/**
* A pattern that matches any number of project or filter operations on top of another relational
* operator. All filter operators are collected and their conditions are broken up and returned
* together with the top project operator.
* [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if
* necessary.
*/
object PhysicalOperation extends OperationHelper with PredicateHelper {
def unapply(plan: LogicalPlan): Option[ReturnType] = {
val (fields, filters, child, _) = collectProjectsAndFilters(plan)
Some((fields.getOrElse(child.output), filters, child))
}
/**
* Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
* Here are two examples for alias in-lining/substitution.
* Before:
* {{{
* SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
* SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
* }}}
* After:
* {{{
* SELECT key AS c1 FROM t1 WHERE key > 10
* SELECT key AS c2 FROM t1 WHERE key > 10
* }}}
*/
private def collectProjectsAndFilters(plan: LogicalPlan):
(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression]) =
plan match {
case Project(fields, child) if fields.forall(_.deterministic) =>
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
case Filter(condition, child) if condition.deterministic =>
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
case h: ResolvedHint =>
collectProjectsAndFilters(h.child)
case other =>
(None, Nil, other, AttributeMap(Seq()))
}
}
/**
* A variant of [[PhysicalOperation]]. It matches any number of project or filter
* operations even if they are non-deterministic, as long as they satisfy the
* requirement of CollapseProject and CombineFilters.
*/
object ScanOperation extends OperationHelper with PredicateHelper {
type ScanReturnType = Option[(Option[Seq[NamedExpression]],
Seq[Expression], LogicalPlan, AttributeMap[Expression])]
def unapply(plan: LogicalPlan): Option[ReturnType] = {
collectProjectsAndFilters(plan) match {
case Some((fields, filters, child, _)) =>
Some((fields.getOrElse(child.output), filters, child))
case None => None
}
}
private def hasCommonNonDeterministic(
expr: Seq[Expression],
aliases: AttributeMap[Expression]): Boolean = {
expr.exists(_.collect {
case a: AttributeReference if aliases.contains(a) => aliases(a)
}.exists(!_.deterministic))
}
private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
plan match {
case Project(fields, child) =>
collectProjectsAndFilters(child) match {
case Some((_, filters, other, aliases)) =>
// Follow CollapseProject and only keep going if the collected Projects
// do not have common non-deterministic expressions.
if (!hasCommonNonDeterministic(fields, aliases)) {
val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
} else {
None
}
case None => None
}
case Filter(condition, child) =>
collectProjectsAndFilters(child) match {
case Some((fields, filters, other, aliases)) =>
// Follow CombineFilters and only keep going if 1) the collected Filters
// and this filter are all deterministic or 2) if this filter is the first
// collected filter and doesn't have common non-deterministic expressions
// with lower Project.
val substitutedCondition = substitute(aliases)(condition)
val canCombineFilters = (filters.nonEmpty && filters.forall(_.deterministic) &&
substitutedCondition.deterministic) || filters.isEmpty
if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) {
Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition),
other, aliases))
} else {
None
}
case None => None
}
case h: ResolvedHint =>
collectProjectsAndFilters(h.child)
case other =>
Some((None, Nil, other, AttributeMap(Seq())))
}
}
}
/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*
* Null-safe equality will be transformed into equality as joining key (replace null with default
* value).
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
/** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild, joinHint) */
type ReturnType =
(JoinType, Seq[Expression], Seq[Expression],
Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
def unapply(join: Join): Option[ReturnType] = join match {
case Join(left, right, joinType, condition, hint) =>
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could
// be joined together
case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
Seq((Coalesce(Seq(l, Literal.default(l.dataType))),
Coalesce(Seq(r, Literal.default(r.dataType)))),
(IsNull(l), IsNull(r))
)
case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
Seq((Coalesce(Seq(r, Literal.default(r.dataType))),
Coalesce(Seq(l, Literal.default(l.dataType)))),
(IsNull(r), IsNull(l))
)
case other => None
}
val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case Equality(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
case _ => false
}
if (joinKeys.nonEmpty) {
val (leftKeys, rightKeys) = joinKeys.unzip
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right, hint))
} else {
None
}
}
}
/**
* A pattern that collects the filter and inner joins.
*
* Filter
* |
* inner Join
* / \ ----> (Seq(plan0, plan1, plan2), conditions)
* Filter plan2
* |
* inner join
* / \
* plan0 plan1
*
* Note: This pattern currently only works for left-deep trees.
*/
object ExtractFiltersAndInnerJoins extends PredicateHelper {
/**
* Flatten all inner joins, which are next to each other.
* Return a list of logical plans to be joined with a boolean for each plan indicating if it
* was involved in an explicit cross join. Also returns the entire list of join conditions for
* the left-deep tree.
*/
def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner)
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(left, joinType)
(plans ++ Seq((right, joinType)), conditions ++
cond.toSeq.flatMap(splitConjunctivePredicates))
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))
case _ => (Seq((plan, parentJoinType)), Seq.empty)
}
def unapply(plan: LogicalPlan)
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]
= plan match {
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint))
if hint == JoinHint.NONE =>
Some(flattenJoin(f))
case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE =>
Some(flattenJoin(j))
case _ => None
}
}
/**
* An extractor used when planning the physical execution of an aggregation. Compared with a logical
* aggregation, the following transformations are performed:
* - Unnamed grouping expressions are named so that they can be referred to across phases of
* aggregation
* - Aggregations that appear multiple times are deduplicated.
* - The computation of the aggregations themselves is separated from the final result. For
* example, the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
* computation that computes `count.resultAttribute + 1`.
*/
object PhysicalAggregation {
// groupingExpressions, aggregateExpressions, resultExpressions, child
type ReturnType =
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of semantically distinct aggregate expressions and re-write expressions so
// that they reference the single copy of the aggregate function which actually gets computed.
// Non-deterministic aggregate expressions are not deduplicated.
val equivalentAggregateExpressions = new EquivalentExpressions
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap
// The original `resultExpressions` are a set of expressions which may reference
// aggregate expressions, grouping column values, and constants. When aggregate operator
// emits output rows, we will use `resultExpressions` to generate an output projection
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
// so replace each aggregate expression by its corresponding attribute in the set:
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
// Similar to AggregateExpression
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
case expression =>
// Since we're using `namedGroupingAttributes` to extract the grouping key
// columns, we need to replace grouping key expressions with their corresponding
// attributes. We do not rely on the equality check at here since attributes may
// differ cosmetically. Instead, we use semanticEquals.
groupExpressionMap.collectFirst {
case (expr, ne) if expr semanticEquals expression => ne.toAttribute
}.getOrElse(expression)
}.asInstanceOf[NamedExpression]
}
Some((
namedGroupingExpressions.map(_._2),
aggregateExpressions,
rewrittenResultExpressions,
child))
case _ => None
}
}
/**
* An extractor used when planning physical execution of a window. This extractor outputs
* the window function type of the logical window.
*
* The input logical window must contain same type of window functions, which is ensured by
* the rule ExtractWindowExpressions in the analyzer.
*/
object PhysicalWindow {
// windowFunctionType, windowExpression, partitionSpec, orderSpec, child
private type ReturnType =
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
// The window expression should not be empty here, otherwise it's a bug.
if (windowExpressions.isEmpty) {
throw new AnalysisException(s"Window expression is empty in $expr")
}
val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
.reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
if (t1 != t2) {
// We shouldn't have different window function type here, otherwise it's a bug.
throw new AnalysisException(
s"Found different window function type in $windowExpressions")
} else {
t1
}
}
Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))
case _ => None
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy