All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.analysis.ResolveHints.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf


/**
 * Collection of rules related to hints. The only hint currently available is join strategy hint.
 *
 * Note that this is separately into two rules because in the future we might introduce new hint
 * rules that have different ordering requirements from join strategies.
 */
object ResolveHints {

  /**
   * The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a
   * sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)",
   * "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that
   * is not aliased differently), subquery, or common table expression that match the specified
   * name.
   *
   * The hint resolution works by recursively traversing down the query plan to find a relation or
   * subquery that matches one of the specified relation aliases. The traversal does not go past
   * beyond any view reference, with clause or subquery alias.
   *
   * This rule must happen before common table expressions.
   */
  class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
    private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases)

    private val hintErrorHandler = conf.hintErrorHandler

    def resolver: Resolver = conf.resolver

    private def createHintInfo(hintName: String): HintInfo = {
      HintInfo(strategy =
        JoinStrategyHint.strategies.find(
          _.hintAliases.map(
            _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
    }

    // This method checks if given multi-part identifiers are matched with each other.
    // The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch
    // in the analyzer and we cannot semantically compare them at this stage.
    // Therefore, we follow a simple rule; they match if an identifier in a hint
    // is a tail of an identifier in a relation. This process is independent of a session
    // catalog (`currentDb` in [[SessionCatalog]]) and it just compares them literally.
    //
    // For example,
    //  * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
    //    the broadcast hint will match both tables, `db1.t` and `t`,
    //    even when the current db is `db2`.
    //  * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN t`,
    //    the broadcast hint will match the left-side table only, `default.t`.
    private def matchedIdentifier(identInHint: Seq[String], identInQuery: Seq[String]): Boolean = {
      if (identInHint.length <= identInQuery.length) {
        identInHint.zip(identInQuery.takeRight(identInHint.length))
          .forall { case (i1, i2) => resolver(i1, i2) }
      } else {
        false
      }
    }

    private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
      r.identifier.qualifier :+ r.identifier.name
    }

    private def applyJoinStrategyHint(
        plan: LogicalPlan,
        relationsInHint: Set[Seq[String]],
        relationsInHintWithMatch: mutable.HashSet[Seq[String]],
        hintName: String): LogicalPlan = {
      // Whether to continue recursing down the tree
      var recurse = true

      def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
        relationsInHint.find(matchedIdentifier(_, identInQuery))
          .map(relationsInHintWithMatch.add).nonEmpty
      }

      val newNode = CurrentOrigin.withOrigin(plan.origin) {
        plan match {
          case ResolvedHint(u @ UnresolvedRelation(ident), hint)
              if matchedIdentifierInHint(ident) =>
            ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler))

          case ResolvedHint(r: SubqueryAlias, hint)
              if matchedIdentifierInHint(extractIdentifier(r)) =>
            ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler))

          case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) =>
            ResolvedHint(plan, createHintInfo(hintName))

          case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) =>
            ResolvedHint(plan, createHintInfo(hintName))

          case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
            // Don't traverse down these nodes.
            // For an existing strategy hint, there is no chance for a match from this point down.
            // The rest (view, with, subquery) indicates different scopes that we shouldn't traverse
            // down. Note that technically when this rule is executed, we haven't completed view
            // resolution yet and as a result the view part should be deadcode. I'm leaving it here
            // to be more future proof in case we change the view we do view resolution.
            recurse = false
            plan

          case _ =>
            plan
        }
      }

      if ((plan fastEquals newNode) && recurse) {
        newNode.mapChildren { child =>
          applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName)
        }
      } else {
        newNode
      }
    }

    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
      case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
        if (h.parameters.isEmpty) {
          // If there is no table alias specified, apply the hint on the entire subtree.
          ResolvedHint(h.child, createHintInfo(h.name))
        } else {
          // Otherwise, find within the subtree query plans to apply the hint.
          val relationNamesInHint = h.parameters.map {
            case tableName: String => UnresolvedAttribute.parseAttributeName(tableName)
            case tableId: UnresolvedAttribute => tableId.nameParts
            case unsupported => throw new AnalysisException("Join strategy hint parameter " +
              s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
          }.toSet
          val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
          val applied = applyJoinStrategyHint(
            h.child, relationNamesInHint, relationsInHintWithMatch, h.name)

          // Filters unmatched relation identifiers in the hint
          val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
          hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, unmatchedIdents)
          applied
        }
    }
  }

  /**
   * COALESCE Hint accepts names "COALESCE", "REPARTITION", and "REPARTITION_BY_RANGE".
   */
  class ResolveCoalesceHints(conf: SQLConf) extends Rule[LogicalPlan] {

    /**
     * This function handles hints for "COALESCE" and "REPARTITION".
     * The "COALESCE" hint only has a partition number as a parameter. The "REPARTITION" hint
     * has a partition number, columns, or both of them as parameters.
     */
    private def createRepartition(
        shuffle: Boolean, hint: UnresolvedHint): LogicalPlan = {
      val hintName = hint.name.toUpperCase(Locale.ROOT)

      def createRepartitionByExpression(
          numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
        val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
        if (sortOrders.nonEmpty) throw new IllegalArgumentException(
          s"""Invalid partitionExprs specified: $sortOrders
             |For range partitioning use REPARTITION_BY_RANGE instead.
           """.stripMargin)
        val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
        if (invalidParams.nonEmpty) {
          throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
            s"${invalidParams.mkString(", ")} found")
        }
        RepartitionByExpression(
          partitionExprs.map(_.asInstanceOf[Expression]), hint.child, numPartitions)
      }

      hint.parameters match {
        case Seq(IntegerLiteral(numPartitions)) =>
          Repartition(numPartitions, shuffle, hint.child)
        case Seq(numPartitions: Int) =>
          Repartition(numPartitions, shuffle, hint.child)
        // The "COALESCE" hint (shuffle = false) must have a partition number only
        case _ if !shuffle =>
          throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

        case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
          createRepartitionByExpression(numPartitions, param.tail)
        case param @ Seq(numPartitions: Int, _*) if shuffle =>
          createRepartitionByExpression(numPartitions, param.tail)
        case param @ Seq(_*) if shuffle =>
          createRepartitionByExpression(conf.numShufflePartitions, param)
      }
    }

    /**
     * This function handles hints for "REPARTITION_BY_RANGE".
     * The "REPARTITION_BY_RANGE" hint must have column names and a partition number is optional.
     */
    private def createRepartitionByRange(hint: UnresolvedHint): RepartitionByExpression = {
      val hintName = hint.name.toUpperCase(Locale.ROOT)

      def createRepartitionByExpression(
          numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
        val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
        if (invalidParams.nonEmpty) {
          throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
            s"${invalidParams.mkString(", ")} found")
        }
        val sortOrder = partitionExprs.map {
          case expr: SortOrder => expr
          case expr: Expression => SortOrder(expr, Ascending)
        }
        RepartitionByExpression(sortOrder, hint.child, numPartitions)
      }

      hint.parameters match {
        case param @ Seq(IntegerLiteral(numPartitions), _*) =>
          createRepartitionByExpression(numPartitions, param.tail)
        case param @ Seq(numPartitions: Int, _*) =>
          createRepartitionByExpression(numPartitions, param.tail)
        case param @ Seq(_*) =>
          createRepartitionByExpression(conf.numShufflePartitions, param)
      }
    }

    def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
      case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match {
          case "REPARTITION" =>
            createRepartition(shuffle = true, hint)
          case "COALESCE" =>
            createRepartition(shuffle = false, hint)
          case "REPARTITION_BY_RANGE" =>
            createRepartitionByRange(hint)
          case _ => hint
        }
    }
  }

  object ResolveCoalesceHints {
    val COALESCE_HINT_NAMES: Set[String] = Set("COALESCE", "REPARTITION", "REPARTITION_BY_RANGE")
  }

  /**
   * Removes all the hints, used to remove invalid hints provided by the user.
   * This must be executed after all the other hint rules are executed.
   */
  class RemoveAllHints(conf: SQLConf) extends Rule[LogicalPlan] {

    private val hintErrorHandler = conf.hintErrorHandler

    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
      case h: UnresolvedHint =>
        hintErrorHandler.hintNotRecognized(h.name, h.parameters)
        h.child
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy