All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.rapids.tool.ToolUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2021-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.rapids.tool

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import com.nvidia.spark.rapids.tool.planparser.SubqueryExecParser
import com.nvidia.spark.rapids.tool.profiling.ProfileUtils.replaceDelimiter
import com.nvidia.spark.rapids.tool.qualification.QualOutputWriter
import org.apache.maven.artifact.versioning.ComparableVersion

import org.apache.spark.internal.{config, Logging}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

object ToolUtils extends Logging {
  // List of recommended file-encodings on the GPUs.
  val SUPPORTED_ENCODINGS = Seq("UTF-8")
  // the prefix of keys defined by the RAPIDS plugin
  val PROPS_RAPIDS_KEY_PREFIX = "spark.rapids"
  // List of keys from sparkProperties that may point to RAPIDS jars.
  // Note that we ignore "spark.yarn.secondary.jars" for now as it does not include a full path.
  val POSSIBLE_JARS_PROPERTIES = Set("spark.driver.extraClassPath",
    "spark.executor.extraClassPath",
    "spark.yarn.dist.jars",
    "spark.repl.local.jars")
  val RAPIDS_JAR_REGEX = "(.*rapids-4-spark.*jar)|(.*cudf.*jar)".r

  // Add more entries to this lookup table as necessary.
  // There is no need to list all supported versions.
  private val lookupVersions = Map(
    "311" -> new ComparableVersion("3.1.1"), // default build version
    "320" -> new ComparableVersion("3.2.0"), // introduced reusedExchange
    "330" -> new ComparableVersion("3.3.0"), // used to check for memoryOverheadFactor
    "331" -> new ComparableVersion("3.3.1"),
    "340" -> new ComparableVersion("3.4.0"),  // introduces jsonProtocolChanges
    "350" -> new ComparableVersion("3.5.0")  // introduces windowGroupLimit
  )

  // Property to check the spark runtime version. We need this outside of test module as we
  // extend the support runtime for different platforms such as Databricks.
  lazy val sparkRuntimeVersion = {
    org.apache.spark.SPARK_VERSION
  }

  def compareVersions(verA: String, verB: String): Int = {
    Try {
      val verObjA = new ComparableVersion(verA)
      val verObjB = new ComparableVersion(verB)
      verObjA.compareTo(verObjB)
    } match {
      case Success(compRes) => compRes
      case Failure(t) =>
        logError(s"exception comparing two versions [$verA, $verB]", t)
        0
    }
  }

  def runtimeIsSparkVersion(refVersion: String): Boolean = {
    compareVersions(refVersion, sparkRuntimeVersion) == 0
  }

  def compareToSparkVersion(currVersion: String, lookupVersion: String): Int = {
    val lookupVersionObj = lookupVersions.get(lookupVersion).get
    val currVersionObj = new ComparableVersion(currVersion)
    currVersionObj.compareTo(lookupVersionObj)
  }

  def isSpark320OrLater(sparkVersion: String = sparkRuntimeVersion): Boolean = {
    compareToSparkVersion(sparkVersion, "320") >= 0
  }

  def isSpark330OrLater(sparkVersion: String = sparkRuntimeVersion): Boolean = {
    compareToSparkVersion(sparkVersion, "330") >= 0
  }

  def isSpark331OrLater(sparkVersion: String = sparkRuntimeVersion): Boolean = {
    compareToSparkVersion(sparkVersion, "331") >= 0
  }

  def isSpark340OrLater(sparkVersion: String = sparkRuntimeVersion): Boolean = {
    compareToSparkVersion(sparkVersion, "340") >= 0
  }

  def isSpark350OrLater(sparkVersion: String = sparkRuntimeVersion): Boolean = {
    compareToSparkVersion(sparkVersion, "350") >= 0
  }

  def isPluginEnabled(properties: Map[String, String]): Boolean = {
    (properties.getOrElse(config.PLUGINS.key, "").contains("com.nvidia.spark.SQLPlugin")
      && properties.getOrElse("spark.rapids.sql.enabled", "true").toBoolean)
  }

  def showString(df: DataFrame, numRows: Int) = {
    df.showString(numRows, 0)
  }

  // given to duration values, calculate a human readable percent
  // rounded to 2 decimal places. ie 39.12%
  def calculateDurationPercent(first: Long, total: Long): Double = {
    val firstDec = BigDecimal.decimal(first)
    val totalDec = BigDecimal.decimal(total)
    if (firstDec == 0 || totalDec == 0) {
      0.toDouble
    } else {
      val res = (firstDec / totalDec) * 100
      formatDoubleValue(res, 2)
    }
  }

  // given to duration values, calculate a human average
  // rounded to specified number of decimal places.
  def calculateAverage(first: Double, size: Long, places: Int): Double = {
    val firstDec = BigDecimal.decimal(first)
    val sizeDec = BigDecimal.decimal(size)
    if (firstDec == 0 || sizeDec == 0) {
      0.toDouble
    } else {
      val res = (firstDec / sizeDec)
      formatDoubleValue(res, places)
    }
  }

  def formatDoubleValue(bigValNum: BigDecimal, places: Int): Double = {
    bigValNum.setScale(places, BigDecimal.RoundingMode.HALF_UP).toDouble
  }

  def formatDoublePrecision(valNum: Double): String = {
    truncateDoubleToTwoDecimal(valNum).toString
  }

  def truncateDoubleToTwoDecimal(valNum: Double): Double = {
    // floor is applied after multiplying by 100. This keeps the number "as is" up-to two decimal.
    math.floor(valNum * 100) / 100
  }

  /**
   * Converts a sequence of elements to a single string that can be appended to a formatted text.
   * Delegates to [[com.nvidia.spark.rapids.tool.profiling.ProfileUtils.replaceDelimiter]] to
   * replace what is used as a text delimiter with something else.
   *
   * @param values the sequence of elements to join together.
   * @param separator the separator string to use.
   * @param txtDelimiter the delimiter used by the output file format (i.e., comma for CSV).
   * @return a string representation of the input sequence value. In the resulting string the string
   *         representations (w.r.t. the method toString) of all elements are separated by
   *         the string sep.
   */
  def renderTextField(values: Seq[Any], separator: String, txtDelimiter: String): String = {
    replaceDelimiter(values.mkString(separator), txtDelimiter)
  }

  def formatComplexTypes(
      values: Seq[String], fileDelimiter: String = QualOutputWriter.CSV_DELIMITER): String = {
    renderTextField(values, ";", fileDelimiter)
  }

  def formatPotentialProblems(
      values: Seq[String], fileDelimiter: String = QualOutputWriter.CSV_DELIMITER): String = {
    renderTextField(values, ":", fileDelimiter)
  }

  /**
   * Given a spark property key, this predicates checks if it is related to RAPIDS configurations.
   * Note that, "related RAPIDS properties" do not always have 'spark.rapids' prefix.
   *
   * @param sparkPropKey the spark property key
   * @return True if it is directly related to RAPIDS
   */
  def isRapidsPropKey(pKey: String): Boolean = {
    pKey.startsWith(PROPS_RAPIDS_KEY_PREFIX) || pKey.startsWith("spark.executorEnv.UCX") ||
      pKey.startsWith("spark.shuffle.manager") || pKey.equals("spark.shuffle.service.enabled")
  }

  /**
   * Checks if the given value is supported for all Ops or not.
   * @param fileEncoding the value being read from the Application configs
   * @return True if file encoding is supported
   */
  def isFileEncodingRecommended(fileEncoding: String): Boolean = {
    fileEncoding.matches("(?i)utf-?8")
  }

  /**
   * Collects the paths that points to RAPIDS jars in a map of properties.
   * @param properties the map of properties to holding the app configuration.
   * @return set of unique file paths that matches RAPIDS jars patterns.
   */
  def extractRAPIDSJarsFromProps(properties: collection.Map[String, String]): Set[String] = {
    properties.filterKeys(POSSIBLE_JARS_PROPERTIES.contains(_)).collect {
      case (_, pVal) if pVal.matches(RAPIDS_JAR_REGEX.regex) =>
        pVal.split(",").filter(_.matches(RAPIDS_JAR_REGEX.regex))
    }.flatten.toSet
  }
}

object JoinType {
  val Inner = "Inner"
  val Cross = "Cross"
  val LeftOuter = "LeftOuter"
  val RightOuter = "RightOuter"
  val FullOuter = "FullOuter"
  val LeftSemi = "LeftSemi"
  val LeftAnti = "LeftAnti"
  val ExistenceJoin = "ExistenceJoin"

  val supportedJoinTypeForBuildRight = Set(Inner, Cross, LeftOuter, LeftSemi,
    LeftAnti, FullOuter, ExistenceJoin)

  val supportedJoinTypeForBuildLeft = Set(Inner, Cross, RightOuter, FullOuter)

  val allsupportedJoinType = Set(Inner, Cross, LeftOuter, RightOuter, FullOuter, LeftSemi,
    LeftAnti, ExistenceJoin)
}

object BuildSide {
  val BuildLeft = "BuildLeft"
  val BuildRight = "BuildRight"

  val supportedBuildSides = Map(BuildLeft -> JoinType.supportedJoinTypeForBuildLeft,
    BuildRight -> JoinType.supportedJoinTypeForBuildRight)
}

object SQLMetricsStats {
  val SIZE_METRIC = "size"
  val TIMING_METRIC = "timing"
  val NS_TIMING_METRIC = "nsTiming"
  val AVERAGE_METRIC = "average"
  val SUM_METRIC = "sum"

  def hasStats(metrics : String): Boolean = {
    metrics match {
      case SIZE_METRIC | TIMING_METRIC | NS_TIMING_METRIC | AVERAGE_METRIC => true
      case _ => false
    }
  }
}

case class RDDCheckResult(
    nodeNameRDD: Boolean,
    nodeDescRDD: Boolean,
    expr: Set[String] = Set.empty) {
  def isRDD: Boolean = nodeNameRDD || nodeDescRDD
}

object RDDCheckHelper {
  // regular expression to search for RDDs in node descriptions
  private val dataSetRDDRegExDescLookup = Set(
    ".*\\$Lambda\\$.*".r,
    ".*\\.apply$".r
  )
  // regular expression to search for RDDs in node names
  private val dataSetOrRDDRegExLookup = Set(
    "ExistingRDD$".r,
    "^Scan ExistingRDD.*".r,
    "SerializeFromObject$".r,
    "DeserializeToObject$".r,
    "MapPartitions$".r,
    "MapElements$".r,
    "AppendColumns$".r,
    "AppendColumnsWithObject$".r,
    "MapGroups$".r,
    "FlatMapGroupsInR$".r,
    "FlatMapGroupsInRWithArrow$".r,
    "CoGroup$".r
  )

  def isDatasetOrRDDPlan(nodeName: String, nodeDesc: String): RDDCheckResult = {
    val nodeNameRdd = dataSetOrRDDRegExLookup.exists(regEx => nodeName.trim.matches(regEx.regex))
    // For optimization purpose, we do not want to to search for matches inside node description
    // if it is not necessary.
    val nodeDescRdd = !nodeNameRdd &&
      dataSetRDDRegExDescLookup.exists(regEx => nodeDesc.matches(regEx.regex))
    // TODO: catch the expressions that match the regular expression so we can pass it later to
    //       the reporting
    RDDCheckResult(nodeNameRdd, nodeDescRdd)
  }
}


object ExecHelper {
  private val UDFRegExLookup = Set(
    ".*UDF.*".r
  )

  // we don't want to mark the *InPandas and ArrowEvalPythonExec as unsupported with UDF
  private val skipUDFCheckExecs = Seq("ArrowEvalPython", "AggregateInPandas",
    "FlatMapGroupsInPandas", "MapInPandas", "WindowInPandas", "PythonMapInArrow", "MapInArrow")

  // Set containing execs that should be labeled as "shouldRemove"
  private val execsToBeRemoved = Set(
    "GenerateBloomFilter",      // Exclusive on AWS. Ignore it as metrics cannot be evaluated.
    "ReusedExchange",           // reusedExchange should not be added to speedups
    "ColumnarToRow",            // for now, assume everything is columnar
    // Our customer-integration team requested this to be added to the list of execs to be removed.
    "ResultQueryStage",
    // AdaptiveSparkPlan is not a real exec. It is a wrapper for the whole plan.
    // Our customer-integration team requested this to be added to the list of execs to be removed.
    "AdaptiveSparkPlan",        // according to request from our customer facing team
    SubqueryExecParser.execName // Subquery represents a simple collect
  )

  def isUDF(node: SparkPlanGraphNode): Boolean = {
    if (skipUDFCheckExecs.exists(node.name.contains(_))) {
      false
    } else {
      UDFRegExLookup.exists(regEx => node.desc.matches(regEx.regex))
    }
  }

  def shouldBeRemoved(nodeName: String): Boolean = {
    execsToBeRemoved.contains(nodeName)
  }

  ///////////////////////////////////////////
  // start definitions of execs to be ignored
  // Collect Limit replacement can be slower on the GPU. Disabled by default.
  private val CollectLimit = "CollectLimit"
  // Some DDL's and table commands which can be ignored
  private val ExecuteCreateViewCommand = "Execute CreateViewCommand"
  private val LocalTableScan = "LocalTableScan"
  private val ExecuteCreateDatabaseCommand = "Execute CreateDatabaseCommand"
  private val ExecuteDropDatabaseCommand = "Execute DropDatabaseCommand"
  private val ExecuteCreateTableAsSelectCommand = "Execute CreateTableAsSelectCommand"
  private val ExecuteCreateTableCommand = "Execute CreateTableCommand"
  private val ExecuteDropTableCommand = "Execute DropTableCommand"
  private val ExecuteCreateDataSourceTableAsSelectCommand = "Execute " +
    "CreateDataSourceTableAsSelectCommand"
  private val SetCatalogAndNamespace = "SetCatalogAndNamespace"
  private val ExecuteSetCommand = "Execute SetCommand"
  private val ResultQueryStage = "ResultQueryStage"
  private val ExecAddJarsCommand = "Execute AddJarsCommand"
  private val ExecInsertIntoHadoopFSRelationCommand = "Execute InsertIntoHadoopFsRelationCommand"
  private val ScanJDBCRelation = "Scan JDBCRelation"
  private val ScanOneRowRelation = "Scan OneRowRelation"
  private val CommandResult = "CommandResult"
  private val ExecuteAlterTableRecoverPartitionsCommand =
    "Execute AlterTableRecoverPartitionsCommand"
  private val ExecuteCreateFunctionCommand = "Execute CreateFunctionCommand"
  private val CreateHiveTableAsSelectCommand = "Execute CreateFunctionCommand"
  private val ExecuteDeleteCommand = "Execute DeleteCommand"
  private val ExecuteDescribeTableCommand = "Execute DescribeTableCommand"
  private val ExecuteRefreshTable = "Execute RefreshTable"
  private val ExecuteRepairTableCommand = "Execute RepairTableCommand"
  private val ExecuteShowPartitionsCommand = "Execute ShowPartitionsCommand"
  private val ExecuteClearCacheCommand = "Execute ClearCacheCommand"
  private val ExecuteOptimizeTableCommandEdge = "Execute OptimizeTableCommandEdge"
  // DeltaLakeOperations
  private val ExecUpdateCommandEdge = "Execute UpdateCommandEdge"
  private val ExecDeleteCommandEdge = "Execute DeleteCommandEdge"
  private val ExecDescribeDeltaHistoryCommand = "Execute DescribeDeltaHistoryCommand"
  private val ExecShowPartitionsDeltaCommand = "Execute ShowPartitionsDeltaCommand"

  def getAllIgnoreExecs: Set[String] = Set(CollectLimit,
    ExecuteCreateViewCommand, LocalTableScan, ExecuteCreateTableCommand,
    ExecuteDropTableCommand, ExecuteCreateDatabaseCommand, ExecuteDropDatabaseCommand,
    ExecuteCreateTableAsSelectCommand, ExecuteCreateDataSourceTableAsSelectCommand,
    SetCatalogAndNamespace, ExecuteSetCommand,
    ResultQueryStage,
    ExecAddJarsCommand,
    ExecInsertIntoHadoopFSRelationCommand,
    ScanJDBCRelation,
    ScanOneRowRelation,
    CommandResult,
    ExecUpdateCommandEdge,
    ExecDeleteCommandEdge,
    ExecDescribeDeltaHistoryCommand,
    ExecShowPartitionsDeltaCommand,
    ExecuteAlterTableRecoverPartitionsCommand,
    ExecuteCreateFunctionCommand,
    CreateHiveTableAsSelectCommand,
    ExecuteDeleteCommand,
    ExecuteDescribeTableCommand,
    ExecuteRefreshTable,
    ExecuteRepairTableCommand,
    ExecuteShowPartitionsCommand,
    ExecuteClearCacheCommand,
    ExecuteOptimizeTableCommandEdge,
    SubqueryExecParser.execName
  )

  def shouldIgnore(execName: String): Boolean = {
    getAllIgnoreExecs.contains(execName)
  }
}

case class UnsupportedExpr(exprName: String, unsupportedReason: String)

object MlOps {
  val sparkml = "spark.ml."
  val xgBoost = "spark.XGBoost"
  val pysparkLog = "py4j.GatewayConnection.run" // pyspark eventlog contains py4j
}

object MlOpsEventLogType {
  val pyspark = "pyspark"
  val scala = "scala"
}

object SupportedMLFuncsName {
  val funcName: Map[String, String] = Map(
    "org.apache.spark.ml.clustering.KMeans.fit" -> "KMeans",
    "org.apache.spark.ml.feature.PCA.fit" -> "PCA",
    "org.apache.spark.ml.regression.LinearRegression.train" -> "LinearRegression",
    "org.apache.spark.ml.classification.RandomForestClassifier.train" -> "RandomForestClassifier",
    "org.apache.spark.ml.regression.RandomForestRegressor.train" -> "RandomForestRegressor",
    "ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier.train" -> "XGBoost"
  )
}

class AppEventlogProcessException(message: String) extends Exception(message)

case class GpuEventLogException(
    message: String = "Cannot parse event logs from GPU run: skipping this file")
    extends AppEventlogProcessException(message)

case class PhotonEventLogException(
    message: String = "Encountered Databricks Photon event log: skipping this file!")
    extends AppEventlogProcessException(message)

case class StreamingEventLogException(
    message: String = "Encountered Spark Structured Streaming Job: skipping this file!")
    extends AppEventlogProcessException(message)

case class IncorrectAppStatusException(
    message: String = "Application status is incorrect. Missing AppInfo")
    extends AppEventlogProcessException(message)

case class UnsupportedMetricNameException(metricName: String)
    extends AppEventlogProcessException(
      s"Unsupported metric name found in the event log: $metricName")

// Class used a container to hold the information of the Tuple
// to simplify arguments of methods and caching.
case class SqlPlanInfoGraphEntry(
    sqlID: Long,
    planInfo: SparkPlanInfo,
    sparkPlanGraph: SparkPlanGraph
)

// A class used to cache the SQLPlanInfoGraphs
class SqlPlanInfoGraphBuffer {
  // A set to hold the SqlPlanInfoGraphEntry. LinkedHashSet to maintain the order of insertion.
  val sqlPlanInfoGraphs = mutable.LinkedHashSet[SqlPlanInfoGraphEntry]()
  def addSqlPlanInfoGraph(sqlID: Long, planInfo: SparkPlanInfo): SqlPlanInfoGraphEntry = {
    val newEntry = SqlPlanInfoGraphBuffer.createEntry(sqlID, planInfo)
    sqlPlanInfoGraphs += newEntry
    newEntry
  }
}

object SqlPlanInfoGraphBuffer {
  def apply(): SqlPlanInfoGraphBuffer = new SqlPlanInfoGraphBuffer()
  def createEntry(sqlID: Long, planInfo: SparkPlanInfo): SqlPlanInfoGraphEntry = {
    val planGraph = ToolsPlanGraph(planInfo)
    SqlPlanInfoGraphEntry(sqlID, planInfo, planGraph)
  }
}

// Case class to represent a failed AppInfo creation
case class FailureApp(
    status: String,
    message: String
)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy