All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.analysis.CheckAnalysis.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
 * Throws user facing errors when passed invalid queries that fail to analyze.
 */
trait CheckAnalysis extends PredicateHelper {

  /**
   * Override to provide additional checks for correct analysis.
   * These rules will be evaluated after our built-in check rules.
   */
  val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil

  protected def failAnalysis(msg: String): Nothing = {
    throw new AnalysisException(msg)
  }

  protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
    exprs.flatMap(_.collect {
      case e: Generator => e
    }).length > 1
  }

  protected def hasMapType(dt: DataType): Boolean = {
    dt.existsRecursively(_.isInstanceOf[MapType])
  }

  protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match {
    case _: Intersect | _: Except | _: Distinct =>
      plan.output.find(a => hasMapType(a.dataType))
    case d: Deduplicate =>
      d.keys.find(a => hasMapType(a.dataType))
    case _ => None
  }

  private def checkLimitClause(limitExpr: Expression): Unit = {
    limitExpr match {
      case e if !e.foldable => failAnalysis(
        "The limit expression must evaluate to a constant value, but got " +
          limitExpr.sql)
      case e if e.dataType != IntegerType => failAnalysis(
        s"The limit expression must be integer type, but got " +
          e.dataType.catalogString)
      case e =>
        e.eval() match {
          case null => failAnalysis(
            s"The evaluated limit expression must not be null, but got ${limitExpr.sql}")
          case v: Int if v < 0 => failAnalysis(
            s"The limit expression must be equal to or greater than 0, but got $v")
          case _ => // OK
        }
    }
  }

  def checkAnalysis(plan: LogicalPlan): Unit = {
    // We transform up and order the rules so as to catch the first possible failure instead
    // of the result of cascading resolution failures.
    plan.foreachUp {

      case p if p.analyzed => // Skip already analyzed sub-plans

      case u: UnresolvedRelation =>
        u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")

      case operator: LogicalPlan =>
        // Check argument data types of higher-order functions downwards first.
        // If the arguments of the higher-order functions are resolved but the type check fails,
        // the argument functions will not get resolved, but we should report the argument type
        // check failure instead of claiming the argument functions are unresolved.
        operator transformExpressionsDown {
          case hof: HigherOrderFunction
              if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure =>
            hof.checkArgumentDataTypes() match {
              case TypeCheckResult.TypeCheckFailure(message) =>
                hof.failAnalysis(
                  s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message")
            }
        }

        operator transformExpressionsUp {
          case a: Attribute if !a.resolved =>
            val from = operator.inputSet.map(_.qualifiedName).mkString(", ")
            a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")

          case e: Expression if e.checkInputDataTypes().isFailure =>
            e.checkInputDataTypes() match {
              case TypeCheckResult.TypeCheckFailure(message) =>
                e.failAnalysis(
                  s"cannot resolve '${e.sql}' due to data type mismatch: $message")
            }

          case c: Cast if !c.resolved =>
            failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " +
              c.dataType.catalogString)

          case g: Grouping =>
            failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup")
          case g: GroupingID =>
            failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup")

          case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
            failAnalysis(s"Distinct window functions are not supported: $w")

          case w @ WindowExpression(_: OffsetWindowFunction,
            WindowSpecDefinition(_, order, frame: SpecifiedWindowFrame))
             if order.isEmpty || !frame.isOffset =>
            failAnalysis("An offset window function can only be evaluated in an ordered " +
              s"row-based window frame with a single offset: $w")

          case _ @ WindowExpression(_: PythonUDF,
            WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
              if !frame.isUnbounded =>
            failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")

          case w @ WindowExpression(e, s) =>
            // Only allow window functions with an aggregate expression or an offset window
            // function or a Pandas window UDF.
            e match {
              case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
                w
              case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
                w
              case _ =>
                failAnalysis(s"Expression '$e' not supported within a window function.")
            }

          case s: SubqueryExpression =>
            checkSubqueryExpression(operator, s)
            s
        }

        operator match {
          case etw: EventTimeWatermark =>
            etw.eventTime.dataType match {
              case s: StructType
                if s.find(_.name == "end").map(_.dataType) == Some(TimestampType) =>
              case _: TimestampType =>
              case _ =>
                failAnalysis(
                  s"Event time must be defined on a window or a timestamp, but " +
                  s"${etw.eventTime.name} is of type ${etw.eventTime.dataType.catalogString}")
            }
          case f: Filter if f.condition.dataType != BooleanType =>
            failAnalysis(
              s"filter expression '${f.condition.sql}' " +
                s"of type ${f.condition.dataType.catalogString} is not a boolean.")

          case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
            failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
              s"conditions: $condition")

          case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
            failAnalysis(
              s"join condition '${condition.sql}' " +
                s"of type ${condition.dataType.catalogString} is not a boolean.")

          case Aggregate(groupingExprs, aggregateExprs, child) =>
            def isAggregateExpression(expr: Expression) = {
              expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
            }

            def checkValidAggregateExpression(expr: Expression): Unit = expr match {
              case expr: Expression if isAggregateExpression(expr) =>
                val aggFunction = expr match {
                  case agg: AggregateExpression => agg.aggregateFunction
                  case udf: PythonUDF => udf
                }
                aggFunction.children.foreach { child =>
                  child.foreach {
                    case expr: Expression if isAggregateExpression(expr) =>
                      failAnalysis(
                        s"It is not allowed to use an aggregate function in the argument of " +
                          s"another aggregate function. Please use the inner aggregate function " +
                          s"in a sub-query.")
                    case other => // OK
                  }

                  if (!child.deterministic) {
                    failAnalysis(
                      s"nondeterministic expression ${expr.sql} should not " +
                        s"appear in the arguments of an aggregate function.")
                  }
                }
              case e: Attribute if groupingExprs.isEmpty =>
                // Collect all [[AggregateExpressions]]s.
                val aggExprs = aggregateExprs.filter(_.collect {
                  case a: AggregateExpression => a
                }.nonEmpty)
                failAnalysis(
                  s"grouping expressions sequence is empty, " +
                    s"and '${e.sql}' is not an aggregate function. " +
                    s"Wrap '${aggExprs.map(_.sql).mkString("(", ", ", ")")}' in windowing " +
                    s"function(s) or wrap '${e.sql}' in first() (or first_value) " +
                    s"if you don't care which value you get."
                )
              case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
                failAnalysis(
                  s"expression '${e.sql}' is neither present in the group by, " +
                    s"nor is it an aggregate function. " +
                    "Add to group by or wrap in first() (or first_value) if you don't care " +
                    "which value you get.")
              case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
              case e => e.children.foreach(checkValidAggregateExpression)
            }

            def checkValidGroupingExprs(expr: Expression): Unit = {
              if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
                failAnalysis(
                  "aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
              }

              // Check if the data type of expr is orderable.
              if (!RowOrdering.isOrderable(expr.dataType)) {
                failAnalysis(
                  s"expression ${expr.sql} cannot be used as a grouping expression " +
                    s"because its data type ${expr.dataType.catalogString} is not an orderable " +
                    s"data type.")
              }

              if (!expr.deterministic) {
                // This is just a sanity check, our analysis rule PullOutNondeterministic should
                // already pull out those nondeterministic expressions and evaluate them in
                // a Project node.
                failAnalysis(s"nondeterministic expression ${expr.sql} should not " +
                  s"appear in grouping expression.")
              }
            }

            groupingExprs.foreach(checkValidGroupingExprs)
            aggregateExprs.foreach(checkValidAggregateExpression)

          case Sort(orders, _, _) =>
            orders.foreach { order =>
              if (!RowOrdering.isOrderable(order.dataType)) {
                failAnalysis(
                  s"sorting is not supported for columns of type ${order.dataType.catalogString}")
              }
            }

          case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr)

          case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr)

          case _: Union | _: SetOperation if operator.children.length > 1 =>
            def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
            def ordinalNumber(i: Int): String = i match {
              case 0 => "first"
              case 1 => "second"
              case 2 => "third"
              case i => s"${i + 1}th"
            }
            val ref = dataTypes(operator.children.head)
            operator.children.tail.zipWithIndex.foreach { case (child, ti) =>
              // Check the number of columns
              if (child.output.length != ref.length) {
                failAnalysis(
                  s"""
                    |${operator.nodeName} can only be performed on tables with the same number
                    |of columns, but the first table has ${ref.length} columns and
                    |the ${ordinalNumber(ti + 1)} table has ${child.output.length} columns
                  """.stripMargin.replace("\n", " ").trim())
              }
              // Check if the data types match.
              dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
                // SPARK-18058: we shall not care about the nullability of columns
                if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
                  failAnalysis(
                    s"""
                      |${operator.nodeName} can only be performed on tables with the compatible
                      |column types. ${dt1.catalogString} <> ${dt2.catalogString} at the
                      |${ordinalNumber(ci)} column of the ${ordinalNumber(ti + 1)} table
                    """.stripMargin.replace("\n", " ").trim())
                }
              }
            }

          // If the view output doesn't have the same number of columns neither with the child
          // output, nor with the query column names, throw an AnalysisException.
          // If the view's child output can't up cast to the view output,
          // throw an AnalysisException, too.
          case v @ View(desc, output, child) if child.resolved && output != child.output =>
            val queryColumnNames = desc.viewQueryColumnNames
            val queryOutput = if (queryColumnNames.nonEmpty) {
              if (output.length != queryColumnNames.length) {
                // If the view output doesn't have the same number of columns with the query column
                // names, throw an AnalysisException.
                throw new AnalysisException(
                  s"The view output ${output.mkString("[", ",", "]")} doesn't have the same" +
                    "number of columns with the query column names " +
                    s"${queryColumnNames.mkString("[", ",", "]")}")
              }
              val resolver = SQLConf.get.resolver
              queryColumnNames.map { colName =>
                child.output.find { attr =>
                  resolver(attr.name, colName)
                }.getOrElse(throw new AnalysisException(
                  s"Attribute with name '$colName' is not found in " +
                    s"'${child.output.map(_.name).mkString("(", ",", ")")}'"))
              }
            } else {
              child.output
            }

            output.zip(queryOutput).foreach {
              case (attr, originAttr) if !attr.dataType.sameType(originAttr.dataType) =>
                // The dataType of the output attributes may be not the same with that of the view
                // output, so we should cast the attribute to the dataType of the view output
                // attribute. Will throw an AnalysisException if the cast can't be performed or
                // might truncate.
                if (Cast.mayTruncate(originAttr.dataType, attr.dataType) ||
                  !Cast.canCast(originAttr.dataType, attr.dataType)) {
                  throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
                    s"${originAttr.dataType.catalogString} to ${attr.dataType.catalogString} " +
                    "as it may truncate\n")
                }
              case _ =>
            }

          case _ => // Fallbacks to the following checks
        }

        operator match {
          case o if o.children.nonEmpty && o.missingInput.nonEmpty =>
            val missingAttributes = o.missingInput.mkString(",")
            val input = o.inputSet.mkString(",")
            val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " +
              s"from $input in operator ${operator.simpleString}."

            val resolver = plan.conf.resolver
            val attrsWithSameName = o.missingInput.filter { missing =>
              o.inputSet.exists(input => resolver(missing.name, input.name))
            }

            val msg = if (attrsWithSameName.nonEmpty) {
              val sameNames = attrsWithSameName.map(_.name).mkString(",")
              s"$msgForMissingAttributes Attribute(s) with the same name appear in the " +
                s"operation: $sameNames. Please check if the right attribute(s) are used."
            } else {
              msgForMissingAttributes
            }

            failAnalysis(msg)

          case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
            failAnalysis(
              s"""Only a single table generating function is allowed in a SELECT clause, found:
                 | ${exprs.map(_.sql).mkString(",")}""".stripMargin)

          case j: Join if !j.duplicateResolved =>
            val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
            failAnalysis(
              s"""
                 |Failure when resolving conflicting references in Join:
                 |$plan
                 |Conflicting attributes: ${conflictingAttributes.mkString(",")}
                 |""".stripMargin)

          case i: Intersect if !i.duplicateResolved =>
            val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
            failAnalysis(
              s"""
                 |Failure when resolving conflicting references in Intersect:
                 |$plan
                 |Conflicting attributes: ${conflictingAttributes.mkString(",")}
               """.stripMargin)

          case e: Except if !e.duplicateResolved =>
            val conflictingAttributes = e.left.outputSet.intersect(e.right.outputSet)
            failAnalysis(
              s"""
                 |Failure when resolving conflicting references in Except:
                 |$plan
                 |Conflicting attributes: ${conflictingAttributes.mkString(",")}
               """.stripMargin)

          // TODO: although map type is not orderable, technically map type should be able to be
          // used in equality comparison, remove this type check once we support it.
          case o if mapColumnInSetOperation(o).isDefined =>
            val mapCol = mapColumnInSetOperation(o).get
            failAnalysis("Cannot have map type columns in DataFrame which calls " +
              s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " +
              "is " + mapCol.dataType.catalogString)

          case o if o.expressions.exists(!_.deterministic) &&
            !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
            !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] =>
            // The rule above is used to check Aggregate operator.
            failAnalysis(
              s"""nondeterministic expressions are only allowed in
                 |Project, Filter, Aggregate or Window, found:
                 | ${o.expressions.map(_.sql).mkString(",")}
                 |in operator ${operator.simpleString}
               """.stripMargin)

          case _: UnresolvedHint =>
            throw new IllegalStateException(
              "Internal error: logical hint operator should have been removed during analysis")

          case _ => // Analysis successful!
        }
    }
    extendedCheckRules.foreach(_(plan))
    plan.foreachUp {
      case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}")
      case _ =>
    }

    plan.setAnalyzed()
  }

  /**
   * Validates subquery expressions in the plan. Upon failure, returns an user facing error.
   */
  private def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = {
    def checkAggregateInScalarSubquery(
        conditions: Seq[Expression],
        query: LogicalPlan, agg: Aggregate): Unit = {
      // Make sure correlated scalar subqueries contain one row for every outer row by
      // enforcing that they are aggregates containing exactly one aggregate expression.
      val aggregates = agg.expressions.flatMap(_.collect {
        case a: AggregateExpression => a
      })
      if (aggregates.isEmpty) {
        failAnalysis("The output of a correlated scalar subquery must be aggregated")
      }

      // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
      // are not part of the correlated columns.
      val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
      // Collect the local references from the correlated predicate in the subquery.
      val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
        .filterNot(conditions.flatMap(_.references).contains)
      val correlatedCols = AttributeSet(subqueryColumns)
      val invalidCols = groupByCols -- correlatedCols
      // GROUP BY columns must be a subset of columns in the predicates
      if (invalidCols.nonEmpty) {
        failAnalysis(
          "A GROUP BY clause in a scalar correlated subquery " +
            "cannot contain non-correlated columns: " +
            invalidCols.mkString(","))
      }
    }

    // Skip subquery aliases added by the Analyzer.
    // For projects, do the necessary mapping and skip to its child.
    def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match {
      case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child)
      case p: Project => cleanQueryInScalarSubquery(p.child)
      case child => child
    }

    // Validate the subquery plan.
    checkAnalysis(expr.plan)

    expr match {
      case ScalarSubquery(query, conditions, _) =>
        // Scalar subquery must return one column as output.
        if (query.output.size != 1) {
          failAnalysis(
            s"Scalar subquery must return only one column, but got ${query.output.size}")
        }

        if (conditions.nonEmpty) {
          cleanQueryInScalarSubquery(query) match {
            case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a)
            case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a)
            case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail")
          }

          // Only certain operators are allowed to host subquery expression containing
          // outer references.
          plan match {
            case _: Filter | _: Aggregate | _: Project => // Ok
            case other => failAnalysis(
              "Correlated scalar sub-queries can only be used in a " +
                s"Filter/Aggregate/Project: $plan")
          }
        }

      case inSubqueryOrExistsSubquery =>
        plan match {
          case _: Filter => // Ok
          case _ =>
            failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan")
        }
    }

    // Validate to make sure the correlations appearing in the query are valid and
    // allowed by spark.
    checkCorrelationsInSubquery(expr.plan)
  }

  /**
   * Validates to make sure the outer references appearing inside the subquery
   * are allowed.
   */
  private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = {
    // Validate that correlated aggregate expression do not contain a mixture
    // of outer and local references.
    def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
      expr.foreach {
        case a: AggregateExpression if containsOuter(a) =>
          val outer = a.collect { case OuterReference(e) => e.toAttribute }
          val local = a.references -- outer
          if (local.nonEmpty) {
            val msg =
              s"""
                 |Found an aggregate expression in a correlated predicate that has both
                 |outer and local references, which is not supported yet.
                 |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql},
                 |Outer references: ${outer.map(_.sql).mkString(", ")},
                 |Local references: ${local.map(_.sql).mkString(", ")}.
               """.stripMargin.replace("\n", " ").trim()
            failAnalysis(msg)
          }
        case _ =>
      }
    }

    // Make sure a plan's subtree does not contain outer references
    def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
      if (hasOuterReferences(p)) {
        failAnalysis(s"Accessing outer query column is not allowed in:\n$p")
      }
    }

    // Make sure a plan's expressions do not contain :
    // 1. Aggregate expressions that have mixture of outer and local references.
    // 2. Expressions containing outer references on plan nodes other than Filter.
    def failOnInvalidOuterReference(p: LogicalPlan): Unit = {
      p.expressions.foreach(checkMixedReferencesInsideAggregateExpr)
      if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) {
        failAnalysis(
          "Expressions referencing the outer query are not supported outside of WHERE/HAVING " +
            s"clauses:\n$p")
      }
    }

    // SPARK-17348: A potential incorrect result case.
    // When a correlated predicate is a non-equality predicate,
    // certain operators are not permitted from the operator
    // hosting the correlated predicate up to the operator on the outer table.
    // Otherwise, the pull up of the correlated predicate
    // will generate a plan with a different semantics
    // which could return incorrect result.
    // Currently we check for Aggregate and Window operators
    //
    // Below shows an example of a Logical Plan during Analyzer phase that
    // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..]
    // through the Aggregate (or Window) operator could alter the result of
    // the Aggregate.
    //
    // Project [c1#76]
    // +- Project [c1#87, c2#88]
    // :  (Aggregate or Window operator)
    // :  +- Filter [outer(c2#77) >= c2#88)]
    // :     +- SubqueryAlias t2, `t2`
    // :        +- Project [_1#84 AS c1#87, _2#85 AS c2#88]
    // :           +- LocalRelation [_1#84, _2#85]
    // +- SubqueryAlias t1, `t1`
    // +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
    // +- LocalRelation [_1#73, _2#74]
    // SPARK-35080: The same issue can happen to correlated equality predicates when
    // they do not guarantee one-to-one mapping between inner and outer attributes.
    // For example:
    // Table:
    //   t1(a, b): [(0, 6), (1, 5), (2, 4)]
    //   t2(c): [(6)]
    //
    // Query:
    //   SELECT c, (SELECT COUNT(*) FROM t1 WHERE a + b = c) FROM t2
    //
    // Original subquery plan:
    //   Aggregate [count(1)]
    //   +- Filter ((a + b) = outer(c))
    //      +- LocalRelation [a, b]
    //
    // Plan after pulling up correlated predicates:
    //   Aggregate [a, b] [count(1), a, b]
    //   +- LocalRelation [a, b]
    //
    // Plan after rewrite:
    //   Project [c1, count(1)]
    //   +- Join LeftOuter ((a + b) = c)
    //      :- LocalRelation [c]
    //      +- Aggregate [a, b] [count(1), a, b]
    //         +- LocalRelation [a, b]
    //
    // The right hand side of the join transformed from the subquery will output
    //   count(1) | a | b
    //      1     | 0 | 6
    //      1     | 1 | 5
    //      1     | 2 | 4
    // and the plan after rewrite will give the original query incorrect results.
    def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = {
      if (predicates.nonEmpty) {
        // Report a non-supported case as an exception
        failAnalysis("Correlated column is not allowed in predicate " +
          s"${predicates.map(_.sql).mkString}:\n$p")
      }
    }

    def containsAttribute(e: Expression): Boolean = {
      e.find(_.isInstanceOf[Attribute]).isDefined
    }

    // Given a correlated predicate, check if it is either a non-equality predicate or
    // equality predicate that does not guarantee one-on-one mapping between inner and
    // outer attributes. When the correlated predicate does not contain any attribute
    // (i.e. only has outer references), it is supported and should return false. E.G.:
    //   (a = outer(c)) -> false
    //   (outer(c) = outer(d)) -> false
    //   (a > outer(c)) -> true
    //   (a + b = outer(c)) -> true
    // The last one is true because there can be multiple combinations of (a, b) that
    // satisfy the equality condition. For example, if outer(c) = 0, then both (0, 0)
    // and (-1, 1) can make the predicate evaluate to true.
    def isUnsupportedPredicate(condition: Expression): Boolean = condition match {
      // Only allow equality condition with one side being an attribute and another
      // side being an expression without attributes from the inner query. Note
      // OuterReference is a leaf node and will not be found here.
      case Equality(_: Attribute, b) => containsAttribute(b)
      case Equality(a, _: Attribute) => containsAttribute(a)
      case e @ Equality(_, _) => containsAttribute(e)
      case _ => true
    }

    val unsupportedPredicates = mutable.ArrayBuffer.empty[Expression]

    // Simplify the predicates before validating any unsupported correlation patterns in the plan.
    AnalysisHelper.allowInvokingTransformsInAnalyzer { BooleanSimplification(sub).foreachUp {
      // Whitelist operators allowed in a correlated subquery
      // There are 4 categories:
      // 1. Operators that are allowed anywhere in a correlated subquery, and,
      //    by definition of the operators, they either do not contain
      //    any columns or cannot host outer references.
      // 2. Operators that are allowed anywhere in a correlated subquery
      //    so long as they do not host outer references.
      // 3. Operators that need special handlings. These operators are
      //    Filter, Join, Aggregate, and Generate.
      //
      // Any operators that are not in the above list are allowed
      // in a correlated subquery only if they are not on a correlation path.
      // In other word, these operators are allowed only under a correlation point.
      //
      // A correlation path is defined as the sub-tree of all the operators that
      // are on the path from the operator hosting the correlated expressions
      // up to the operator producing the correlated values.

      // Category 1:
      // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias
      case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias =>

      // Category 2:
      // These operators can be anywhere in a correlated subquery.
      // so long as they do not host outer references in the operators.
      case p: Project =>
        failOnInvalidOuterReference(p)

      case s: Sort =>
        failOnInvalidOuterReference(s)

      case r: RepartitionByExpression =>
        failOnInvalidOuterReference(r)

      // Category 3:
      // Filter is one of the two operators allowed to host correlated expressions.
      // The other operator is Join. Filter can be anywhere in a correlated subquery.
      case f: Filter =>
        val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter)
        unsupportedPredicates ++= correlated.filter(isUnsupportedPredicate)
        failOnInvalidOuterReference(f)

      // Aggregate cannot host any correlated expressions
      // It can be on a correlation path if the correlation contains
      // only supported correlated equality predicates.
      // It cannot be on a correlation path if the correlation has
      // non-equality correlated predicates.
      case a: Aggregate =>
        failOnInvalidOuterReference(a)
        failOnUnsupportedCorrelatedPredicate(unsupportedPredicates.toSeq, a)

      // Join can host correlated expressions.
      case j @ Join(left, right, joinType, _) =>
        joinType match {
          // Inner join, like Filter, can be anywhere.
          case _: InnerLike =>
            failOnInvalidOuterReference(j)

          // Left outer join's right operand cannot be on a correlation path.
          // LeftAnti and ExistenceJoin are special cases of LeftOuter.
          // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame
          // so it should not show up here in Analysis phase. This is just a safety net.
          //
          // LeftSemi does not allow output from the right operand.
          // Any correlated references in the subplan
          // of the right operand cannot be pulled up.
          case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
            failOnInvalidOuterReference(j)
            failOnOuterReferenceInSubTree(right)

          // Likewise, Right outer join's left operand cannot be on a correlation path.
          case RightOuter =>
            failOnInvalidOuterReference(j)
            failOnOuterReferenceInSubTree(left)

          // Any other join types not explicitly listed above,
          // including Full outer join, are treated as Category 4.
          case _ =>
            failOnOuterReferenceInSubTree(j)
        }

      // Generator with join=true, i.e., expressed with
      // LATERAL VIEW [OUTER], similar to inner join,
      // allows to have correlation under it
      // but must not host any outer references.
      // Note:
      // Generator with requiredChildOutput.isEmpty is treated as Category 4.
      case g: Generate if g.requiredChildOutput.nonEmpty =>
        failOnInvalidOuterReference(g)

      // Category 4: Any other operators not in the above 3 categories
      // cannot be on a correlation path, that is they are allowed only
      // under a correlation point but they and their descendant operators
      // are not allowed to have any correlated expressions.
      case p =>
        failOnOuterReferenceInSubTree(p)
    }}
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy