All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.paimon.spark.catalyst.optimizer.MergePaimonScalarSubqueriersBase.scala Maven / Gradle / Ivy

There is a newer version: 0.9.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.paimon.spark.catalyst.optimizer

import org.apache.paimon.spark.PaimonScan
import org.apache.paimon.spark.util.CTERelationRefUtils

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CreateNamedStruct, Expression, ExprId, GetStructField, LeafExpression, Literal, NamedExpression, PredicateHelper, ScalarSubquery, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, Filter, Join, LogicalPlan, Project, Subquery, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

import scala.collection.mutable.ArrayBuffer

/**
 * Most codes are copied from [[org.apache.spark.sql.catalyst.optimizer.MergeScalarSubqueries]].
 *
 * That merge scalar subqueries for DataSource V2 can't be achieved on Spark Side, due lack of the
 * unified interface which can determine whether two [[DataSourceV2ScanRelation]]s can be merged and
 * reused. So we extend the [[tryMergePlans]] method to check and merge
 * [[DataSourceV2ScanRelation]]s, thus we can merge scalar subqueries for paimon.
 */
trait MergePaimonScalarSubqueriersBase extends Rule[LogicalPlan] with PredicateHelper {
  def apply(plan: LogicalPlan): LogicalPlan = {
    plan match {
      // Subquery reuse needs to be enabled for this optimization.
      case _ if !conf.getConf(SQLConf.SUBQUERY_REUSE_ENABLED) => plan

      // This rule does a whole plan traversal, no need to run on subqueries.
      case _: Subquery => plan

      // Plans with CTEs are not supported for now.
      case _: WithCTE => plan

      case _ => extractCommonScalarSubqueries(plan)
    }
  }

  /**
   * An item in the cache of merged scalar subqueries.
   *
   * @param attributes
   *   Attributes that form the struct scalar return value of a merged subquery.
   * @param plan
   *   The plan of a merged scalar subquery.
   * @param merged
   *   A flag to identify if this item is the result of merging subqueries. Please note that
   *   `attributes.size == 1` doesn't always mean that the plan is not merged as there can be
   *   subqueries that are different ([[checkIdenticalPlans]] is false) due to an extra [[Project]]
   *   node in one of them. In that case `attributes.size` remains 1 after merging, but the merged
   *   flag becomes true.
   */
  case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: Boolean)

  private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
    val cache = ArrayBuffer.empty[Header]
    val planWithReferences = insertReferences(plan, cache)
    cache.zipWithIndex.foreach {
      case (header, i) =>
        cache(i) = cache(i).copy(plan = if (header.merged) {
          CTERelationDef(
            createProject(header.attributes, removeReferences(header.plan, cache)),
            underSubquery = true)
        } else {
          removeReferences(header.plan, cache)
        })
    }
    val newPlan = removeReferences(planWithReferences, cache)
    val subqueryCTEs = cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef])
    if (subqueryCTEs.nonEmpty) {
      WithCTE(newPlan, subqueryCTEs)
    } else {
      newPlan
    }
  }

  // First traversal builds up the cache and inserts `ScalarSubqueryReference`s to the plan.
  private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]): LogicalPlan = {
    plan.transformUpWithSubqueries {
      case n =>
        n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
          case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
            val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
            ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType, s.exprId)
        }
    }
  }

  // Caching returns the index of the subquery in the cache and the index of scalar member in the
  // "Header".
  private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = {
    val output = plan.output.head
    cache.zipWithIndex
      .collectFirst(Function.unlift {
        case (header, subqueryIndex) =>
          checkIdenticalPlans(plan, header.plan)
            .map {
              outputMap =>
                val mappedOutput = mapAttributes(output, outputMap)
                val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
                subqueryIndex -> headerIndex
            }
            .orElse(tryMergePlans(plan, header.plan).map {
              case (mergedPlan, outputMap) =>
                val mappedOutput = mapAttributes(output, outputMap)
                var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId)
                val newHeaderAttributes = if (headerIndex == -1) {
                  headerIndex = header.attributes.size
                  header.attributes :+ mappedOutput
                } else {
                  header.attributes
                }
                cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, merged = true)
                subqueryIndex -> headerIndex
            })
      })
      .getOrElse {
        cache += Header(Seq(output), plan, merged = false)
        cache.length - 1 -> 0
      }
  }

  // If 2 plans are identical return the attribute mapping from the new to the cached version.
  protected def checkIdenticalPlans(
      newPlan: LogicalPlan,
      cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = {
    if (newPlan.canonicalized == cachedPlan.canonicalized) {
      Some(AttributeMap(newPlan.output.zip(cachedPlan.output)))
    } else {
      None
    }
  }

  // Recursively traverse down and try merging 2 plans. If merge is possible then return the merged
  // plan with the attribute mapping from the new to the merged version.
  // Please note that merging arbitrary plans can be complicated, the current version supports only
  // some of the most important nodes.
  private def tryMergePlans(
      newPlan: LogicalPlan,
      cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = {
    checkIdenticalPlans(newPlan, cachedPlan)
      .map(cachedPlan -> _)
      .orElse((newPlan, cachedPlan) match {
        case (np: Project, cp: Project) =>
          tryMergePlans(np.child, cp.child).map {
            case (mergedChild, outputMap) =>
              val (mergedProjectList, newOutputMap) =
                mergeNamedExpressions(np.projectList, outputMap, cp.projectList)
              val mergedPlan = Project(mergedProjectList, mergedChild)
              mergedPlan -> newOutputMap
          }
        case (np, cp: Project) =>
          tryMergePlans(np, cp.child).map {
            case (mergedChild, outputMap) =>
              val (mergedProjectList, newOutputMap) =
                mergeNamedExpressions(np.output, outputMap, cp.projectList)
              val mergedPlan = Project(mergedProjectList, mergedChild)
              mergedPlan -> newOutputMap
          }
        case (np: Project, cp) =>
          tryMergePlans(np.child, cp).map {
            case (mergedChild, outputMap) =>
              val (mergedProjectList, newOutputMap) =
                mergeNamedExpressions(np.projectList, outputMap, cp.output)
              val mergedPlan = Project(mergedProjectList, mergedChild)
              mergedPlan -> newOutputMap
          }
        case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) =>
          tryMergePlans(np.child, cp.child).flatMap {
            case (mergedChild, outputMap) =>
              val mappedNewGroupingExpression =
                np.groupingExpressions.map(mapAttributes(_, outputMap))
              // Order of grouping expression does matter as merging different grouping orders can
              // introduce "extra" shuffles/sorts that might not present in all of the original
              // subqueries.
              if (
                mappedNewGroupingExpression.map(_.canonicalized) ==
                  cp.groupingExpressions.map(_.canonicalized)
              ) {
                val (mergedAggregateExpressions, newOutputMap) =
                  mergeNamedExpressions(np.aggregateExpressions, outputMap, cp.aggregateExpressions)
                val mergedPlan =
                  Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild)
                Some(mergedPlan -> newOutputMap)
              } else {
                None
              }
          }

        case (np: Filter, cp: Filter) =>
          tryMergePlans(np.child, cp.child).flatMap {
            case (mergedChild, outputMap) =>
              val mappedNewCondition = mapAttributes(np.condition, outputMap)
              // Comparing the canonicalized form is required to ignore different forms of the same
              // expression.
              if (mappedNewCondition.canonicalized == cp.condition.canonicalized) {
                val mergedPlan = cp.withNewChildren(Seq(mergedChild))
                Some(mergedPlan -> outputMap)
              } else {
                None
              }
          }

        case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint =>
          tryMergePlans(np.left, cp.left).flatMap {
            case (mergedLeft, leftOutputMap) =>
              tryMergePlans(np.right, cp.right).flatMap {
                case (mergedRight, rightOutputMap) =>
                  val outputMap = leftOutputMap ++ rightOutputMap
                  val mappedNewCondition = np.condition.map(mapAttributes(_, outputMap))
                  // Comparing the canonicalized form is required to ignore different forms of the same
                  // expression and `AttributeReference.quailifier`s in `cp.condition`.
                  if (
                    mappedNewCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)
                  ) {
                    val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight))
                    Some(mergedPlan -> outputMap)
                  } else {
                    None
                  }
              }
          }
        case (
              newV2ScanRelation: DataSourceV2ScanRelation,
              cachedV2ScanRelation: DataSourceV2ScanRelation) =>
          tryMergeDataSourceV2ScanRelation(newV2ScanRelation, cachedV2ScanRelation)

        // Otherwise merging is not possible.
        case _ => None
      })
  }

  def tryMergeDataSourceV2ScanRelation(
      newV2ScanRelation: DataSourceV2ScanRelation,
      cachedV2ScanRelation: DataSourceV2ScanRelation)
      : Option[(LogicalPlan, AttributeMap[Attribute])]

  protected def samePartitioning(
      newPartitioning: Option[Seq[Expression]],
      cachedPartitioning: Option[Seq[Expression]],
      outputAttrMap: AttributeMap[Attribute]): Boolean = {
    val mappedNewPartitioning = newPartitioning.map(_.map(mapAttributes(_, outputAttrMap)))
    mappedNewPartitioning.map(_.map(_.canonicalized)) == cachedPartitioning.map(
      _.map(_.canonicalized))
  }

  protected def mergePaimonScan(scan1: PaimonScan, scan2: PaimonScan): Option[PaimonScan] = {
    if (
      scan1.table == scan2.table &&
      scan1.filters == scan2.filters &&
      scan1.pushDownLimit == scan2.pushDownLimit
    ) {

      if (scan1.requiredSchema == scan2.requiredSchema) {
        Some(scan2)
      } else {
        val mergedRequiredSchema = StructType(
          (scan2.requiredSchema.fields.toSet ++ scan1.requiredSchema.fields.toSet).toSeq)
        Some(scan2.copy(requiredSchema = mergedRequiredSchema))
      }
    } else {
      None
    }
  }

  private def createProject(attributes: Seq[Attribute], plan: LogicalPlan): Project = {
    Project(
      Seq(
        Alias(
          CreateNamedStruct(attributes.flatMap(a => Seq(Literal(a.name), a))),
          "mergedValue")()),
      plan)
  }

  protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = {
    expr.transform { case a: Attribute => outputMap.getOrElse(a, a) }.asInstanceOf[T]
  }

  // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into
  // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to
  // the merged version that can be propagated up during merging nodes.
  private def mergeNamedExpressions(
      newExpressions: Seq[NamedExpression],
      outputMap: AttributeMap[Attribute],
      cachedExpressions: Seq[NamedExpression]) = {
    val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*)
    val newOutputMap = AttributeMap(newExpressions.map {
      ne =>
        val mapped = mapAttributes(ne, outputMap)
        val withoutAlias = mapped match {
          case Alias(child, _) => child
          case e => e
        }
        ne.toAttribute -> mergedExpressions
          .find {
            case Alias(child, _) => child.semanticEquals(withoutAlias)
            case e => e.semanticEquals(withoutAlias)
          }
          .getOrElse {
            mergedExpressions += mapped
            mapped
          }
          .toAttribute
    })
    (mergedExpressions.toSeq, newOutputMap)
  }

  // Only allow aggregates of the same implementation because merging different implementations
  // could cause performance regression.
  private def supportedAggregateMerge(newPlan: Aggregate, cachedPlan: Aggregate) = {
    val newPlanAggregateExpressions = newPlan.aggregateExpressions.flatMap(_.collect {
      case a: AggregateExpression => a
    })
    val cachedPlanAggregateExpressions = cachedPlan.aggregateExpressions.flatMap(_.collect {
      case a: AggregateExpression => a
    })
    val newPlanSupportsHashAggregate = Aggregate.supportsHashAggregate(
      newPlanAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
    val cachedPlanSupportsHashAggregate = Aggregate.supportsHashAggregate(
      cachedPlanAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
    newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate ||
    newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && {
      val newPlanSupportsObjectHashAggregate =
        Aggregate.supportsObjectHashAggregate(newPlanAggregateExpressions)
      val cachedPlanSupportsObjectHashAggregate =
        Aggregate.supportsObjectHashAggregate(cachedPlanAggregateExpressions)
      newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate ||
      newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate
    }
  }

  // Second traversal replaces `ScalarSubqueryReference`s to either
  // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from
  // multiple subqueries or `ScalarSubquery(original plan)` if it isn't.
  private def removeReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]) = {
    plan.transformUpWithSubqueries {
      case n =>
        n.transformExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
          case ssr: ScalarSubqueryReference =>
            val header = cache(ssr.subqueryIndex)
            if (header.merged) {
              val subqueryCTE = header.plan.asInstanceOf[CTERelationDef]
              GetStructField(
                createScalarSubquery(
                  CTERelationRefUtils.createCTERelationRef(
                    subqueryCTE.id,
                    _resolved = true,
                    subqueryCTE.output),
                  ssr.exprId),
                ssr.headerIndex)
            } else {
              createScalarSubquery(header.plan, ssr.exprId)
            }
        }
    }
  }

  protected def createScalarSubquery(plan: LogicalPlan, exprId: ExprId): ScalarSubquery

}

/** Temporal reference to a subquery. */
case class ScalarSubqueryReference(
    subqueryIndex: Int,
    headerIndex: Int,
    dataType: DataType,
    exprId: ExprId)
  extends LeafExpression
  with Unevaluable {
  override def nullable: Boolean = true

  final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE)

  override def stringArgs: Iterator[Any] = Iterator(subqueryIndex, headerIndex, dataType, exprId.id)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy