com.nvidia.spark.rapids.RapidsMeta.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of rapids-4-spark_2.13 Show documentation
Show all versions of rapids-4-spark_2.13 Show documentation
Creates the distribution package of the RAPIDS plugin for Apache Spark
The newest version!
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import java.time.ZoneId
import scala.collection.mutable
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, RuntimeReplaceable, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, UnaryLike}
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.python.AggregateInPandasExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.TimeZoneDB
import org.apache.spark.sql.rapids.aggregate.{CpuToGpuAggregateBufferConverter, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinMetaBase, GpuBroadcastNestedLoopJoinMetaBase}
import org.apache.spark.sql.types.{ArrayType, DataType, DateType, MapType, StringType, StructType}
trait DataFromReplacementRule {
val operationName: String
def incompatDoc: Option[String] = None
def disabledMsg: Option[String] = None
def confKey: String
def getChecks: Option[TypeChecks[_]]
}
/**
* A version of DataFromReplacementRule that is used when no replacement rule can be found.
*/
final class NoRuleDataFromReplacementRule extends DataFromReplacementRule {
override val operationName: String = ""
override def confKey = "NOT_FOUND"
override def getChecks: Option[TypeChecks[_]] = None
}
object RapidsMeta {
val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported")
}
/**
* Holds metadata about a stage in the physical plan that is separate from the plan itself.
* This is helpful in deciding when to replace part of the plan with a GPU enabled version.
*
* @param wrapped what we are wrapping
* @param conf the config
* @param parent the parent of this node, if there is one.
* @param rule holds information related to the config for this object, typically this is the rule
* used to wrap the stage.
* @tparam INPUT the exact type of the class we are wrapping.
* @tparam BASE the generic base class for this type of stage, i.e. SparkPlan, Expression, etc.
* @tparam OUTPUT when converting to a GPU enabled version of the plan, the generic base
* type for all GPU enabled versions.
*/
abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
val wrapped: INPUT,
val conf: RapidsConf,
val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) {
/**
* The wrapped plans that should be examined
*/
val childPlans: Seq[SparkPlanMeta[_]]
/**
* The wrapped expressions that should be examined
*/
val childExprs: Seq[BaseExprMeta[_]]
/**
* The wrapped scans that should be examined
*/
val childScans: Seq[ScanMeta[_]]
/**
* The wrapped partitioning that should be examined
*/
val childParts: Seq[PartMeta[_]]
/**
* The wrapped data writing commands that should be examined
*/
val childDataWriteCmds: Seq[DataWritingCommandMeta[_]]
/** The wrapped runnable commands that should be examined. */
val childRunnableCmds: Seq[RunnableCommandMeta[_]] = Seq.empty
/**
* Convert what this wraps to a GPU enabled version.
*/
def convertToGpu(): OUTPUT
/**
* Keep this on the CPU, but possibly convert its children under it to run on the GPU if enabled.
* By default this just returns what is wrapped by this. For some types of operators/stages,
* like SparkPlan, each part of the query can be converted independent of other parts. As such in
* a subclass this should be overridden to do the correct thing.
*/
def convertToCpu(): BASE = wrapped
protected var cannotBeReplacedReasons: Option[mutable.Set[String]] = None
private var mustBeReplacedReasons: Option[mutable.Set[String]] = None
private var cannotReplaceAnyOfPlanReasons: Option[mutable.Set[String]] = None
private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None
private var typeConversionReasons: Option[mutable.Set[String]] = None
protected var cannotRunOnGpuBecauseOfSparkPlan: Boolean = false
protected var cannotRunOnGpuBecauseOfCost: Boolean = false
import RapidsMeta.gpuSupportedTag
/**
* Recursively force a section of the plan back onto CPU, stopping once a plan
* is reached that is already on CPU.
*/
final def recursiveCostPreventsRunningOnGpu(): Unit = {
if (canThisBeReplaced && !mustThisBeReplaced) {
costPreventsRunningOnGpu()
childDataWriteCmds.foreach(_.recursiveCostPreventsRunningOnGpu())
childRunnableCmds.foreach(_.recursiveCostPreventsRunningOnGpu())
}
}
final def costPreventsRunningOnGpu(): Unit = {
cannotRunOnGpuBecauseOfCost = true
willNotWorkOnGpu("Removed by cost-based optimizer")
childExprs.foreach(_.recursiveCostPreventsRunningOnGpu())
childParts.foreach(_.recursiveCostPreventsRunningOnGpu())
childScans.foreach(_.recursiveCostPreventsRunningOnGpu())
}
final def recursiveSparkPlanPreventsRunningOnGpu(): Unit = {
cannotRunOnGpuBecauseOfSparkPlan = true
childExprs.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childParts.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childScans.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childDataWriteCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childRunnableCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
}
final def recursiveSparkPlanRemoved(): Unit = {
shouldBeRemoved("parent plan is removed")
childExprs.foreach(_.recursiveSparkPlanRemoved())
childParts.foreach(_.recursiveSparkPlanRemoved())
childScans.foreach(_.recursiveSparkPlanRemoved())
childDataWriteCmds.foreach(_.recursiveSparkPlanRemoved())
childRunnableCmds.foreach(_.recursiveSparkPlanRemoved())
}
/**
* Call this to indicate that this should not be replaced with a GPU enabled version
* @param because why it should not be replaced.
*/
final def willNotWorkOnGpu(because: String): Unit = {
cannotBeReplacedReasons.get.add(because)
// annotate the real spark plan with the reason as well so that the information is available
// during query stage planning when AQE is on
wrapped match {
case p: SparkPlan =>
p.setTagValue(gpuSupportedTag,
p.getTagValue(gpuSupportedTag).getOrElse(Set.empty) + because)
case _ =>
}
}
final def mustBeReplaced(because: String): Unit = {
mustBeReplacedReasons.get.add(because)
}
/**
* Call this if there is a condition found that the entire plan is not allowed
* to run on the GPU.
*/
final def entirePlanWillNotWork(because: String): Unit = {
cannotReplaceAnyOfPlanReasons.get.add(because)
// recursively tag the plan so that AQE does not attempt
// to run any of the child query stages on the GPU
willNotWorkOnGpu(because)
childPlans.foreach(_.entirePlanWillNotWork(because))
}
final def shouldBeRemoved(because: String): Unit =
shouldBeRemovedReasons.get.add(because)
/**
* Call this method to record information about type conversions via DataTypeMeta.
*/
final def addConvertedDataType(expression: Expression, typeMeta: DataTypeMeta): Unit = {
typeConversionReasons.get.add(
s"$expression: ${typeMeta.reasonForConversion}")
}
/**
* Returns true if this node should be removed.
*/
final def shouldThisBeRemoved: Boolean = shouldBeRemovedReasons.exists(_.nonEmpty)
/**
* Returns true iff this could be replaced.
*/
final def canThisBeReplaced: Boolean = cannotBeReplacedReasons.exists(_.isEmpty)
/**
* Returns true iff this must be replaced because its children have already been
* replaced and this needs to also be replaced for compatibility.
*/
final def mustThisBeReplaced: Boolean = mustBeReplacedReasons.exists(_.nonEmpty)
/**
* Returns the list of reasons the entire plan can't be replaced. An empty
* set means the entire plan is ok to be replaced, do the normal checking
* per exec and children.
*/
final def entirePlanExcludedReasons: Set[String] = {
cannotReplaceAnyOfPlanReasons.getOrElse(mutable.Set.empty).toSet
}
/**
* Returns true iff all of the expressions and their children could be replaced.
*/
def canExprTreeBeReplaced: Boolean = childExprs.forall(_.canExprTreeBeReplaced)
/**
* Returns true iff all of the scans can be replaced.
*/
def canScansBeReplaced: Boolean = childScans.forall(_.canThisBeReplaced)
/**
* Returns true iff all of the partitioning can be replaced.
*/
def canPartsBeReplaced: Boolean = childParts.forall(_.canThisBeReplaced)
/**
* Return true if the resulting node in the plan will support columnar execution
*/
def supportsColumnar: Boolean = canThisBeReplaced
/**
* Returns true iff all of the data writing commands can be replaced.
*/
def canDataWriteCmdsBeReplaced: Boolean = childDataWriteCmds.forall(_.canThisBeReplaced)
def confKey: String = rule.confKey
final val operationName: String = rule.operationName
final val incompatDoc: Option[String] = rule.incompatDoc
def isIncompat: Boolean = incompatDoc.isDefined
final val disabledMsg: Option[String] = rule.disabledMsg
def isDisabledByDefault: Boolean = disabledMsg.isDefined
def initReasons(): Unit = {
cannotBeReplacedReasons = Some(mutable.Set[String]())
mustBeReplacedReasons = Some(mutable.Set[String]())
shouldBeRemovedReasons = Some(mutable.Set[String]())
cannotReplaceAnyOfPlanReasons = Some(mutable.Set[String]())
typeConversionReasons = Some(mutable.Set[String]())
}
/**
* Tag all of the children to see if they are GPU compatible first.
* Do basic common verification for the operators, and then call
* [[tagSelfForGpu]]
*/
final def tagForGpu(): Unit = {
childScans.foreach(_.tagForGpu())
childParts.foreach(_.tagForGpu())
childExprs.foreach(_.tagForGpu())
childDataWriteCmds.foreach(_.tagForGpu())
childRunnableCmds.foreach(_.tagForGpu())
childPlans.foreach(_.tagForGpu())
initReasons()
if (!conf.isOperatorEnabled(confKey, isIncompat, isDisabledByDefault)) {
if (isIncompat && !conf.isIncompatEnabled) {
willNotWorkOnGpu(s"the GPU version of ${wrapped.getClass.getSimpleName}" +
s" is not 100% compatible with the Spark version. ${incompatDoc.get}. To enable this" +
s" $operationName despite the incompatibilities please set the config" +
s" $confKey to true. You could also set ${RapidsConf.INCOMPATIBLE_OPS} to true" +
s" to enable all incompatible ops")
} else if (isDisabledByDefault) {
willNotWorkOnGpu(s"the $operationName ${wrapped.getClass.getSimpleName} has" +
s" been disabled, and is disabled by default because ${disabledMsg.get}. Set $confKey" +
s" to true if you wish to enable it")
} else {
willNotWorkOnGpu(s"the $operationName ${wrapped.getClass.getSimpleName} has" +
s" been disabled. Set $confKey to true if you wish to enable it")
}
}
tagSelfForGpu()
}
/**
* Do any extra checks and tag yourself if you are compatible or not. Be aware that this may
* already have been marked as incompatible for a number of reasons.
*
* All of your children should have already been tagged so if there are situations where you
* may need to disqualify your children for various reasons you may do it here too.
*/
def tagSelfForGpu(): Unit
protected def indent(append: StringBuilder, depth: Int): Unit =
append.append(" " * depth)
def replaceMessage: String = "run on GPU"
def noReplacementPossibleMessage(reasons: String): String = s"cannot run on GPU because $reasons"
def suppressWillWorkOnGpuInfo: Boolean = false
private def willWorkOnGpuInfo: String = cannotBeReplacedReasons match {
case None => "NOT EVALUATED FOR GPU YET"
case Some(v) if v.isEmpty &&
(cannotRunOnGpuBecauseOfSparkPlan || shouldThisBeRemoved) => "could " + replaceMessage
case Some(v) if v.isEmpty => "will " + replaceMessage
case Some(v) =>
noReplacementPossibleMessage(v.mkString("; "))
}
private def willBeRemovedInfo: String = shouldBeRemovedReasons match {
case None => ""
case Some(v) if v.isEmpty => ""
case Some(v) =>
val reasons = v.mkString("; ")
s" but is going to be removed because $reasons"
}
private def typeConversionInfo: String = typeConversionReasons match {
case None => ""
case Some(v) if v.isEmpty => ""
case Some(v) =>
"The data type of following expressions will be converted in GPU runtime: " +
v.mkString("; ")
}
/**
* When converting this to a string should we include the string representation of what this
* wraps too? This is off by default.
*/
protected val printWrapped = false
final private def getIndicatorChar: String = {
if (shouldThisBeRemoved) {
"#"
} else if (cannotRunOnGpuBecauseOfCost) {
"$"
} else if (canThisBeReplaced) {
if (cannotRunOnGpuBecauseOfSparkPlan) {
"@"
} else if (cannotRunOnGpuBecauseOfCost) {
"$"
} else {
"*"
}
} else {
"!"
}
}
def checkTimeZoneId(sessionZoneId: ZoneId): Unit = {
// Both of the Spark session time zone and JVM's default time zone should be UTC.
if (!TimeZoneDB.isSupportedTimezone(sessionZoneId)) {
willNotWorkOnGpu("Not supported zone id. " +
s"Actual session local zone id: $sessionZoneId")
}
val defaultZoneId = ZoneId.systemDefault()
if (!TimeZoneDB.isSupportedTimezone(defaultZoneId)) {
willNotWorkOnGpu(s"Not supported zone id. Actual default zone id: $defaultZoneId")
}
}
/**
* Create a string representation of this in append.
* @param strBuilder where to place the string representation.
* @param depth how far down the tree this is.
* @param all should all the data be printed or just what does not work on the GPU?
*/
def print(strBuilder: StringBuilder, depth: Int, all: Boolean): Unit = {
if ((all || !canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) &&
!suppressWillWorkOnGpuInfo) {
indent(strBuilder, depth)
strBuilder.append(getIndicatorChar)
strBuilder.append(operationName)
.append(" <")
.append(wrapped.getClass.getSimpleName)
.append("> ")
if (printWrapped) {
strBuilder.append(wrapped)
.append(" ")
}
strBuilder.append(willWorkOnGpuInfo).
append(willBeRemovedInfo)
typeConversionInfo match {
case info if info.isEmpty =>
case info => strBuilder.append(". ").append(info)
}
strBuilder.append("\n")
}
printChildren(strBuilder, depth, all)
}
private final def printChildren(append: StringBuilder, depth: Int, all: Boolean): Unit = {
childScans.foreach(_.print(append, depth + 1, all))
childParts.foreach(_.print(append, depth + 1, all))
childExprs.foreach(_.print(append, depth + 1, all))
childDataWriteCmds.foreach(_.print(append, depth + 1, all))
childRunnableCmds.foreach(_.print(append, depth + 1, all))
childPlans.foreach(_.print(append, depth + 1, all))
}
def explain(all: Boolean): String = {
val appender = new StringBuilder()
print(appender, 0, all)
appender.toString()
}
override def toString: String = {
explain(true)
}
}
/**
* Base class for metadata around `Partitioning`.
*/
abstract class PartMeta[INPUT <: Partitioning](part: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, Partitioning, GpuPartitioning](part, conf, parent, rule) {
override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty
override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
override final def tagSelfForGpu(): Unit = {
rule.getChecks.foreach(_.tag(this))
if (!canExprTreeBeReplaced) {
willNotWorkOnGpu("not all expressions can be replaced")
}
tagPartForGpu()
}
def tagPartForGpu(): Unit = {}
}
/**
* Metadata for Partitioning with no rule found
*/
final class RuleNotFoundPartMeta[INPUT <: Partitioning](
part: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends PartMeta[INPUT](part, conf, parent, new NoRuleDataFromReplacementRule) {
override def tagPartForGpu(): Unit = {
willNotWorkOnGpu(s"GPU does not currently support the operator ${part.getClass}")
}
override def convertToGpu(): GpuPartitioning =
throw new IllegalStateException("Cannot be converted to GPU")
}
/**
* Base class for metadata around `Scan`.
*/
abstract class ScanMeta[INPUT <: Scan](scan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, Scan, GpuScan](scan, conf, parent, rule) {
override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty
override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
override def tagSelfForGpu(): Unit = {}
def supportsRuntimeFilters: Boolean = false
}
/**
* Metadata for `Scan` with no rule found
*/
final class RuleNotFoundScanMeta[INPUT <: Scan](
scan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends ScanMeta[INPUT](scan, conf, parent, new NoRuleDataFromReplacementRule) {
override def tagSelfForGpu(): Unit = {
willNotWorkOnGpu(s"GPU does not currently support the operator ${scan.getClass}")
}
override def convertToGpu(): GpuScan =
throw new IllegalStateException("Cannot be converted to GPU")
}
/**
* Base class for metadata around `DataWritingCommand`.
*/
abstract class DataWritingCommandMeta[INPUT <: DataWritingCommand](
cmd: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, DataWritingCommand, GpuDataWritingCommand](cmd, conf, parent, rule) {
override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty
override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
val checkTimeZone: Boolean = true
final override def tagSelfForGpu(): Unit = {
if (checkTimeZone) {
timezoneCheck()
}
tagSelfForGpuInternal()
}
protected def tagSelfForGpuInternal(): Unit = {}
// Check whether data type of intput/output contains timestamp type, which
// is related to time zone.
// Only UTC time zone is allowed to be consistent with previous behavior
// for [[DataWritingCommand]]. Needs to override [[checkTimeZone]] to skip
// UTC time zone check in sub class of [[DataWritingCommand]].
def timezoneCheck(): Unit = {
val types = (wrapped.inputSet.map(_.dataType) ++ wrapped.outputSet.map(_.dataType)).toSet
if (types.exists(GpuOverrides.isOrContainsTimestamp(_))) {
if (!GpuOverrides.isUTCTimezone()) {
willNotWorkOnGpu("Only UTC timezone is supported. " +
s"Current timezone settings: (JVM : ${ZoneId.systemDefault()}, " +
s"session: ${SQLConf.get.sessionLocalTimeZone}). ")
}
}
}
}
/**
* Metadata for `DataWritingCommand` with no rule found
*/
final class RuleNotFoundDataWritingCommandMeta[INPUT <: DataWritingCommand](
cmd: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends DataWritingCommandMeta[INPUT](cmd, conf, parent, new NoRuleDataFromReplacementRule) {
override def tagSelfForGpuInternal(): Unit = {
willNotWorkOnGpu(s"GPU does not currently support the operator ${cmd.getClass}")
}
override def convertToGpu(): GpuDataWritingCommand =
throw new IllegalStateException("Cannot be converted to GPU")
}
/**
* Base class for metadata around `SparkPlan`.
*/
abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, SparkPlan, GpuExec](plan, conf, parent, rule) {
def tagForExplain(): Unit = {
if (!canThisBeReplaced) {
childExprs.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childParts.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childScans.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childDataWriteCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
childRunnableCmds.foreach(_.recursiveSparkPlanPreventsRunningOnGpu())
}
if (shouldThisBeRemoved) {
childExprs.foreach(_.recursiveSparkPlanRemoved())
childParts.foreach(_.recursiveSparkPlanRemoved())
childScans.foreach(_.recursiveSparkPlanRemoved())
childDataWriteCmds.foreach(_.recursiveSparkPlanRemoved())
childRunnableCmds.foreach(_.recursiveSparkPlanRemoved())
}
childPlans.foreach(_.tagForExplain())
}
def requireAstForGpuOn(exprMeta: BaseExprMeta[_]): Unit = {
// willNotWorkOnGpu does not deduplicate reasons. Most of the time that is fine
// but here we want to avoid adding the reason twice, because this method can be
// called multiple times, and also the reason can automatically be added in if
// a child expression would not work in the non-AST case either.
// So only add it if canExprTreeBeReplaced changed after requiring that the
// given expression is AST-able.
val previousExprReplaceVal = canExprTreeBeReplaced
exprMeta.requireAstForGpu()
val newExprReplaceVal = canExprTreeBeReplaced
if (previousExprReplaceVal != newExprReplaceVal &&
!newExprReplaceVal) {
willNotWorkOnGpu("not all expressions can be replaced")
}
}
override val childPlans: Seq[SparkPlanMeta[SparkPlan]] =
plan.children.map(GpuOverrides.wrapPlan(_, conf, Some(this)))
override val childExprs: Seq[BaseExprMeta[_]] =
plan.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = Map.empty
var cpuCost: Double = 0
var gpuCost: Double = 0
var estimatedOutputRows: Option[BigInt] = None
override def convertToCpu(): SparkPlan = {
wrapped.withNewChildren(childPlans.map(_.convertIfNeeded()))
}
def getReasonsNotToReplaceEntirePlan: Set[String] = {
val childReasons = childPlans.flatMap(_.getReasonsNotToReplaceEntirePlan)
entirePlanExcludedReasons ++ childReasons
}
// For adaptive execution we have to ensure we mark everything properly
// the first time through and that has to match what happens when AQE
// splits things up and does the subquery analysis at the shuffle boundaries.
// If the AQE subquery analysis changes the plan from what is originally
// marked we can end up with mismatches like happened in:
// https://github.com/NVIDIA/spark-rapids/issues/1423
// AQE splits subqueries at shuffle boundaries which means that it only
// sees the children at that point. So in our fix up exchange we only
// look at the children and mark is at will not work on GPU if the
// child can't be replaced.
private def fixUpExchangeOverhead(): Unit = {
childPlans.foreach(_.fixUpExchangeOverhead())
if (wrapped.isInstanceOf[ShuffleExchangeExec] &&
!SparkShimImpl.isExecutorBroadcastShuffle(wrapped.asInstanceOf[ShuffleExchangeExec]) &&
!childPlans.exists(_.supportsColumnar) &&
(plan.conf.adaptiveExecutionEnabled ||
!parent.exists(_.supportsColumnar))) {
// Some platforms can present a plan where the root of the plan is a shuffle followed by
// an AdaptiveSparkPlanExec. If it looks like the child AdaptiveSparkPlanExec will end up
// on the GPU than this shuffle should be GPU as well.
val shuffle = wrapped.asInstanceOf[ShuffleExchangeExec]
val isChildOnGpu = shuffle.child match {
case ap: AdaptiveSparkPlanExec if parent.isEmpty => GpuOverrides.probablyGpuPlan(ap, conf)
case _ => false
}
if (!isChildOnGpu) {
willNotWorkOnGpu("Columnar exchange without columnar children is inefficient")
}
childPlans.head.wrapped
.getTagValue(GpuOverrides.preRowToColProjection).foreach { r2c =>
wrapped.setTagValue(GpuOverrides.preRowToColProjection, r2c)
}
}
}
private def fixUpBroadcastJoins(): Unit = {
childPlans.foreach(_.fixUpBroadcastJoins())
wrapped match {
case _: BroadcastHashJoinExec =>
this.asInstanceOf[GpuBroadcastHashJoinMetaBase].checkTagForBuildSide()
case _: BroadcastNestedLoopJoinExec =>
this.asInstanceOf[GpuBroadcastNestedLoopJoinMetaBase].checkTagForBuildSide()
case _ => // noop
}
}
/**
* Run rules that happen for the entire tree after it has been tagged initially.
*/
def runAfterTagRules(): Unit = {
// In the first pass tagSelfForGpu will deal with each operator individually.
// Children will be tagged first and then their parents will be tagged. This gives
// flexibility when tagging yourself to look at your children and disable yourself if your
// children are not all on the GPU. In some cases we need to be able to disable our
// children too, or in this case run a rule that will disable operations when looking at
// more of the tree. These exceptions should be documented here. We need to take special care
// that we take into account all side-effects of these changes, because we are **not**
// re-triggering the rules associated with parents, grandparents, etc. If things get too
// complicated we may need to update this to have something with triggers, but then we would
// have to be very careful to avoid loops in the rules.
// RULES:
// 1) If file scan plan runs on the CPU, and the following plans run on GPU, then
// GpuRowToColumnar will be inserted. GpuRowToColumnar will invalid input_file_xxx operations,
// So input_file_xxx in the following GPU operators will get empty value.
// InputFileBlockRule is to prevent the SparkPlans
// [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU
InputFileBlockRule(this.asInstanceOf[SparkPlanMeta[SparkPlan]])
// 2) For shuffles, avoid replacing the shuffle if the child is not going to be replaced.
fixUpExchangeOverhead()
// 3) Some child nodes can't run on GPU if parent nodes can't run on GPU.
// WriteFilesExec is a new operator from Spark version 340,
// Did not extract a shim code for simplicity
tagChildAccordingToParent(this.asInstanceOf[SparkPlanMeta[SparkPlan]], "WriteFilesExec")
// 4) InputFileBlockRule may change the meta of broadcast join and its child plans,
// and this change may cause mismatch between the join and its build side
// BroadcastExchangeExec, leading to errors. Need to fix the mismatch.
fixUpBroadcastJoins()
}
/**
* tag child node can't run on GPU if parent node can't run on GPU and child node is a `typeName`
* From Spark 340, plan is like:
* InsertIntoHadoopFsRelationCommand
* +- WriteFiles
* +- sub plan
* Instead of:
* InsertIntoHadoopFsRelationCommand
* +- sub plan
* WriteFiles is a temporary node and does not have input and output, it acts like a tag node.
* @param p plan
* @param typeName type name
*/
private def tagChildAccordingToParent(p: SparkPlanMeta[SparkPlan], typeName: String): Unit = {
p.childPlans.foreach(e => tagChildAccordingToParent(e, typeName))
if (p.wrapped.getClass.getSimpleName.equals(typeName)) {
assert(p.parent.isDefined)
if (!p.parent.get.canThisBeReplaced) {
// parent can't run on GPU, also tag this.
p.willNotWorkOnGpu(
s"$typeName can't run on GPU because parent can't run on GPU")
}
}
}
override final def tagSelfForGpu(): Unit = {
rule.getChecks.foreach(_.tag(this))
if (!canExprTreeBeReplaced) {
willNotWorkOnGpu("not all expressions can be replaced")
}
if (!canScansBeReplaced) {
willNotWorkOnGpu("not all scans can be replaced")
}
if (!canPartsBeReplaced) {
willNotWorkOnGpu("not all partitioning can be replaced")
}
if (!canDataWriteCmdsBeReplaced) {
willNotWorkOnGpu("not all data writing commands can be replaced")
}
if (!childRunnableCmds.forall(_.canThisBeReplaced)) {
willNotWorkOnGpu("not all commands can be replaced")
}
// All ExecMeta extend SparkMeta. We need to check if the requiredChildDistribution
// is recognized or not. If it's unrecognized Distribution then we fall back to CPU.
plan.requiredChildDistribution.foreach { d =>
if (!DistributionUtil.isSupported(d)) {
willNotWorkOnGpu(s"unsupported required distribution: $d")
}
}
checkExistingTags()
tagPlanForGpu()
}
/**
* When AQE is enabled and we are planning a new query stage, we need to look at meta-data
* previously stored on the spark plan to determine whether this operator can run on GPU
*/
def checkExistingTags(): Unit = {
wrapped.getTagValue(RapidsMeta.gpuSupportedTag)
.foreach(_.diff(cannotBeReplacedReasons.get)
.foreach(willNotWorkOnGpu))
}
/**
* Called to verify that this plan will work on the GPU. Generic checks will have already been
* done. In general this method should only tag this operator as bad. If it needs to tag
* one of its children please take special care to update the comment inside
* `tagSelfForGpu` so we don't end up with something that could be cyclical.
*/
def tagPlanForGpu(): Unit = {}
/**
* If this is enabled to be converted to a GPU version convert it and return the result, else
* do what is needed to possibly convert the rest of the plan.
*/
final def convertIfNeeded(): SparkPlan = {
if (shouldThisBeRemoved) {
if (childPlans.isEmpty) {
throw new IllegalStateException("can't remove when plan has no children")
} else if (childPlans.size > 1) {
throw new IllegalStateException("can't remove when plan has more than 1 child")
}
childPlans.head.convertIfNeeded()
} else {
if (canThisBeReplaced) {
convertToGpu()
} else {
convertToCpu()
}
}
}
/**
* Gets output attributes of current SparkPlanMeta, which is supposed to be called during
* type checking for the current plan.
*
* By default, it simply returns the output of wrapped plan. For specific plans, they can
* override outputTypeMetas to apply custom conversions on the output of wrapped plan. For plans
* which just pass through the schema of childPlan, they can set useOutputAttributesOfChild to
* true, in order to propagate the custom conversions of childPlan if they exist.
*/
def outputAttributes: Seq[Attribute] = outputTypeMetas match {
case Some(typeMetas) =>
require(typeMetas.length == wrapped.output.length,
"The length of outputTypeMetas doesn't match to the length of plan's output")
wrapped.output.zip(typeMetas).map {
case (ar, meta) if meta.typeConverted =>
addConvertedDataType(ar, meta)
AttributeReference(ar.name, meta.dataType.get, ar.nullable, ar.metadata)(
ar.exprId, ar.qualifier)
case (ar, _) =>
ar
}
case None if useOutputAttributesOfChild =>
require(wrapped.children.length == 1,
"useOutputAttributesOfChild ONLY works on UnaryPlan")
// We pass through the outputAttributes of the child plan only if it will be really applied
// in the runtime. We can pass through either if child plan can be replaced by GPU overrides;
// or if child plan is available for runtime type conversion. The later condition indicates
// the CPU to GPU data transition will be introduced as the pre-processing of the adjacent
// GpuRowToColumnarExec, though the child plan can't produce output attributes for GPU.
// Otherwise, we should fetch the outputAttributes from the wrapped plan.
//
// We can safely call childPlan.canThisBeReplaced here, because outputAttributes is called
// via tagSelfForGpu. At this point, tagging of the child plan has already taken place.
if (childPlans.head.canThisBeReplaced || childPlans.head.availableRuntimeDataTransition) {
childPlans.head.outputAttributes
} else {
wrapped.output
}
case None =>
wrapped.output
}
/**
* Returns whether the resulting SparkPlan supports columnar execution
*/
override def supportsColumnar: Boolean = wrapped.supportsColumnar || canThisBeReplaced
/**
* Overrides this method to implement custom conversions for specific plans.
*/
protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] = None
/**
* Whether to pass through the outputAttributes of childPlan's meta, only for UnaryPlan
*/
protected val useOutputAttributesOfChild: Boolean = false
/**
* Whether there exists runtime data transition for the wrapped plan, if true, the overriding
* of output attributes will always work even when the wrapped plan can't be replaced by GPU
* overrides.
*/
val availableRuntimeDataTransition: Boolean = false
}
/**
* Metadata for `SparkPlan` with no rule found
*/
final class RuleNotFoundSparkPlanMeta[INPUT <: SparkPlan](
plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends SparkPlanMeta[INPUT](plan, conf, parent, new NoRuleDataFromReplacementRule) {
override def tagPlanForGpu(): Unit =
willNotWorkOnGpu(s"GPU does not currently support the operator ${plan.getClass}")
override def convertToGpu(): GpuExec =
throw new IllegalStateException("Cannot be converted to GPU")
}
/**
* Metadata for `SparkPlan` that should not be replaced or have any kind of warning for
*/
final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan](
plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends SparkPlanMeta[INPUT](plan, conf, parent, new NoRuleDataFromReplacementRule) {
/** We don't want to spam the user with messages about these operators */
override def suppressWillWorkOnGpuInfo: Boolean = true
override def tagPlanForGpu(): Unit =
willNotWorkOnGpu(s"there is no need to replace ${plan.getClass}")
override def convertToGpu(): GpuExec =
throw new IllegalStateException("Cannot be converted to GPU")
}
sealed abstract class ExpressionContext
object ProjectExprContext extends ExpressionContext {
override def toString: String = "project"
}
/**
* This is a special context. All other contexts are determined by the Spark query in a generic way.
* AST support in many cases is an optimization and so it is tagged and checked after it is
* determined that this operation will run on the GPU. In other cases it is required. In those cases
* AST support is determined and used when tagging the metas to see if they will work on the GPU or
* not. This part is not done automatically.
*/
object AstExprContext extends ExpressionContext {
override def toString: String = "AST"
val notSupportedMsg = "this expression does not support AST"
}
object GroupByAggExprContext extends ExpressionContext {
override def toString: String = "aggregation"
}
object ReductionAggExprContext extends ExpressionContext {
override def toString: String = "reduction"
}
object WindowAggExprContext extends ExpressionContext {
override def toString: String = "window"
}
object ExpressionContext {
private[this] def findParentPlanMeta(meta: BaseExprMeta[_]): Option[SparkPlanMeta[_]] =
meta.parent match {
case Some(p: BaseExprMeta[_]) => findParentPlanMeta(p)
case Some(p: SparkPlanMeta[_]) => Some(p)
case _ => None
}
def getAggregateFunctionContext(meta: BaseExprMeta[_]): ExpressionContext = {
val parent = findParentPlanMeta(meta)
assert(parent.isDefined, "It is expected that an aggregate function is a child of a SparkPlan")
parent.get.wrapped match {
case agg: SparkPlan if SparkShimImpl.isWindowFunctionExec(agg) =>
WindowAggExprContext
case agg: AggregateInPandasExec =>
if (agg.groupingExpressions.isEmpty) {
ReductionAggExprContext
} else {
GroupByAggExprContext
}
case agg: BaseAggregateExec =>
// Since Spark 3.5, Python udfs are wrapped in AggregateInPandasExec. UDFs for earlier
// versions of Spark should be handled by the BaseAggregateExec
if (agg.groupingExpressions.isEmpty) {
ReductionAggExprContext
} else {
GroupByAggExprContext
}
case _ => throw new IllegalStateException(
s"Found an aggregation function in an unexpected context $parent")
}
}
def getRegularOperatorContext(meta: RapidsMeta[_, _, _]): ExpressionContext = meta.wrapped match {
case _: Expression if meta.parent.isDefined => getRegularOperatorContext(meta.parent.get)
case _ => ProjectExprContext
}
}
/**
* The metadata around `DataType`, which records the original data type, the desired data type for
* GPU overrides, and the reason of potential conversion. The metadata is to ensure TypeChecks
* tagging the actual data types for GPU runtime, since data types of GPU overrides may slightly
* differ from original CPU counterparts.
*/
class DataTypeMeta(
val wrapped: Option[DataType],
desired: Option[DataType] = None,
reason: Option[String] = None) {
lazy val dataType: Option[DataType] = desired match {
case Some(dt) => Some(dt)
case None => wrapped
}
// typeConverted will only be true if there exists DataType in wrapped expression
lazy val typeConverted: Boolean = dataType.nonEmpty && dataType != wrapped
/**
* Returns the reason for conversion if exists
*/
def reasonForConversion: String = {
val reasonMsg = (if (typeConverted) reason else None)
.map(r => s", because $r").getOrElse("")
s"Converted ${wrapped.getOrElse("N/A")} to " +
s"${dataType.getOrElse("N/A")}" + reasonMsg
}
}
object DataTypeMeta {
/**
* create DataTypeMeta from Expression
*/
def apply(expr: Expression, overrideType: Option[DataType]): DataTypeMeta = {
val wrapped = try {
Some(expr.dataType)
} catch {
case _: java.lang.UnsupportedOperationException => None
case _: org.apache.spark.SparkException => None
}
new DataTypeMeta(wrapped, overrideType)
}
}
/**
* Base class for metadata around `Expression`.
*/
abstract class BaseExprMeta[INPUT <: Expression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, Expression, Expression](expr, conf, parent, rule) {
private val cannotBeAstReasons: mutable.Set[String] = mutable.Set.empty
override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty
override val childExprs: Seq[BaseExprMeta[_]] =
expr.children.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
override val printWrapped: Boolean = true
def dataType: DataType = expr.dataType
val ignoreUnsetDataTypes = false
override def canExprTreeBeReplaced: Boolean =
canThisBeReplaced && super.canExprTreeBeReplaced
/**
* Gets the DataTypeMeta of current BaseExprMeta, which is supposed to be called in the
* tag methods of expression-level type checks.
*
* By default, it simply returns the data type of wrapped expression. But for specific
* expressions, they can easily override data type for type checking through calling the
* method `overrideDataType`.
*/
def typeMeta: DataTypeMeta = DataTypeMeta(wrapped.asInstanceOf[Expression], overrideType)
/**
* Overrides the data type of the wrapped expression during type checking.
*
* NOTICE: This method will NOT modify the wrapped expression itself. Therefore, the actual
* transition on data type is still necessary when converting this expression to GPU.
*/
def overrideDataType(dt: DataType): Unit = overrideType = Some(dt)
private var overrideType: Option[DataType] = None
lazy val context: ExpressionContext = expr match {
case _: WindowExpression => WindowAggExprContext
case _: WindowFunction => WindowAggExprContext
case _: AggregateFunction => ExpressionContext.getAggregateFunctionContext(this)
case _: AggregateExpression => ExpressionContext.getAggregateFunctionContext(this)
case _ => ExpressionContext.getRegularOperatorContext(this)
}
val isFoldableNonLitAllowed: Boolean = conf.isFoldableNonLitAllowed
// There are 4 levels of timezone check in GPU plan tag phase:
// Level 1: Check whether an expression is related to timezone. This is achieved by
// [[needTimeZoneCheck]] below.
// Level 2: Check related expression has been implemented with timezone. There is a
// toggle flag [[isTimeZoneSupported]] for this. If false, fallback to UTC-only check as
// before. If yes, move to next level check. When we add timezone support for a related
// function. [[isTimeZoneSupported]] should be override as true.
// Level 3: Check whether the desired timezone is supported by Gpu kernel.
def checkExprForTimezone(): Unit = {
// Level 1 check
if (!needTimeZoneCheck) return
// Level 2 check
if (!isTimeZoneSupported) return checkUTCTimezone(this, getZoneId())
// Level 3 check
val zoneId = getZoneId()
if (!GpuTimeZoneDB.isSupportedTimeZone(zoneId)) {
willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(zoneId.toString))
}
}
protected def getZoneId(): ZoneId = {
this.wrapped match {
case tzExpr: TimeZoneAwareExpression => tzExpr.zoneId
case ts: UTCTimestamp => {
assert(false, s"Have to override getZoneId() of BaseExprMeta in ${this.getClass.toString}")
throw new IllegalArgumentException(s"Failed to get zone id from ${ts.getClass.toString}")
}
case _ => throw new IllegalArgumentException(
s"Zone check should never been happened to ${this.getClass.toString} " +
"which is not timezone related")
}
}
// Level 1 timezone checking flag
// Both [[isTimeZoneSupported]] and [[needTimeZoneCheck]] are needed to check whether timezone
// check needed. For cast expression, only some cases are needed pending on its data type and
// its child's data type.
//
//+------------------------+-------------------+-----------------------------------------+
//| Value | needTimeZoneCheck | isTimeZoneSupported |
//+------------------------+-------------------+-----------------------------------------+
//| TimezoneAwareExpression| True | False by default, True when implemented |
//| Others | False | N/A (will not be checked) |
//+------------------------+-------------------+-----------------------------------------+
lazy val needTimeZoneCheck: Boolean = {
wrapped match {
// CurrentDate expression will not go through this even it's a `TimeZoneAwareExpression`.
// It will be treated as literal in Rapids.
case _: TimeZoneAwareExpression =>
if (wrapped.isInstanceOf[Cast]) {
val cast = wrapped.asInstanceOf[Cast]
needsTimeZone(cast.child.dataType, cast.dataType)
} else if(PlanShims.isAnsiCast(wrapped)) {
val (from, to) = PlanShims.extractAnsiCastTypes(wrapped)
needsTimeZone(from, to)
} else{
true
}
case _ => false
}
}
// Mostly base on Spark existing [[Cast.needsTimeZone]] method. Two changes are made:
// 1. Override date related based on https://github.com/apache/spark/pull/40524 merged
// 2. Existing `needsTimezone` doesn't consider complex types to string which is timezone
// related. (incl. struct/map/list to string).
private[this] def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
case (StringType, DateType) => false
case (DateType, StringType) => false
case (ArrayType(fromType, _), StringType) => needsTimeZone(fromType, to)
case (MapType(fromKey, fromValue, _), StringType) =>
needsTimeZone(fromKey, to) || needsTimeZone(fromValue, to)
case (StructType(fromFields), StringType) =>
fromFields.exists {
case fromField =>
needsTimeZone(fromField.dataType, to)
}
// Avoid copying full implementation here. Otherwise needs to create shim for TimestampNTZ
// since Spark 3.4.0
case _ => Cast.needsTimeZone(from, to)
}
// Level 2 timezone checking flag, need to override to true when supports timezone in functions
// Useless if it's not timezone related expression defined in [[needTimeZoneCheck]]
def isTimeZoneSupported: Boolean = false
/**
* Timezone check which only allows UTC timezone. This is consistent with previous behavior.
*
* @param meta to check whether it's UTC
*/
def checkUTCTimezone(meta: RapidsMeta[_, _, _], zoneId: ZoneId): Unit = {
if (!GpuOverrides.isUTCTimezone(zoneId)) {
meta.willNotWorkOnGpu(
TimeZoneDB.nonUTCTimezoneNotSupportedStr(meta.wrapped.getClass.toString))
}
}
final override def tagSelfForGpu(): Unit = {
if (wrapped.foldable && !GpuOverrides.isLit(wrapped) && !isFoldableNonLitAllowed) {
willNotWorkOnGpu(s"Cannot run on GPU. Is ConstantFolding excluded? Expression " +
s"$wrapped is foldable and operates on non literals")
}
rule.getChecks.foreach(_.tag(this))
checkExprForTimezone()
tagExprForGpu()
}
/**
* Called to verify that this expression will work on the GPU. For most expressions without
* extra checks all of the checks should have already been done.
*/
def tagExprForGpu(): Unit = {}
final def willNotWorkInAst(because: String): Unit = cannotBeAstReasons.add(because)
final def canThisBeAst: Boolean = {
tagForAst()
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}
/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}
final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
willNotWorkOnGpu(s"AST is required and $reason")
}
childExprs.foreach(_.requireAstForGpu())
}
private var taggedForAst = false
private final def tagForAst(): Unit = {
if (!taggedForAst) {
if (wrapped.foldable && !GpuOverrides.isLit(wrapped)) {
willNotWorkInAst(s"Cannot convert to AST. Is ConstantFolding excluded? Expression " +
s"$wrapped is foldable and operates on non literals")
}
rule.getChecks.foreach {
case exprCheck: ExprChecks => exprCheck.tagAst(this)
case other => throw new IllegalArgumentException(s"Unexpected check found $other")
}
tagSelfForAst()
taggedForAst = true
}
}
/** Called to verify that this expression will work as a GPU AST expression. */
protected def tagSelfForAst(): Unit = {
// NOOP
}
protected def willWorkInAstInfo: String = {
if (cannotBeAstReasons.isEmpty) {
"will run in AST"
} else {
s"cannot be converted to GPU AST because ${cannotBeAstReasons.mkString(";")}"
}
}
/**
* Create a string explanation for whether this expression tree can be converted to an AST
* @param strBuilder where to place the string representation.
* @param depth how far down the tree this is.
* @param all should all the data be printed or just what does not work in the AST?
*/
protected def printAst(strBuilder: StringBuilder, depth: Int, all: Boolean): Unit = {
if (all || !canThisBeAst) {
indent(strBuilder, depth)
strBuilder.append(operationName)
.append(" <")
.append(wrapped.getClass.getSimpleName)
.append("> ")
if (printWrapped) {
strBuilder.append(wrapped)
.append(" ")
}
strBuilder.append(willWorkInAstInfo).append("\n")
}
childExprs.foreach(_.printAst(strBuilder, depth + 1, all))
}
def explainAst(all: Boolean): String = {
tagForAst()
val appender = new StringBuilder()
printAst(appender, 0, all)
appender.toString()
}
}
abstract class ExprMeta[INPUT <: Expression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BaseExprMeta[INPUT](expr, conf, parent, rule) {
override def convertToGpu(): GpuExpression
}
/**
* Base class for metadata around `RuntimeReplaceableExpression`. We will never
* get a RuntimeReplaceableExpression as it will be converted to the actual Expression
* by the time we get it. We need to have this here as some Expressions e.g. UnaryPositive
* don't extend UnaryExpression.
*/
abstract class RuntimeReplaceableUnaryExprMeta[INPUT <: RuntimeReplaceable](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends UnaryExprMetaBase[INPUT](expr, conf, parent, rule)
/** Base metadata class for RuntimeReplaceable expressions that support conversion to AST as well */
abstract class RuntimeReplaceableUnaryAstExprMeta[INPUT <: RuntimeReplaceable](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RuntimeReplaceableUnaryExprMeta[INPUT](expr, conf, parent, rule)
/**
* Base class for metadata around `UnaryExpression`.
*/
abstract class UnaryExprMeta[INPUT <: Expression with UnaryLike[Expression]](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) extends UnaryExprMetaBase[INPUT](expr, conf, parent, rule)
protected abstract class UnaryExprMetaBase[INPUT <: Expression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression =
convertToGpu(childExprs.head.convertToGpu())
def convertToGpu(child: Expression): GpuExpression
/**
* `ConstantFolding` executes early in the logical plan process, which
* simplifies many things before we get to the physical plan. If you enable
* AQE, some optimizations can cause new expressions to show up that would have been
* folded in by the logical plan optimizer (like `cast(null as bigint)` which just
* becomes Literal(null, Long) after `ConstantFolding`), so enabling this here
* allows us to handle these when they are generated by an AQE rule.
*/
override val isFoldableNonLitAllowed: Boolean = true
}
/** Base metadata class for unary expressions that support conversion to AST as well */
abstract class UnaryAstExprMeta[INPUT <: UnaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends UnaryExprMeta[INPUT](expr, conf, parent, rule) {
}
/**
* Base class for metadata around `AggregateFunction`.
*/
abstract class AggExprMeta[INPUT <: AggregateFunction](
val expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def tagExprForGpu(): Unit = {
tagAggForGpu()
if (needsAnsiCheck) {
GpuOverrides.checkAndTagAnsiAgg(ansiTypeToCheck, this)
}
}
// not all aggs overwrite this
def tagAggForGpu(): Unit = {}
override final def convertToGpu(): GpuExpression =
convertToGpu(childExprs.map(_.convertToGpu()))
def convertToGpu(childExprs: Seq[Expression]): GpuExpression
// Set to false if the aggregate doesn't overflow and therefore
// shouldn't error
val needsAnsiCheck: Boolean = true
// The type to use to determine whether the aggregate could overflow.
// Set to None, if we should fallback for all types
val ansiTypeToCheck: Option[DataType] = Some(expr.dataType)
}
/**
* Base class for metadata around `ImperativeAggregate`.
*/
abstract class ImperativeAggExprMeta[INPUT <: ImperativeAggregate](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends AggExprMeta[INPUT](expr, conf, parent, rule) {
def convertToGpu(childExprs: Seq[Expression]): GpuExpression
}
/**
* Base class for metadata around `TypedImperativeAggregate`.
*/
abstract class TypedImperativeAggExprMeta[INPUT <: TypedImperativeAggregate[_]](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ImperativeAggExprMeta[INPUT](expr, conf, parent, rule) {
/**
* Returns aggregation buffer with the actual data type under GPU runtime. This method is
* called to override the data types of typed imperative aggregation buffers during GPU
* overriding.
*/
def aggBufferAttribute: AttributeReference
/**
* Returns a buffer converter who can generate a Expression to transform the aggregation buffer
* of wrapped function from CPU format to GPU format. The conversion occurs on the CPU, so the
* generated expression should be a CPU Expression executed by row.
*/
def createCpuToGpuBufferConverter(): CpuToGpuAggregateBufferConverter =
throw new NotImplementedError("The method should be implemented by specific functions")
/**
* Returns a buffer converter who can generate a Expression to transform the aggregation buffer
* of wrapped function from GPU format to CPU format. The conversion occurs on the CPU, so the
* generated expression should be a CPU Expression executed by row.
*/
def createGpuToCpuBufferConverter(): GpuToCpuAggregateBufferConverter =
throw new NotImplementedError("The method should be implemented by specific functions")
/**
* Whether buffers of current Aggregate is able to be converted from CPU to GPU format and
* reversely in runtime. If true, it assumes both createCpuToGpuBufferConverter and
* createGpuToCpuBufferConverter are implemented.
*/
val supportBufferConversion: Boolean = false
}
/**
* Base class for metadata around `BinaryExpression`.
*/
abstract class BinaryExprMeta[INPUT <: BinaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression = {
val Seq(lhs, rhs) = childExprs.map(_.convertToGpu())
convertToGpu(lhs, rhs)
}
def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression
}
/** Base metadata class for binary expressions that support conversion to AST */
abstract class BinaryAstExprMeta[INPUT <: BinaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[INPUT](expr, conf, parent, rule) {
override def tagSelfForAst(): Unit = {
if (wrapped.left.dataType != wrapped.right.dataType) {
willNotWorkInAst("AST binary expression operand types must match, found " +
s"${wrapped.left.dataType},${wrapped.right.dataType}")
}
}
}
/**
* Base class for metadata around `TernaryExpression`.
*/
abstract class TernaryExprMeta[INPUT <: TernaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression = {
val Seq(child0, child1, child2) = childExprs.map(_.convertToGpu())
convertToGpu(child0, child1, child2)
}
def convertToGpu(val0: Expression, val1: Expression,
val2: Expression): GpuExpression
}
/**
* Base class for metadata around `QuaternaryExpression`.
*/
abstract class QuaternaryExprMeta[INPUT <: QuaternaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression = {
val Seq(child0, child1, child2, child3) = childExprs.map(_.convertToGpu())
convertToGpu(child0, child1, child2, child3)
}
def convertToGpu(val0: Expression, val1: Expression,
val2: Expression, val3: Expression): GpuExpression
}
abstract class String2TrimExpressionMeta[INPUT <: String2TrimExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression = {
val gpuCol :: gpuTrimParam = childExprs.map(_.convertToGpu())
convertToGpu(gpuCol, gpuTrimParam.headOption)
}
def convertToGpu(column: Expression, target: Option[Expression] = None): GpuExpression
}
/**
* Base class for metadata around `ComplexTypeMergingExpression`.
*/
abstract class ComplexTypeMergingExprMeta[INPUT <: ComplexTypeMergingExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
override final def convertToGpu(): GpuExpression =
convertToGpu(childExprs.map(_.convertToGpu()))
def convertToGpu(childExprs: Seq[Expression]): GpuExpression
}
/**
* Metadata for `Expression` with no rule found
*/
final class RuleNotFoundExprMeta[INPUT <: Expression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends ExprMeta[INPUT](expr, conf, parent, new NoRuleDataFromReplacementRule) {
override def tagExprForGpu(): Unit =
willNotWorkOnGpu(s"GPU does not currently support the operator ${expr.getClass}")
override def convertToGpu(): GpuExpression =
throw new IllegalStateException(s"Cannot be converted to GPU ${expr.getClass} " +
s"${expr.dataType} $expr")
}
/** Base class for metadata around `RunnableCommand`. */
abstract class RunnableCommandMeta[INPUT <: RunnableCommand](
cmd: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends RapidsMeta[INPUT, RunnableCommand, RunnableCommand](cmd, conf, parent, rule)
{
override val childPlans: Seq[SparkPlanMeta[_]] = Seq.empty
override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty
override val childScans: Seq[ScanMeta[_]] = Seq.empty
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty
}
/** Metadata for `RunnableCommand` with no rule found */
final class RuleNotFoundRunnableCommandMeta[INPUT <: RunnableCommand](
cmd: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
extends RunnableCommandMeta[INPUT](cmd, conf, parent, new NoRuleDataFromReplacementRule) {
// Do not complain by default, as many commands are metadata-only.
override def suppressWillWorkOnGpuInfo: Boolean = true
override def tagSelfForGpu(): Unit =
willNotWorkOnGpu(s"GPU does not currently support the command ${cmd.getClass}")
override def convertToGpu(): RunnableCommand =
throw new IllegalStateException("Cannot be converted to GPU")
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy