All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.nvidia.spark.rapids.tool.planparser.SQLPlanParser.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.nvidia.spark.rapids.tool.planparser

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, WeakHashMap}
import scala.util.control.NonFatal
import scala.util.matching.Regex

import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, ExecHelper, JoinType, RDDCheckHelper, ToolUtils, UnsupportedExpr}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

object OpActions extends Enumeration {
  type OpAction = Value
  val NONE, IgnoreNoPerf, IgnorePerf, Triage = Value
}

object OpTypes extends Enumeration {
  type OpType = Value
  val ReadExec, ReadRDD, WriteExec, Exec, Expr, UDF, DataSet = Value
}

object UnsupportedReasons extends Enumeration {
  type UnsupportedReason = Value
  val IS_UDF, CONTAINS_UDF,
      IS_DATASET, CONTAINS_DATASET,
      IS_UNSUPPORTED, CONTAINS_UNSUPPORTED_EXPR,
      UNSUPPORTED_IO_FORMAT = Value

  // Mutable map to cache custom reasons
  private val customReasonsCache = WeakHashMap.empty[String, Value]

  // Method to get or create a custom reason
  def CUSTOM_REASON(reason: String): Value = {
    customReasonsCache.getOrElseUpdate(reason, new Val(nextId, reason))
  }

  def reportUnsupportedReason(unsupportedReason: UnsupportedReason): String = {
    unsupportedReason match {
      case IS_UDF => "Is UDF"
      case CONTAINS_UDF => "Contains UDF"
      case IS_DATASET => "Is Dataset or RDD"
      case CONTAINS_DATASET => "Contains Dataset or RDD"
      case IS_UNSUPPORTED => "Unsupported"
      case CONTAINS_UNSUPPORTED_EXPR => "Contains unsupported expr"
      case UNSUPPORTED_IO_FORMAT => "Unsupported IO format"
      case customReason @ _  => customReason.toString
    }
  }
}

case class UnsupportedExecSummary(
    sqlId: Long,
    execId: Long,
    execValue: String,
    opType: OpTypes.OpType,
    reason: UnsupportedReasons.UnsupportedReason,
    opAction: OpActions.OpAction,
    isExpression: Boolean = false) {

  val finalOpType: String = if (opType.equals(OpTypes.UDF) || opType.equals(OpTypes.DataSet)) {
    s"${OpTypes.Exec.toString}"
  } else {
    s"${opType.toString}"
  }

  val unsupportedOperator: String = execValue

  val details: String = UnsupportedReasons.reportUnsupportedReason(reason)
}

case class ExecInfo(
    sqlID: Long,
    exec: String,
    expr: String,
    speedupFactor: Double,
    duration: Option[Long],
    nodeId: Long,
    opType: OpTypes.OpType,
    isSupported: Boolean,
    children: Option[Seq[ExecInfo]], // only one level deep
    var stages: Set[Int],
    var shouldRemove: Boolean,
    var unsupportedExecReason: String,
    unsupportedExprs: Seq[UnsupportedExpr],
    dataSet: Boolean,
    udf: Boolean,
    shouldIgnore: Boolean) {

  private def childrenToString = {
    val str = children.map { c =>
      c.map("       " + _.toString).mkString("\n")
    }.getOrElse("")
    if (str.nonEmpty) {
      "\n" + str
    } else {
      str
    }
  }

  override def toString: String = {
    s"exec: $exec, expr: $expr, sqlID: $sqlID , speedupFactor: $speedupFactor, " +
      s"duration: $duration, nodeId: $nodeId, " +
      s"isSupported: $isSupported, children: " +
      s"$childrenToString, stages: ${stages.mkString(",")}, " +
      s"shouldRemove: $shouldRemove, shouldIgnore: $shouldIgnore"
  }

  def setStages(stageIDs: Set[Int]): Unit = {
    stages = stageIDs
  }

  def appendToStages(stageIDs: Set[Int]): Unit = {
    stages ++= stageIDs
  }

  def setShouldRemove(value: Boolean): Unit = {
    shouldRemove ||= value
  }

  def setUnsupportedExecReason(reason: String): Unit = {
    unsupportedExecReason = reason
  }

  // Helper function to determine the unsupported reason
  def determineUnsupportedReason(reason: String,
      knownReason: UnsupportedReasons.Value): UnsupportedReasons.Value = {
    if (reason.nonEmpty) UnsupportedReasons.CUSTOM_REASON(reason) else knownReason
  }

  def getOpAction: OpActions.OpAction = {
    // shouldRemove is checked first because sometimes an exec could have both flag set to true,
    // but then we care about having the "NoPerf" part
    if (isSupported) {
      OpActions.NONE
    } else {
      if (shouldRemove) {
        OpActions.IgnoreNoPerf
      } else if (shouldIgnore) {
        OpActions.IgnorePerf
      } else  {
        OpActions.Triage
      }
    }
  }

  private def getUnsupportedReason: UnsupportedReasons.UnsupportedReason = {
    if (children.isDefined) {
      // TODO: Handle the children
    }

    if (udf) {
      UnsupportedReasons.CONTAINS_UDF
    } else if (dataSet) {
      if (unsupportedExprs.isEmpty) { // case when the node itself is a DataSet or RDD
        UnsupportedReasons.IS_DATASET
      } else {
        UnsupportedReasons.CONTAINS_DATASET
      }
    } else if (unsupportedExprs.nonEmpty) {
      UnsupportedReasons.CONTAINS_UNSUPPORTED_EXPR
    } else {
      opType match {
        case OpTypes.ReadExec | OpTypes.WriteExec => UnsupportedReasons.UNSUPPORTED_IO_FORMAT
        case _ => UnsupportedReasons.IS_UNSUPPORTED
      }
    }
  }

  def getUnsupportedExecSummaryRecord(execId: Long): Seq[UnsupportedExecSummary] = {
    // Get the custom reason if it exists
    val execUnsupportedReason = determineUnsupportedReason(unsupportedExecReason,
      getUnsupportedReason)

    // Initialize the result with the exec summary
    val res = ArrayBuffer(UnsupportedExecSummary(sqlID, execId, exec, opType,
      execUnsupportedReason, getOpAction))

    // TODO: Should we iterate on exec children?
    // add the unsupported expressions to the results, if there are any custom reasons add them
    // to the result appropriately
    if (unsupportedExprs.nonEmpty) {
      val exprKnownReason = execUnsupportedReason match {
        case UnsupportedReasons.CONTAINS_UDF => UnsupportedReasons.IS_UDF
        case UnsupportedReasons.CONTAINS_DATASET => UnsupportedReasons.IS_DATASET
        case UnsupportedReasons.UNSUPPORTED_IO_FORMAT => UnsupportedReasons.UNSUPPORTED_IO_FORMAT
        case _ => UnsupportedReasons.IS_UNSUPPORTED
      }

      unsupportedExprs.foreach { expr =>
        val exprUnsupportedReason = determineUnsupportedReason(expr.unsupportedReason,
          exprKnownReason)
        res += UnsupportedExecSummary(sqlID, execId, expr.exprName, OpTypes.Expr,
          exprUnsupportedReason, getOpAction, isExpression = true)
      }
    }
    res
  }
}

object ExecInfo {
  // Used to create an execInfo without recalculating the dataSet or Udf.
  // This is helpful when we know that node description may contain some patterns that can be
  // mistakenly identified as UDFs
  def createExecNoNode(sqlID: Long,
      exec: String,
      expr: String,
      speedupFactor: Double,
      duration: Option[Long],
      nodeId: Long,
      opType: OpTypes.OpType,
      isSupported: Boolean,
      children: Option[Seq[ExecInfo]], // only one level deep
      stages: Set[Int] = Set.empty,
      shouldRemove: Boolean = false,
      unsupportedExecReason: String = "",
      unsupportedExprs: Seq[UnsupportedExpr] = Seq.empty,
      dataSet: Boolean = false,
      udf: Boolean = false): ExecInfo = {
    // Set the ignoreFlag
    // 1- we ignore any exec with UDF
    // 2- we ignore any exec with dataset
    // 3- Finally we ignore any exec matching the lookup table
    // if the opType is RDD, then we automatically enable the datasetFlag
    val finalDataSet = dataSet || opType.equals(OpTypes.ReadRDD)
    val shouldIgnore = udf || finalDataSet || ExecHelper.shouldIgnore(exec)
    val removeFlag = shouldRemove || ExecHelper.shouldBeRemoved(exec)
    val finalOpType = if (udf) {
      OpTypes.UDF
    } else if (dataSet) {
      // we still want the ReadRDD to stand out from other RDDs. So, we use the original
      // dataSetFlag
      OpTypes.DataSet
    } else {
      opType
    }
    // Set the supported Flag
    val supportedFlag = isSupported && !udf && !finalDataSet
    ExecInfo(
      sqlID,
      exec,
      expr,
      speedupFactor,
      duration,
      nodeId,
      finalOpType,
      supportedFlag,
      children,
      stages,
      removeFlag,
      unsupportedExecReason,
      unsupportedExprs,
      finalDataSet,
      udf,
      shouldIgnore
    )
  }

  def apply(
      node: SparkPlanGraphNode,
      sqlID: Long,
      exec: String,
      expr: String,
      speedupFactor: Double,
      duration: Option[Long],
      nodeId: Long,
      isSupported: Boolean,
      children: Option[Seq[ExecInfo]], // only one level deep
      stages: Set[Int] = Set.empty,
      shouldRemove: Boolean = false,
      unsupportedExecReason:String = "",
      unsupportedExprs: Seq[UnsupportedExpr] = Seq.empty,
      dataSet: Boolean = false,
      udf: Boolean = false,
      opType: OpTypes.OpType = OpTypes.Exec): ExecInfo = {
    // Some execs need to be trimmed such as "Scan"
    // Example: Scan parquet . ->  Scan parquet.
    // scan nodes needs trimming
    val nodeName = node.name.trim
    // we don't want to mark the *InPandas and ArrowEvalPythonExec as unsupported with UDF
    val containsUDF = udf || ExecHelper.isUDF(node)
    // check is the node has a dataset operations and if so change to not supported
    val rddCheckRes = RDDCheckHelper.isDatasetOrRDDPlan(nodeName, node.desc)
    val ds = dataSet || rddCheckRes.isRDD

    // if the expression is RDD because of the node name, then we do not want to add the
    // unsupportedExpressions because it becomes bogus.
    val finalUnsupportedExpr = if (rddCheckRes.nodeDescRDD) {
      Seq.empty[UnsupportedExpr]
    } else {
      unsupportedExprs
    }
    createExecNoNode(
      sqlID,
      exec,
      expr,
      speedupFactor,
      duration,
      nodeId,
      opType,
      isSupported,
      children,
      stages,
      shouldRemove,
      unsupportedExecReason,
      finalUnsupportedExpr,
      ds,
      containsUDF
    )
  }
}

case class PlanInfo(
    appID: String,
    sqlID: Long,
    sqlDesc: String,
    execInfo: Seq[ExecInfo]
)

object SQLPlanParser extends Logging {

  val equiJoinRegexPattern = """\[([\w#, +*\\\-\.<>=$\`\(\)]+\])""".r

  val functionPattern = """(\w+)\(.*\)""".r

  val functionPrefixPattern = """(\w+)\(""".r // match words preceded by parenthesis

  val windowFunctionPattern = """(\w+)\(""".r

  val aggregatePrefixes = Set(
    "finalmerge_", // DB specific prefix for final merge agg functions
    "partial_",    // used for partials
    "merge_"       // Used for partial merge
  )

  val ignoreExpressions = Set("any", "cast", "ansi_cast", "decimal", "decimaltype", "every",
    "some",
    "list",
    // some ops turn into literals and they should not cause any fallbacks
    "current_database", "current_user", "current_timestamp",
    // ArrayBuffer is a Scala function and may appear in some of the JavaRDDs/UDAFs)
    "arraybuffer", "arraytype",
    // TODO: we may need later to consider that structs indicate unsupported data types,
    //  but for now we just ignore it to avoid false positives.
    //  StructType and StructField showup from expressions like ("from_json").
    //  We do not want them to appear as independent expressions.
    "structfield", "structtype")

  // As RAPIDS plugin rev 2b09372, it only supports parse_url(*,HOST|PROTOCOL|QUERY|PATH[,*]).
  // the following partToExtract parse_url(*,REF|FILE|AUTHORITY|USERINFO[,*]) are not supported
  val unsupportedParseURLParts = Set("FILE", "REF", "AUTHORITY", "USERINFO")
  // define a pattern to identify whether a certain string contains the unsupported extractParts of
  // the parse_url
  val regExParseURLPart =
    s"(?i)parse_url\\(.*,\\s*(${unsupportedParseURLParts.mkString("|")})(?:\\s*,.*)*\\)".r

  /**
   * This function is used to create a set of nodes that should be skipped while parsing the Execs
   * of a specific node.
   * When a reused expression appears in a SparkPlan, the sparkPlanGraph constructed from the
   * eventlog will have duplicates for all the ancestors of the exec (i.e., "ReusedExchange").
   * This leads to a gap in the GPU speedups across different platforms which generate the graph
   * without duplicates.
   * A work around is to detect all the duplicates nodes so that we can mark them as "shouldRemove".
   * If a wholeGen node has all the children labeled as ancestor of reused-exchange, then the
   * wholeGen should also be added to the same set.
   * @param planGraph the graph generated for a spark plan. This graph can be different depending on
   *                  the spark-sql jar version used to construct the graph from existing eventlogs.
   * @return a set of node IDS to be skipped during the aggregation of the speedups.
   */
  private def buildSkippedReusedNodesForPlan(planGraph: SparkPlanGraph): Set[Long] = {
    def findNodeAncestors(planGraph: SparkPlanGraph,
        graphNode: SparkPlanGraphNode): mutable.Set[Long] = {
      // Given a node in the graph, this function is to go backward to find all the ancestors of
      // the node including the node, itself.
      val visited = mutable.Set[Long](graphNode.id)
      val q1 = mutable.Queue[Long](graphNode.id)
      while (q1.nonEmpty) {
        val curNode = q1.dequeue()
        val allSinkEdges = planGraph.edges
          .filter(e => e.toId == curNode)
          .filterNot(e => visited.contains(e.fromId))
        for (currEdge <- allSinkEdges) {
          q1.enqueue(currEdge.fromId)
          visited += currEdge.fromId
        }
      }
      // Loop on the wholeGen to see if any of them is covered by the ancestors path.
      // This implies that the wholeStageCodeGen is also reused.
      // Note that the following logic can be moved to the WholeStageCodegen parser. Handling the
      // logic here has advantages:
      //   1- no need to append to the final set
      //   2- keep the logic in one place.
      val allStageNodes = planGraph.nodes.filter(
        stageNode => stageNode.name.contains("WholeStageCodegen"))
      allStageNodes.filter { n =>
        n.asInstanceOf[SparkPlanGraphCluster].nodes.forall(c => visited.contains(c.id))
      }.foreach(wNode => visited += wNode.id)
      visited
    }

    // create a list of all the candidate leaf nodes. This includes wholeStageCodeGen.
    val candidateNodes = planGraph.allNodes.filter(n => reuseExecs.contains(n.name))
    candidateNodes.flatMap(findNodeAncestors(planGraph, _)).toSet
  }

  def parseSQLPlan(
      appID: String,
      planInfo: SparkPlanInfo,
      sqlID: Long,
      sqlDesc: String,
      checker: PluginTypeChecker,
      app: AppBase): PlanInfo = {
    val planGraph = ToolsPlanGraph(planInfo)
    // Find all the node graphs that should be excluded and send it to the parsePlanNode
    val excludedNodes = buildSkippedReusedNodesForPlan(planGraph)
    // we want the sub-graph nodes to be inside of the wholeStageCodeGen so use nodes
    // vs allNodes
    val execInfos = planGraph.nodes.flatMap { node =>
      parsePlanNode(node, sqlID, checker, app, reusedNodeIds = excludedNodes)
    }
    PlanInfo(appID, sqlID, sqlDesc, execInfos)
  }

  def getStagesInSQLNode(node: SparkPlanGraphNode, app: AppBase): Set[Int] = {
    val nodeAccums = node.metrics.map(_.accumulatorId)
    nodeAccums.flatMap(app.accumManager.getAccStageIds).toSet
  }

  // Set containing execs that refers to other expressions. We need this to be a list to allow
  // appending more execs in teh future as necessary.
  // Note that Spark graph may create duplicate nodes when any of the following execs exists.
  private val reuseExecs = Set("ReusedExchange")

  def parsePlanNode(
      node: SparkPlanGraphNode,
      sqlID: Long,
      checker: PluginTypeChecker,
      app: AppBase,
      reusedNodeIds: Set[Long]
  ): Seq[ExecInfo] = {
    // Avoid counting duplicate nodes. We mark them as shouldRemove to neutralize their impact on
    // speedups.
    val isDupNode = reusedNodeIds.contains(node.id)
    // Normalize the execName by removing the trailing '$' character, if present.
    // This is necessary because in Scala, the '$' character is often appended to the names of
    // generated classes or objects, and we want to match the base name regardless of this suffix.
    val normalizedNodeName = node.name.stripSuffix("$")
    if (isDupNode) {
      // log that information. This should not cause significant increase in log size.
      logDebug(s"Marking [sqlID = ${sqlID}, node = ${normalizedNodeName}] as shouldRemove. " +
        s"Reason: duplicate - ancestor of ReusedExchange")
    }
    if (normalizedNodeName.contains("WholeStageCodegen")) {
      // this is special because it is a SparkPlanGraphCluster vs SparkPlanGraphNode
      WholeStageExecParser(node.asInstanceOf[SparkPlanGraphCluster], checker, sqlID, app,
        reusedNodeIds).parse
    } else {
      val execInfos = try {
        normalizedNodeName match {
          case "AggregateInPandas" =>
            AggregateInPandasExecParser(node, checker, sqlID).parse
          case "ArrowEvalPython" =>
            ArrowEvalPythonExecParser(node, checker, sqlID).parse
          case "BatchScan" =>
            BatchScanExecParser(node, checker, sqlID, app).parse
          case "BroadcastExchange" =>
            BroadcastExchangeExecParser(node, checker, sqlID, app).parse
          case "BroadcastHashJoin" =>
            BroadcastHashJoinExecParser(node, checker, sqlID).parse
          case "BroadcastNestedLoopJoin" =>
            BroadcastNestedLoopJoinExecParser(node, checker, sqlID).parse
          case "CartesianProduct" =>
            CartesianProductExecParser(node, checker, sqlID).parse
          case "Coalesce" =>
            CoalesceExecParser(node, checker, sqlID).parse
          case "CollectLimit" =>
            CollectLimitExecParser(node, checker, sqlID).parse
          case "CustomShuffleReader" | "AQEShuffleRead" =>
            CustomShuffleReaderExecParser(node, checker, sqlID).parse
          case "Exchange" =>
            ShuffleExchangeExecParser(node, checker, sqlID, app).parse
          case "Expand" =>
            ExpandExecParser(node, checker, sqlID).parse
          case "Filter" =>
            FilterExecParser(node, checker, sqlID).parse
          case "FlatMapGroupsInPandas" =>
            FlatMapGroupsInPandasExecParser(node, checker, sqlID).parse
          case "Generate" =>
            GenerateExecParser(node, checker, sqlID).parse
          case "GlobalLimit" =>
            GlobalLimitExecParser(node, checker, sqlID).parse
          case "HashAggregate" =>
            HashAggregateExecParser(node, checker, sqlID, app).parse
          case "LocalLimit" =>
            LocalLimitExecParser(node, checker, sqlID).parse
          case "InMemoryTableScan" =>
            InMemoryTableScanExecParser(node, checker, sqlID).parse
          case i if DataWritingCommandExecParser.isWritingCmdExec(i) =>
            DataWritingCommandExecParser.parseNode(node, checker, sqlID)
          case "MapInPandas" =>
            MapInPandasExecParser(node, checker, sqlID).parse
          case "ObjectHashAggregate" =>
            ObjectHashAggregateExecParser(node, checker, sqlID, app).parse
          case "Project" =>
            ProjectExecParser(node, checker, sqlID).parse
          case "PythonMapInArrow" | "MapInArrow" =>
            PythonMapInArrowExecParser(node, checker, sqlID).parse
          case "Range" =>
            RangeExecParser(node, checker, sqlID).parse
          case "Sample" =>
            SampleExecParser(node, checker, sqlID).parse
          case "ShuffledHashJoin" =>
            ShuffledHashJoinExecParser(node, checker, sqlID, app).parse
          case "Sort" =>
            SortExecParser(node, checker, sqlID).parse
          case s if ReadParser.isScanNode(s) =>
            FileSourceScanExecParser(node, checker, sqlID, app).parse
          case "SortAggregate" =>
            SortAggregateExecParser(node, checker, sqlID).parse
          case smj if SortMergeJoinExecParser.accepts(smj) =>
            SortMergeJoinExecParser(node, checker, sqlID).parse
          case "SubqueryBroadcast" =>
            SubqueryBroadcastExecParser(node, checker, sqlID, app).parse
          case sqe if SubqueryExecParser.accepts(sqe) =>
            SubqueryExecParser.parseNode(node, checker, sqlID, app)
          case "TakeOrderedAndProject" =>
            TakeOrderedAndProjectExecParser(node, checker, sqlID).parse
          case "Union" =>
            UnionExecParser(node, checker, sqlID).parse
          case "Window" =>
            WindowExecParser(node, checker, sqlID).parse
          case "WindowInPandas" =>
            WindowInPandasExecParser(node, checker, sqlID).parse
          case "WindowGroupLimit" =>
            WindowGroupLimitParser(node, checker, sqlID).parse
          case wfe if WriteFilesExecParser.accepts(wfe) =>
            WriteFilesExecParser(node, checker, sqlID).parse
          case _ =>
            // Execs that are members of reuseExecs (i.e., ReusedExchange) should be marked as
            // supported but with shouldRemove flag set to True.
            // Setting the "shouldRemove" is handled at the end of the function.
            ExecInfo(node, sqlID, normalizedNodeName, expr = "", 1, duration = None, node.id,
              isSupported = reuseExecs.contains(normalizedNodeName), None)
        }
      } catch {
        // Error parsing expression could trigger an exception. If the exception is not handled,
        // the application will be skipped. We need to suppress exceptions here to avoid
        // sacrificing the entire app analysis.
        // Note that:
        //  - The exec will be considered unsupported.
        //  - No need to add the SQL to the failed SQLs, because this will cause the app to be
        //    labeled as "Not Applicable" which is not preferred at this point.
        case NonFatal(e) =>
          logWarning(s"Unexpected error parsing plan node ${normalizedNodeName}. " +
          s" sqlID = ${sqlID}", e)
          ExecInfo(node, sqlID, normalizedNodeName, expr = "", 1, duration = None, node.id,
            isSupported = false, None)
      }
      val stagesInNode = getStagesInSQLNode(node, app)
      execInfos.setStages(stagesInNode)
      // shouldRemove is set to true if the exec is a member of "execsToBeRemoved" or if the node
      // is a duplicate
      execInfos.setShouldRemove(isDupNode)
      // Set the custom reasons for unsupported execs
      val unsupportedExecsReason = checker.getNotSupportedExecsReason(execInfos.exec)
      execInfos.setUnsupportedExecReason(unsupportedExecsReason)
      Seq(execInfos)
    }
  }

  /**
   * This function is used to calculate an average speedup factor. The input
   * is assumed to an array of doubles where each element is >= 1. If the input array
   * is empty we return 1 because we assume we don't slow things down. Generally
   * the array shouldn't be empty, but if there is some weird case we don't want to
   * blow up, just say we don't speed it up.
   */
  def averageSpeedup(arr: Seq[Double]): Double = {
    if (arr.isEmpty) {
      1.0
    } else {
      val sum = arr.sum
      ToolUtils.calculateAverage(sum, arr.size, 2)
    }
  }

  /**
   * Get the total duration by finding the accumulator with the largest value.
   * This is because each accumulator has a value and an update. As tasks end
   * they just update the value = value + update, so the largest value will be
   * the duration.
   */
  def getTotalDuration(accumId: Option[Long], app: AppBase): Option[Long] = {
    accumId match {
      case Some(x) => app.accumManager.getMaxStageValue(x)
      case _ => None
    }
  }

  def getDriverTotalDuration(accumId: Option[Long], app: AppBase): Option[Long] = {
    val accums = accumId.flatMap(id => app.driverAccumMap.get(id))
      .getOrElse(ArrayBuffer.empty)
    val accumValues = accums.map(_.value)
    val maxDuration = if (accumValues.isEmpty) {
      None
    } else {
      Some(accumValues.max)
    }
    maxDuration
  }

  private def ignoreExpression(expr:String): Boolean = {
    ignoreExpressions.contains(expr.toLowerCase)
  }

  private def getFunctionName(functionPattern: Regex, expr: String): Option[String] = {
    val funcName = functionPattern.findFirstMatchIn(expr) match {
      case Some(func) =>
        val func1 = func.group(1)
        // There are some functions which are not expressions hence should be ignored.
        // For example: In the physical plan cast is usually presented as function call
        // `cast(value#9 as date)`. We add other function names to the result.
        if (!ignoreExpression(func1)) {
          Some(func1)
        } else {
          None
        }
      case _ => logDebug(s"Incorrect expression - $expr")
        None
    }
    funcName
  }

  // This method aims at doing some common processing to an expression before
  // we start parsing it. For example, some special handling is required for some functions.
  def processSpecialFunctions(expr: String): String = {
    // For parse_url, we only support parse_url(*,HOST|PROTOCOL|QUERY|PATH[,*]).
    // So we want to be able to define that parse_url(*,REF|FILE|AUTHORITY|USERINFO[,*])
    // is not supported.

    // The following regex uses forward references to find matches for parse_url(*)
    // we need to use forward references because otherwise multiple occurrences will be matched
    // only once.
    // https://stackoverflow.com/questions/47162098/is-it-possible-to-match-nested-brackets-with-a-
    // regex-without-using-recursion-or/47162099#47162099
    // example parse_url:
    // Project [url_col#7, parse_url(url_col#7, HOST, false) AS HOST#9,
    //          parse_url(url_col#7, QUERY, false) AS QUERY#10]
    val parseURLPattern = ("parse_url(?=\\()(?:(?=.*?\\((?!.*?\\1)(.*\\)(?!.*\\2).*))(?=.*?\\)" +
      "(?!.*?\\2)(.*)).)+?.*?(?=\\1)[^(]*(?=\\2$)").r
    val allMatches = parseURLPattern.findAllMatchIn(expr)
    if (allMatches.nonEmpty) {
      var newExpr = expr
      allMatches.foreach { parse_call =>
        // iterate on all matches replacing parse_url by parse_url_{parttoextract} if any
        // note that we do replaceFirst because we want to map 1-to-1 and the order does
        // not matter here.
        val matched = parse_call.matched
        val extractPart = regExParseURLPart.findFirstMatchIn(matched).map(_.group(1))
        if (extractPart.isDefined) {
          val replacedParseClass =
            matched.replaceFirst("parse_url\\(", s"parse_url_${extractPart.get.toLowerCase}(")
          newExpr = newExpr.replace(matched, replacedParseClass)
        }
      }
      newExpr
    } else {
      expr
    }
  }

  private def getAllFunctionNames(regPattern: Regex, expr: String,
      groupInd: Int = 1, isAggr: Boolean = true): Set[String] = {
    // Returns all matches in an expression. This can be used when the SQL expression is not
    // tokenized.
    val newExpr = processSpecialFunctions(expr)

    // first get all the functionNames
    val exprss =
      regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet

    // For aggregate expressions we want to process the results to remove the prefix
    // DB: remove the "^partial_" and "^finalmerge_" prefixes
    // TODO:
    //    for performance sake, we can turn off the aggregate processing by enabling it only
    //    when needed. However, for now, we always do this processing until we are confident we know
    //    the correct place to turn on/off that flag.we can use the argument isAgg only when needed
    val results = if (isAggr) {
      exprss.collect {
        case func =>
          aggregatePrefixes.find(func.startsWith(_)).map(func.replaceFirst(_, "")).getOrElse(func)
      }
    } else {
      exprss
    }
    results.filterNot(ignoreExpression(_))
  }

  def parseProjectExpressions(exprStr: String): Array[String] = {
    // Project [cast(value#136 as string) AS value#144, CEIL(value#136) AS CEIL(value)#143L]
    // This is to split the string such that only function names are extracted. The pattern is
    // such that function name is succeeded by `(`. We use regex to extract all the function names
    // below:
    getAllFunctionNames(functionPrefixPattern, exprStr).toArray
  }

  // This parser is used for SortAggregateExec, HashAggregateExec and ObjectHashAggregateExec
  def parseAggregateExpressions(exprStr: String): Array[String] = {
    val parsedExpressions = ArrayBuffer[String]()
    // (keys=[num#83], functions=[partial_collect_list(letter#84, 0, 0), partial_count(letter#84)])
    // Currently we only parse the functions expressions.
    // "Keys" parsing is disabled for now because we won't be able to detect the types

    // A map (value -> parseEnabled) between the group and the parsing metadata
    val patternMap = Map(
      "functions" -> true,
      "keys" -> false
    )
    // It won't hurt to define a pattern that is neutral to the order of the functions/keys.
    // This can avoid mismatches when exprStr comes in the fom of (functions=[], keys=[]).
    val pattern = """^\((keys|functions)=\[(.*)\]\s*,\s*(keys|functions)=\[(.*)\]\s*\)$""".r
    // Iterate through the matches and exclude disabled clauses
    pattern.findAllMatchIn(exprStr).foreach { m =>
      // The matching groups are:
      // 0 -> entire expression
      // 1 -> "keys"; 2 -> keys' expression
      // 3 -> "functions"; 4 -> functions' expression
      Array(1, 3).foreach { group_ind =>
        val group_value = m.group(group_ind)
        if (patternMap.getOrElse(group_value, false)) {
          val clauseExpr = m.group(group_ind + 1)
          // No need to split the expr any further because we are only interested in function names
          val used_functions = getAllFunctionNames(functionPrefixPattern, clauseExpr)
          parsedExpressions ++= used_functions
        }
      }
    }
    parsedExpressions.distinct.toArray
  }

  def parseWindowExpressions(exprStr:String): Array[String] = {
    val parsedExpressions = ArrayBuffer[String]()
    // [sum(cast(level#30 as bigint)) windowspecdefinition(device#29, id#28 ASC NULLS FIRST,
    // specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum#35L,
    // row_number() windowspecdefinition(device#29, id#28 ASC NULLS FIRST, specifiedwindowframe
    // (RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#41], [device#29],
    // [id#28 ASC NULLS FIRST]

    // This splits the string to get only the expressions in WindowExec. So we first split the
    // string on closing bracket ] and get the first element from the array. This is followed
    // by removing the first and last parenthesis and removing the cast as it is not an expr.
    // Lastly we split the string by keyword windowsspecdefinition so that each array element
    // except the last element contains one window aggregate function.
    // sum(level#30 as bigint))
    // (device#29, id#28 ASC NULLS FIRST, .....  AS sum#35L, row_number()
    // (device#29, id#28 ASC NULLS FIRST, ......  AS row_number#41
    val windowExprs = exprStr.split("(?<=\\])")(0).
        trim.replaceAll("""^\[+""", "").replaceAll("""\]+$""", "").
        replaceAll("cast\\(", "").split("windowspecdefinition").map(_.trim)

    // Get function name from each array element except the last one as it doesn't contain
    // any window function
    if (windowExprs.nonEmpty) {
      windowExprs.dropRight(1).foreach { windowExprString =>
        val windowFunc = windowFunctionPattern.findAllIn(windowExprString).toList
        val expr = windowFunc.lastOption.getOrElse("")
        val functionName = getFunctionName(windowFunctionPattern, expr)
        functionName match {
          case Some(func) => parsedExpressions += func
          case _ => // NO OP
        }
      }
    }
    parsedExpressions.distinct.toArray
  }

  def parseWindowGroupLimitExpressions(exprStr: String): Array[String] = {
    // [category#16], [amount#17 DESC NULLS LAST], dense_rank(amount#17), 2, Final

    // This splits the string to get only the ranking expression in WindowGroupLimitExec.
    // We split the string on comma and get the third element from the array.
    // dense_rank(amount#17)
    val rankLikeExpr = exprStr.split(", ").lift(2).map(_.trim)
    // Get function name from WindowExpression
    rankLikeExpr.flatMap { rankExpr =>
      windowFunctionPattern.findFirstIn(rankExpr).flatMap { rankLikeFunc =>
        getFunctionName(windowFunctionPattern, rankLikeFunc)
      }
    }.toArray
  }

  def parseExpandExpressions(exprStr: String): Array[String] = {
    // [List(x#1564, hex(y#1455L)#1565, CEIL(z#1456)#1566L, 0),
    // List(x#1564, hex(y#1455L)#1565, null, 1), .......
    // , spark_grouping_id#1567L]
    // For Spark320+, the expandExpressions has different format
    //  [[x#23, CEIL(y#11L)#24L, hex(cast(z#12 as bigint))#25, 0]
    // Parsing:
    // The goal is to extract all valid functions from the expand.
    // It is important to take the following into considerations:
    //  - Some values can be NULLs. That's why we cannot limit the extract to the first row.
    //  - Nested brackets/parenthesis makes it challenging to use regex that contains
    //    brackets/parenthesis to extract expressions.
    // The implementation Use regex to extract all function names and return distinct set of
    // function names.
    // This implementation is 1 line implementation, but it can be a memory/time bottleneck.
    getAllFunctionNames(functionPrefixPattern, exprStr).toArray
  }

  def parseTakeOrderedExpressions(exprStr: String): Array[String] = {
    val parsedExpressions = ArrayBuffer[String]()
    // (limit=2, orderBy=[FLOOR(z#796) ASC NULLS FIRST,
    // CEIL(y#795L) ASC NULLS FIRST,y#1588L ASC NULLS FIRST], output=[x#794,y#796L,z#795])
    val pattern = """orderBy=\[([\w#, \(\)]+\])""".r
    val orderString = pattern.findFirstMatchIn(exprStr)
    // This is to split multiple column names in orderBy clause of parse TakeOrderedAndProjectExec.
    // First we remove orderBy from the string and then split the resultant string.
    // The string is split on delimiter containing FIRST, OR LAST, which is the last string
    // of each column in this Exec that produces an array containing
    // column names. Finally we remove the parentheses from the beginning and end to get only
    // the expressions. Result will be as below.
    // Array(FLOOR(z#796) ASC NULLS FIRST,, CEIL(y#795L) ASC NULLS FIRST)
    if (orderString.isDefined) {
      val parenRemoved = orderString.get.toString.replaceAll("orderBy=", "").
        split("(?<=FIRST,)|(?<=LAST,)").map(_.trim).map(
        _.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
      parenRemoved.foreach { expr =>
        val functionName = getFunctionName(functionPattern, expr)
        functionName match {
          case Some(func) => parsedExpressions += func
          case _ => // NO OP
        }
      }
    }
    parsedExpressions.distinct.toArray
  }

  def parseGenerateExpressions(exprStr: String): Array[String] = {
    // Get the function names from the GenerateExec. The GenerateExec has the following format:
    // 1. Generate explode(arrays#1306), [id#1304], true, [col#1426]
    // 2. Generate json_tuple(values#1305, Zipcode, ZipCodeType, City), [id#1304],
    // false, [c0#1407, c1#1408, c2#1409]
    getAllFunctionNames(functionPrefixPattern, exprStr).toArray
  }

  private def addFunctionNames(exprs: String, parsedExpressions: ArrayBuffer[String]): Unit = {
    val functionNames = getAllFunctionNames(functionPrefixPattern, exprs).toArray
    functionNames.foreach(parsedExpressions += _)
  }

  // This parser is used for BroadcastHashJoin, ShuffledHashJoin and SortMergeJoin
  def parseEquijoinsExpressions(exprStr: String): (Array[String], Boolean) = {
    // ShuffledHashJoin [name#11, CEIL(DEPT#12)], [name#28, CEIL(DEPT_ID#27)], Inner, BuildLeft
    // SortMergeJoin [name#11, CEIL(dept#12)], [name#28, CEIL(dept_id#27)], Inner
    // BroadcastHashJoin [name#11, CEIL(dept#12)], [name#28, CEIL(dept_id#27)], Inner,
    // BuildRight, false
    // BroadcastHashJoin exprString: [i_item_id#56], [i_item_id#56#86], ExistenceJoin(exists#86),
    // BuildRight
    val parsedExpressions = ArrayBuffer[String]()
    // Get all the join expressions and split it with delimiter :: so that it could be used to
    // parse function names (if present) later.
    val joinExprs = equiJoinRegexPattern.findAllMatchIn(exprStr).mkString("::")
    // Get joinType and buildSide(if applicable)
    val joinParams = equiJoinRegexPattern.replaceAllIn(
      exprStr, "").split(",").map(_.trim).filter(_.nonEmpty)
    val joinType = if (joinParams.nonEmpty) {
      joinParams(0).split("\\(")(0).trim
    } else {
      ""
    }

    // This is to differentiate between SortMergeJoin and other Joins. SortMergeJoin has no
    // buildSide.
    val possibleBuildSides = Set(BuildSide.BuildLeft, BuildSide.BuildRight)
    val buildSide = joinParams.find(possibleBuildSides.contains).getOrElse("")
    val isSortMergeJoin = buildSide.isEmpty

    val joinCondition = joinParams.dropWhile(param =>
          possibleBuildSides.contains(param) || param.contains(joinType)).map(_.trim).mkString(",")
    // Get individual expressions which is later used to get the function names.
    val colExpressions = joinExprs.split("::").map(_.trim).map(
      _.replaceAll("""^\[+|\]+$""", "")).map(_.split(",")).flatten.map(_.trim)
    colExpressions.foreach(expr => addFunctionNames(expr, parsedExpressions))
    if (joinCondition.nonEmpty) {
      val conditionExprs = parseConditionalExpressions(joinCondition)
      conditionExprs.foreach(parsedExpressions += _)
    }
    // Check corner cases for SortMergeJoin
    val isSortMergeSupported = !(isSortMergeJoin &&
        joinCondition.nonEmpty && isSMJConditionUnsupported(joinCondition))

    (parsedExpressions.distinct.toArray, equiJoinSupportedTypes(buildSide, joinType)
        && isSortMergeSupported)
  }

  def isSMJConditionUnsupported(joinCondition: String): Boolean = {
    // TODO: This is a temporary solution to check for unsupported conditions in SMJ.
    // Remove these checks once below issues are resolved:
    // https://github.com/NVIDIA/spark-rapids/issues/11213
    // https://github.com/NVIDIA/spark-rapids/issues/11214

    // Regular expressions for corner cases that mark the SMJ as not supported
    val castAsDateRegex = """(?i)\bcast\(\s*.+\s+as\s+date\s*\)""".r
    val lowerInRegex = """(?i)\blower\(\s*.+\s*\)\s+in\s*(\((?:[^\(\)]*|.*)\)|\bsubquery#\d+\b)""".r

    // Split the joinCondition by logical operators (AND/OR)
    val conditions = joinCondition.split("\\s+(?i)(AND|OR)\\s+").map(_.trim)
    conditions.exists { condition =>
      // Check for the specific corner cases that mark the SMJ as not supported
      castAsDateRegex.findFirstIn(condition).isDefined ||
          lowerInRegex.findFirstIn(condition).isDefined
    }
  }

  def parseNestedLoopJoinExpressions(exprStr: String): (Array[String], Boolean) = {
    // BuildRight, LeftOuter, ((CEIL(cast(id1#1490 as double)) <= cast(id2#1496 as bigint))
    // AND (cast(id1#1490 as bigint) < CEIL(cast(id2#1496 as double))))
    // Get joinType and buildSide by splitting the input string.
    val nestedLoopParameters = exprStr.split(",", 3)
    val buildSide = nestedLoopParameters(0).trim
    val joinType = nestedLoopParameters(1).trim

    // Check if condition present on join columns else return empty array
    val parsedExpressions = if (nestedLoopParameters.size > 2) {
      parseConditionalExpressions(exprStr)
    } else {
      Array[String] ()
    }
    (parsedExpressions, nestedLoopJoinSupportedTypes(buildSide, joinType))
  }

  private def isJoinTypeSupported(joinType: String): Boolean = {
    // There is caveat for FullOuter join for equiJoins.
    // FullOuter join id not supported with struct keys but we are sending true for all
    // data structures.
    joinType match {
      case JoinType.Cross => true
      case JoinType.Inner => true
      case JoinType.LeftSemi => true
      case JoinType.FullOuter => true
      case JoinType.LeftOuter => true
      case JoinType.RightOuter => true
      case JoinType.LeftAnti => true
      case JoinType.ExistenceJoin => true
      case _ => false
    }
  }

  private def equiJoinSupportedTypes(buildSide: String, joinType: String): Boolean = {
    val joinTypeSupported = isJoinTypeSupported(joinType)
    // We are checking if the joinType is supported for the buildSide. If the buildSide is not
    // in the supportedBuildSides map then we are assuming that the
    // joinType is supported for that buildSide.
    val buildSideSupported = BuildSide.supportedBuildSides.getOrElse(
      buildSide, JoinType.allsupportedJoinType).contains(joinType)

    joinTypeSupported && buildSideSupported
  }

  private def nestedLoopJoinSupportedTypes(buildSide: String, joinType: String): Boolean = {
    // Full Outer join not supported in BroadcastNestedLoopJoin
    val joinTypeSupported = if (joinType != JoinType.FullOuter) {
      isJoinTypeSupported(joinType)
    } else {
      false
    }
    // This is from GpuBroadcastNestedLoopJoinMeta.tagPlanForGpu where join is
    // not supported on GPU if below condition is met.
    val buildSideNotSupported = if (buildSide == BuildSide.BuildLeft) {
      joinType == JoinType.LeftOuter || joinType == JoinType.LeftSemi ||
        joinType == JoinType.LeftAnti
    } else if (buildSide == BuildSide.BuildRight) {
      joinType == JoinType.RightOuter
    } else {
      false
    }
    joinTypeSupported && !buildSideNotSupported
  }

  def parseSortExpressions(exprStr: String): Array[String] = {
    val parsedExpressions = ArrayBuffer[String]()
    // Sort [round(num#126, 0) ASC NULLS FIRST, letter#127 DESC NULLS LAST], true, 0
    val pattern = """\[([\w#, \(\)]+\])""".r
    val sortString = pattern.findFirstMatchIn(exprStr)
    // This is to split multiple column names in SortExec. Project may have a function on a column.
    // The string is split on delimiter containing FIRST, OR LAST, which is the last string
    // of each column in SortExec that produces an array containing
    // column names. Finally we remove the parentheses from the beginning and end to get only
    // the expressions. Result will be as below.
    // paranRemoved = Array(round(num#7, 0) ASC NULLS FIRST,, letter#8 DESC NULLS LAST)
    if (sortString.isDefined) {
      val paranRemoved = sortString.get.toString.split("(?<=FIRST,)|(?<=LAST,)").
          map(_.trim).map(_.replaceAll("""^\[+""", "").replaceAll("""\]+$""", ""))
      paranRemoved.foreach { expr =>
        val functionName = getFunctionName(functionPattern, expr)
        functionName match {
          case Some(func) => parsedExpressions += func
          case _ => // NO OP
        }
      }
    }
    parsedExpressions.distinct.toArray
  }

  def parseFilterExpressions(exprStr: String): Array[String] = {
    // Filter ((isnotnull(s_state#68) AND (s_state#68 = TN)) OR (hex(cast(value#0 as bigint)) = B))
    parseConditionalExpressions(exprStr)
  }

  // The scope is to extract expressions from a conditional expression.
  // Ideally, parsing conditional expressions needs to build a tree. The current implementation is
  // a simplified version that does not accurately pickup the LHS and RHS of each predicate.
  // Instead, it extracts function names, and expressions in best effort.
  def parseConditionalExpressions(exprStr: String): Array[String] = {
    // Captures any word followed by '('
    // isnotnull(, StringEndsWith(
    val functionsRegEx = """((\w+))\(""".r
    // Captures binary operators followed by '('
    // AND(, OR(, NOT(, =(, <(, >(
    val binaryOpsNoSpaceRegEx = """(^|\s+)((AND|OR|NOT|IN|=|<=>|<|>|>=|\++|-|\*+))(\(+)""".r
    // Capture reserved words at the end of expression. Those should be considered literal
    // and hence are ignored.
    // Binary operators cannot be at the end of the string, or end of expression.
    // For example we know that the following AND is a literal value, not the operator AND.
    // So, we can filter that out from the results.
    //     PushedFilters: [IsNotNull(c_customer_id), StringEndsWith(c_customer_id,AND)]
    //     Filter (isnotnull(names#15) AND StartsWith(names#15, AND))
    // AND), AND$
    val nonBinaryOperatorsRegEx = """\s+((AND|OR|NOT|=|<=>|<|>|>=|\++|-|\*+))($|\)+)""".r
    // Capture all "("
    val parenthesisStartRegEx = """(\(+)""".r
    // Capture all ")"
    val parenthesisEndRegEx = """(\)+)""".r

    val parsedExpressions = ArrayBuffer[String]()
    var processedExpr = exprStr
    // Step-1: make sure that any binary operator won't mix up with functionNames
    // For example AND(, isnotNull()
    binaryOpsNoSpaceRegEx.findAllMatchIn(exprStr).foreach { m =>
      // replace things like 'AND(' with 'AND ('
      val str = s"${m.group(2)}\\(+"
      processedExpr = str.r.replaceAllIn(processedExpr, s"${m.group(2)} \\(")
    }

    // Step-2: Extract function names from the expression
    val functionMatches = functionsRegEx.findAllMatchIn(processedExpr)
    parsedExpressions ++=
      functionMatches.map(_.group(1)).filterNot(ignoreExpression(_))
    // remove all function calls. No need to keep them in the expression
    processedExpr = functionsRegEx.replaceAllIn(processedExpr, " ")

    // Step-3: remove literal variables so we do not treat them as Binary operators
    // Simply replace them by white space.
    processedExpr = nonBinaryOperatorsRegEx.replaceAllIn(processedExpr, " ")

    // Step-4: remove remaining parentheses '(', ')' and commas if we had functionCalls
    if (!functionMatches.isEmpty) {
      // remove ","
      processedExpr = processedExpr.replaceAll(",", " ")
    }
    processedExpr = parenthesisStartRegEx.replaceAllIn(processedExpr, " ")
    processedExpr = parenthesisEndRegEx.replaceAllIn(processedExpr, " ")

    // Step-5: now we should have a simplified expression that can be tokenized on white
    // space delimiter
    processedExpr.split("\\s+").foreach { token =>
      token match {
        case "NOT" => parsedExpressions += "Not"
        case "=" => parsedExpressions += "EqualTo"
        case "<=>" => parsedExpressions += "EqualNullSafe"
        case "<" => parsedExpressions += "LessThan"
        case ">" => parsedExpressions += "GreaterThan"
        case "<=" => parsedExpressions += "LessThanOrEqual"
        case ">=" => parsedExpressions += "GreaterThanOrEqual"
        case "<<" => parsedExpressions += "ShiftLeft"
        case ">>" => parsedExpressions += "ShiftRight"
        case ">>>" => parsedExpressions += "ShiftRightUnsigned"
        case "+" => parsedExpressions += "Add"
        case "-" => parsedExpressions += "Subtract"
        case "*" => parsedExpressions += "Multiply"
        case "IN" => parsedExpressions += "In"
        case "OR" | "||" =>
          // Some Spark2.x eventlogs may have '||' instead of 'OR'
          parsedExpressions += "Or"
        case "&&" | "AND" =>
          // Some Spark2.x eventlogs may have '&&' instead of 'AND'
          parsedExpressions += "And"
        case t if t.contains("#") =>
          // This is a variable name. Ignore those ones.
        case _ =>
          // anything else could be a literal value or we do not handle yet. Ignore them for now.
          logDebug(s"Unrecognized Token - $token")
      }
    }

    parsedExpressions.distinct.toArray
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy