All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.rapids.tool.AppBase.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2021-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.rapids.tool

import java.io.InputStream
import java.util.zip.GZIPInputStream

import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map, SortedMap}

import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo}
import com.nvidia.spark.rapids.tool.planparser.{HiveParseHelper, ReadParser}
import com.nvidia.spark.rapids.tool.planparser.HiveParseHelper.isHiveTableScanNode
import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DataSourceCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase}
import com.nvidia.spark.rapids.tool.qualification.AppSubscriber
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.deploy.history.{EventLogFileReader, EventLogFileWriter}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.store.{AccumManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source}
import org.apache.spark.util.Utils

abstract class AppBase(
    val eventLogInfo: Option[EventLogInfo],
    val hadoopConf: Option[Configuration]) extends Logging with ClusterTagPropHandler {

  var appMetaData: Option[AppMetaData] = None

  // appId is string is stored as a field in the AppMetaData class
  def appId: String = {
    appMetaData match {
      case Some(meta) => meta.appId.getOrElse("")
      case _ => ""
    }
  }

  lazy val attemptId: Int = appMetaData.map(_.attemptId).getOrElse(1)

  // Store map of executorId to executor info
  val executorIdToInfo = new HashMap[String, ExecutorInfoClass]()
  // resourceProfile id to resource profile info
  val resourceProfIdToInfo = new HashMap[Int, ResourceProfileInfoCase]()
  var blockManagersRemoved: ArrayBuffer[BlockManagerRemovedCase] =
    ArrayBuffer[BlockManagerRemovedCase]()
  // The data source information
  val dataSourceInfo: ArrayBuffer[DataSourceCase] = ArrayBuffer[DataSourceCase]()

  // jobId to job info
  val jobIdToInfo = new HashMap[Int, JobInfoClass]()
  val jobIdToSqlID: HashMap[Int, Long] = HashMap.empty[Int, Long]

  // SQL containing any Dataset operation or RDD to DataSet/DataFrame operation
  val sqlIDToDataSetOrRDDCase: HashSet[Long] = HashSet[Long]()
  // Map (sqlID <-> String(problematic issues))
  // Use LinkedHashSet of Strings to preserve the order of insertion.
  val sqlIDtoProblematic: HashMap[Long, LinkedHashSet[String]] =
    HashMap[Long, LinkedHashSet[String]]()
  // sqlId to sql info
  val sqlIdToInfo = new HashMap[Long, SQLExecutionInfoClass]()
  val sqlIdToStages = new HashMap[Long, ArrayBuffer[Int]]()
  // sqlPlans stores HashMap (sqlID <-> SparkPlanInfo)
  // SortedMap is used to keep the order of the sqlPlans since AQEs can overrides the existing ones
  var sqlPlans: Map[Long, SparkPlanInfo] = SortedMap[Long, SparkPlanInfo]()
  var sqlPlanMetricsAdaptive: ArrayBuffer[SQLPlanMetricsCase] = ArrayBuffer[SQLPlanMetricsCase]()

  // accum id to task stage accum info
  lazy val accumManager: AccumManager = new AccumManager()

  lazy val stageManager: StageModelManager = new StageModelManager()
  // Container that manages TaskIno including SparkMetrics.
  // A task is added during a TaskEnd eventLog
  lazy val taskManager: TaskModelManager = new TaskModelManager()

  var driverAccumMap: HashMap[Long, ArrayBuffer[DriverAccumCase]] =
    HashMap[Long, ArrayBuffer[DriverAccumCase]]()

  var sparkRapidsBuildInfo: SparkRapidsBuildInfoEvent = SparkRapidsBuildInfoEvent(immutable.Map(),
    immutable.Map(), immutable.Map(), immutable.Map())

  // Returns the String value of the eventlog or empty if it is not defined. Note that the eventlog
  // won't be defined for running applications
  def getEventLogPath: String = {
    eventLogInfo.map(_.eventLog).getOrElse(new Path("")).toString
  }

  // Update the endTime of the application and calculate the duration.
  // This is called while handling ApplicationEnd event
  def updateEndTime(newEndTime: Long): Unit = {
    appMetaData.foreach(_.setEndTime(newEndTime))
  }

  // Returns a boolean flag to indicate whether the endTime was estimated.
  def isAppDurationEstimated: Boolean = {
    appMetaData.map(_.isDurationEstimated).getOrElse(false)
  }

  // Returns the AppName
  def getAppName: String = {
    appMetaData.map(_.appName).getOrElse("")
  }

  // Returns optional endTime in ms.
  def getAppEndTime: Option[Long] = {
    appMetaData.flatMap(_.endTime)
  }

  // Returns optional wallClock duration of the Application
  def getAppDuration: Option[Long] = {
    appMetaData.flatMap(_.duration)
  }

  // Returns a boolean true/false. This is used to check whether processing an eventlog was
  // successful.
  def isAppMetaDefined: Boolean = appMetaData.isDefined

  /**
   * Sets an estimated endTime to the application based on the function passed as an argument.
   * First it checks if the endTime is already defined or not.
   * This method is a temporary refactor because both QualAppInfo and ProfAppInfo have different
   * ways of estimating the endTime.
   *
   * @param callBack function to estimate the endTime
   */
  def estimateAppEndTime(callBack: () => Option[Long]): Unit = {
    if (getAppEndTime.isEmpty) {
      val estimatedResult = callBack()
      estimatedResult.foreach(eT => appMetaData.foreach(_.setEndTime(eT, estimated = true)))
    }
  }

  def guestimateAppEndTimeCB(): () => Option[Long] = {
    () => None
  }

  // time in ms
  def calculateAppDuration(): Option[Long] = {
    if (appMetaData.isDefined) {
      val appMeta = appMetaData.get
      val startTime = appMeta.startTime
      if (startTime > 0) {
        estimateAppEndTime(guestimateAppEndTimeCB())
      }
      getAppDuration
    } else {
      None
    }
  }

  /**
   * Calculates total core seconds which is the sum over executor core seconds. Executor
   * core seconds is computed as executor duration (s) multiplied by num of cores.
   */
  def calculateTotalCoreSec(): Long = {
    var totalCoreSec: Double = 0
    executorIdToInfo.foreach { case(_, eInfo) =>
      val eStartTime = eInfo.addTime
      var eEndTime = eInfo.removeTime
      if (eEndTime == 0L) {
        getAppEndTime match {
          case Some(appEndTime) =>
            eEndTime = appEndTime
          case None =>
            logInfo("Unable to find either executor or app end time: " +
              "setting executor duration to 0")
            eEndTime = eStartTime
        }
      }
      totalCoreSec += (eEndTime - eStartTime).toDouble / 1000 * eInfo.totalCores
    }
    // round up for edge case when total core seconds is in range [0, 1)
    math.ceil(totalCoreSec).toLong
  }

  def getOrCreateExecutor(executorId: String, addTime: Long): ExecutorInfoClass = {
    executorIdToInfo.getOrElseUpdate(executorId, {
      new ExecutorInfoClass(executorId, addTime)
    })
  }

  def getOrCreateStage(info: StageInfo): StageModel = {
    val stage = stageManager.addStageInfo(info)
    stage
  }

  def getAllStagesForJobsInSqlQuery(sqlID: Long): Seq[Int] = {
    val jobsIdsInSQLQuery = jobIdToSqlID.filter { case (_, sqlIdForJob) =>
      sqlIdForJob == sqlID
    }.keys.toSeq
    jobsIdsInSQLQuery.flatMap { jId =>
      jobIdToInfo.get(jId).map(_.stageIds)
    }.flatten
  }

  def cleanupAccumId(accId: Long): Unit = {
    accumManager.removeAccumInfo(accId)
    driverAccumMap.remove(accId)
  }

  def cleanupStages(stageIds: Set[Int]): Unit = {
    // stageIdToInfo can have multiple stage attempts, remove all of them
    stageManager.removeStages(stageIds)
  }

  def cleanupSQL(sqlID: Long): Unit = {
    sqlIDToDataSetOrRDDCase.remove(sqlID)
    sqlIDtoProblematic.remove(sqlID)
    sqlIdToInfo.remove(sqlID)
    sqlPlans.remove(sqlID)
    val dsToRemove = dataSourceInfo.filter(_.sqlID == sqlID)
    dsToRemove.foreach(dataSourceInfo -= _)

    val jobsInSql = jobIdToSqlID.filter { case (_, sqlIdForJob) =>
      sqlIdForJob == sqlID
    }.keys
    jobsInSql.foreach { jobId =>
      // must call cleanupStage first
      // clean when no other jobs need that stage
      // not sure about races here but lets check the jobs and assume we can clean
      // when none of them reference this stage
      val stagesNotInOtherJobs = jobIdToInfo.get(jobId).map { jInfo =>
        val stagesInJobToRemove = jInfo.stageIds.toSet
        // check to make sure no other jobs reference the same stage
        val allOtherJobs = jobIdToInfo - jobId
        val allOtherStageIds = allOtherJobs.values.flatMap(_.stageIds).toSet
        stagesInJobToRemove.filter(!allOtherStageIds.contains(_))
      }
      stagesNotInOtherJobs.foreach(cleanupStages(_))
      jobIdToSqlID.remove(_)
      jobIdToInfo.remove(_)
    }
  }

  def processEvent(event: SparkListenerEvent): Boolean

  private def openEventLogInternal(log: Path, fs: FileSystem): InputStream = {
    EventLogFileWriter.codecName(log) match {
      case c if c.isDefined && c.get.equals("gz") =>
        val in = fs.open(log)
        try {
          new GZIPInputStream(in)
        } catch {
          case e: Throwable =>
            in.close()
            throw e
        }
      case _ => EventLogFileReader.openEventLog(log, fs)
    }
  }

  /**
   * Internal function to process all the events
   */
  private def processEventsInternal(): Unit = {
    eventLogInfo match {
      case Some(eventLog) =>
        val eventLogPath = eventLog.eventLog
        logInfo("Start Parsing Event Log: " + eventLogPath.toString)

        // at this point all paths should be valid event logs or event log dirs
        val hconf = hadoopConf.getOrElse(RapidsToolsConfUtil.newHadoopConf)
        val fs = eventLogPath.getFileSystem(hconf)
        var totalNumEvents = 0L
        val readerOpt = eventLog match {
          case _: DatabricksEventLog =>
            Some(new DatabricksRollingEventLogFilesFileReader(fs, eventLogPath))
          case _ => EventLogFileReader(fs, eventLogPath)
        }

        if (readerOpt.isDefined) {
          val reader = readerOpt.get
          val runtimeGetFromJsonMethod = EventUtils.getEventFromJsonMethod
          reader.listEventLogFiles.foreach { file =>
            Utils.tryWithResource(openEventLogInternal(file.getPath, fs)) { in =>
              UTF8Source.fromInputStream(in).getLines().find { line =>
                // Using find as foreach with conditional to exit early if we are done.
                // Do NOT use a while loop as it is much much slower.
                totalNumEvents += 1
                runtimeGetFromJsonMethod.apply(line) match {
                  case Some(e) => processEvent(e)
                  case None => false
                }
              }
            }
          }
        } else {
          logError(s"Error getting reader for ${eventLogPath.getName}")
        }
        logInfo(s"Total number of events parsed: $totalNumEvents for ${eventLogPath.toString}")
      case None => logInfo("Streaming events to application")
    }
  }

  private val UDFRegex = ".*UDF.*"

  private val potentialIssuesRegexMap = Map(
    UDFRegex -> "UDF",
    ".*current_timestamp\\(.*\\).*" -> "TIMEZONE current_timestamp()",
    ".*to_timestamp\\(.*\\).*" -> "TIMEZONE to_timestamp()",
    ".*hour\\(.*\\).*" -> "TIMEZONE hour()",
    ".*minute\\(.*\\).*" -> "TIMEZONE minute()",
    ".*second\\(.*\\).*" -> "TIMEZONE second()"
  )

  def findPotentialIssues(desc: String): Set[String] =  {
    val potentialIssuesRegexs = potentialIssuesRegexMap
    val issues = potentialIssuesRegexs.filterKeys(desc.matches(_))
    issues.values.toSet
  }

  /**
   * Builds cluster information based on executor nodes.
   * If executor nodes exist, calculates the number of hosts and total cores,
   * and extracts executor and driver instance types (databricks only)
   */
  protected def buildClusterInfo: Unit = {}

  // The ReadSchema metadata is only in the eventlog for DataSource V1 readers
  def checkMetadataForReadSchema(
      sqlPlanInfoGraph: SqlPlanInfoGraphEntry): ArrayBuffer[DataSourceCase] = {
    // check if planInfo has ReadSchema
    val allMetaWithSchema = AppBase.getPlanMetaWithSchema(sqlPlanInfoGraph.planInfo)
    val allNodes = sqlPlanInfoGraph.sparkPlanGraph.allNodes
    val results = ArrayBuffer[DataSourceCase]()

    allMetaWithSchema.foreach { plan =>
      val meta = plan.metadata
      val readSchema = ReadParser.formatSchemaStr(meta.getOrElse("ReadSchema", ""))
      val scanNode = allNodes.filter(node => {
        // Get ReadSchema of each Node and sanitize it for comparison
        val trimmedNode = AppBase.trimSchema(ReadParser.parseReadNode(node).schema)
        readSchema.contains(trimmedNode)
      }).filter(ReadParser.isScanNode(_))

      // If the ReadSchema is empty or if it is PhotonScan, then we don't need to
      // add it to the dataSourceInfo
      // Processing Photon eventlogs issue: https://github.com/NVIDIA/spark-rapids-tools/issues/251
      if (scanNode.nonEmpty) {
        results += DataSourceCase(
          sqlPlanInfoGraph.sqlID,
          scanNode.head.id,
          meta.getOrElse("Format", "unknown"),
          meta.getOrElse("Location", "unknown"),
          meta.getOrElse("PushedFilters", "unknown"),
          readSchema
        )
      }
    }
    // "scan hive" has no "ReadSchema" defined. So, we need to look explicitly for nodes
    // that are scan hive and add them one by one to the dataSource
    if (hiveEnabled) { // only scan for hive when the CatalogImplementation is using hive
      val allPlanWithHiveScan = AppBase.getPlanInfoWithHiveScan(sqlPlanInfoGraph.planInfo)
      allPlanWithHiveScan.foreach { hiveReadPlan =>
        val sqlGraph = ToolsPlanGraph(hiveReadPlan)
        val hiveScanNode = sqlGraph.allNodes.head
        val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode)
        results += DataSourceCase(
          sqlPlanInfoGraph.sqlID,
          hiveScanNode.id,
          scanHiveMeta.format,
          scanHiveMeta.location,
          scanHiveMeta.filters,
          scanHiveMeta.schema
        )
      }
    }
    dataSourceInfo ++= results
    results
  }

  // This will find scans for DataSource V2, if the schema is very large it
  // will likely be incomplete and have ... at the end.
  def checkGraphNodeForReads(
      sqlID: Long, node: SparkPlanGraphNode): Option[DataSourceCase] = {
    if (ReadParser.isDataSourceV2Node(node)) {
      val res = ReadParser.parseReadNode(node)
      val dsCase = DataSourceCase(
        sqlID,
        node.id,
        res.format,
        res.location,
        res.filters,
        res.schema)
      dataSourceInfo += dsCase
      Some(dsCase)
    } else {
      None
    }
  }

  protected def reportComplexTypes: (Seq[String], Seq[String]) = {
    if (dataSourceInfo.nonEmpty) {
      val schema = dataSourceInfo.map { ds => ds.schema }
      AppBase.parseReadSchemaForNestedTypes(schema)
    } else {
      (Seq(), Seq())
    }
  }

  protected def probNotDataset: HashMap[Long, LinkedHashSet[String]] = {
    sqlIDtoProblematic.filterNot { case (sqlID, _) => sqlIDToDataSetOrRDDCase.contains(sqlID) }
  }

  protected def getPotentialProblemsForDf: Seq[String] = {
    probNotDataset.values.flatten.toSet.toSeq
  }

  /**
   * Registers the attempt ID for the application and updates the tracker map if the attemptId is
   * greater than the existing attemptId.
   */
  def registerAttemptId(): Unit = {
    if (isAppMetaDefined) {
      val currentAttemptId = sparkProperties.getOrElse("spark.app.attempt.id", "1").toInt
      appMetaData.foreach(_.setAttemptId(currentAttemptId))
      AppSubscriber.subscribeAppAttempt(appId, currentAttemptId)
    }
  }

  protected def postCompletion(): Unit = {
    registerAttemptId()
    calculateAppDuration()
  }

  /**
   * Wrapper function to process all the events followed by any
   * post completion tasks.
   */
  def processEvents(): Unit = {
    processEventsInternal()
    postCompletion()
  }
}

object AppBase {

  def parseReadSchemaForNestedTypes(
      schema: ArrayBuffer[String]): (Seq[String], Seq[String]) = {
    val tempStringBuilder = new StringBuilder()
    val individualSchema: ArrayBuffer[String] = new ArrayBuffer()
    var angleBracketsCount = 0
    var parenthesesCount = 0
    val distinctSchema = schema.distinct.filter(_.nonEmpty).mkString(",")

    // Get the nested types i.e everything between < >
    for (char <- distinctSchema) {
      char match {
        case '<' => angleBracketsCount += 1
        case '>' => angleBracketsCount -= 1
        // If the schema has decimals, Example decimal(6,2) then we have to make sure it has both
        // opening and closing parentheses(unless the string is incomplete due to V2 reader).
        case '(' => parenthesesCount += 1
        case ')' => parenthesesCount -= 1
        case _ =>
      }
      if (angleBracketsCount == 0 && parenthesesCount == 0 && char.equals(',')) {
        individualSchema += tempStringBuilder.toString
        tempStringBuilder.setLength(0)
      } else {
        tempStringBuilder.append(char)
      }
    }
    if (tempStringBuilder.nonEmpty) {
      individualSchema += tempStringBuilder.toString
    }

    // If DataSource V2 is used, then Schema may be incomplete with ... appended at the end.
    // We determine complex types and nested complex types until ...
    val incompleteSchema = individualSchema.filter(x => x.contains("..."))
    val completeSchema = individualSchema.filterNot(x => x.contains("..."))

    // Check if it has types
    val incompleteTypes = incompleteSchema.map { x =>
      if (x.contains("...") && x.contains(":")) {
        val schemaTypes = x.split(":", 2)
        if (schemaTypes.size == 2) {
          val partialSchema = schemaTypes(1).split("\\.\\.\\.")
          if (partialSchema.size == 1) {
            partialSchema(0)
          } else {
            ""
          }
        } else {
          ""
        }
      } else {
        ""
      }
    }
    // Omit columnName and get only schemas
    val completeTypes = completeSchema.map { x =>
      val schemaTypes = x.split(":", 2)
      if (schemaTypes.size == 2) {
        schemaTypes(1)
      } else {
        ""
      }
    }
    val schemaTypes = completeTypes ++ incompleteTypes

    // Filter only complex types.
    // Example: array, array>
    val complexTypes = schemaTypes.filter(x =>
      x.startsWith("array<") || x.startsWith("map<") || x.startsWith("struct<"))

    // Determine nested complex types from complex types
    // Example: array> is nested complex type.
    val nestedComplexTypes = complexTypes.filter(complexType => {
      val startIndex = complexType.indexOf('<')
      val closedBracket = complexType.lastIndexOf('>')
      // If String is incomplete due to dsv2, then '>' may not be present. In that case traverse
      // until length of the incomplete string
      val lastIndex = if (closedBracket == -1) {
        complexType.length - 1
      } else {
        closedBracket
      }
      val string = complexType.substring(startIndex, lastIndex + 1)
      string.contains("array<") || string.contains("struct<") || string.contains("map<")
    })

    (complexTypes.filter(_.nonEmpty), nestedComplexTypes.filter(_.nonEmpty))
  }

  private def trimSchema(str: String): String = {
    val index = str.lastIndexOf(",")
    if (index != -1 && str.contains("...")) {
      str.substring(0, index)
    } else {
      str
    }
  }

  private def getPlanMetaWithSchema(planInfo: SparkPlanInfo): Seq[SparkPlanInfo] = {
    // TODO: This method does not belong to AppBase. It should move to another member.
    val childRes = planInfo.children.flatMap(getPlanMetaWithSchema(_))
    if (planInfo.metadata != null && planInfo.metadata.contains("ReadSchema")) {
      childRes :+ planInfo
    } else {
      childRes
    }
  }

  // Finds all the nodes that scan a hive table
  private def getPlanInfoWithHiveScan(planInfo: SparkPlanInfo): Seq[SparkPlanInfo] = {
    // TODO: This method does not belong to AppBAse. It should move to another member.
    val childRes = planInfo.children.flatMap(getPlanInfoWithHiveScan(_))
    if (isHiveTableScanNode(planInfo.nodeName)) {
      childRes :+ planInfo
    } else {
      childRes
    }
  }

  def handleException(e: Exception, path: EventLogInfo): FailureApp = {
    val (status, message): (String, String) = e match {
      case incorrectStatusEx: IncorrectAppStatusException =>
        ("unknown", incorrectStatusEx.getMessage)
      case skippedEx: AppEventlogProcessException =>
        ("skipped", skippedEx.getMessage)
      case _: com.fasterxml.jackson.core.JsonParseException =>
        ("unknown", s"Error parsing JSON: ${path.eventLog.toString}")
      case _: IllegalArgumentException =>
        ("unknown", s"Error parsing file: ${path.eventLog.toString}")
      case ue: Exception =>
        // catch all exceptions and skip that file
        ("unknown", s"Got unexpected exception processing file:" +
          s"${path.eventLog.toString}. ${ue.getMessage} ")
    }

    FailureApp(status, s"${e.getClass.getSimpleName}: $message")
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy