All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.rapids.execution.TrampolineUtil.scala Maven / Gradle / Ivy

There is a newer version: 24.10.1
Show newest version
/*
 * Copyright (c) 2019-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.rapids.execution

import java.util.concurrent.{ScheduledExecutorService, ThreadPoolExecutor}

import org.apache.hadoop.conf.Configuration
import org.json4s.JsonAST

import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkMasterRegex, SparkUpgradeException, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.config
import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.io.CompressionCodec
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.scheduler.SparkListenerEvent
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim
import org.apache.spark.sql.rapids.shims.SparkUpgradeExceptionShims
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{ShutdownHookManager, ThreadUtils, Utils}

object TrampolineUtil {
  def doExecuteBroadcast[T](child: SparkPlan): Broadcast[T] = child.doExecuteBroadcast()

  def isSupportedRelation(mode: BroadcastMode): Boolean =
    ShimTrampolineUtil.isSupportedRelation(mode)

  def unionLikeMerge(left: DataType, right: DataType): DataType =
    ShimTrampolineUtil.unionLikeMerge(left, right)

  def fromAttributes(attrs: Seq[Attribute]): StructType = DataTypeUtilsShim.fromAttributes(attrs)

  def toAttributes(structType: StructType): Seq[Attribute] =
    DataTypeUtilsShim.toAttributes(structType)

  def jsonValue(dataType: DataType): JsonAST.JValue = dataType.jsonValue

  /** Get a human-readable string, e.g.: "4.0 MiB", for a value in bytes. */
  def bytesToString(size: Long): String = Utils.bytesToString(size)

  /** Returns true if called from code running on the Spark driver. */
  def isDriver(env: SparkEnv): Boolean = {
    if (env != null) {
      env.executorId == SparkContext.DRIVER_IDENTIFIER
    } else {
      false
    }
  }

  def isDriver(sparkConf: SparkConf): Boolean = {
    sparkConf.get(EXECUTOR_ID).map(_ == SparkContext.DRIVER_IDENTIFIER)
      .getOrElse(isDriver(SparkEnv.get))
  }

  /**
   * Return true if the provided predicate function returns true for any
   * type node within the datatype tree.
   */
  def dataTypeExistsRecursively(dt: DataType, f: DataType => Boolean): Boolean = {
    dt.existsRecursively(f)
  }

  /**
   * Check if `a` and `b` are the same data type when ignoring nullability
   * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
   */
  def sameType(a: DataType, b: DataType): Boolean = a.sameType(b)

  def incInputRecordsRows(inputMetrics: InputMetrics, rows: Long): Unit =
    inputMetrics.incRecordsRead(rows)

  def makeSparkUpgradeException(
                                version: String,
                                message: String,
                                cause: Throwable): SparkUpgradeException = {
    SparkUpgradeExceptionShims.newSparkUpgradeException(version, message, cause)
  }

  /** Shuts down and cleans up any existing Spark session */
  def cleanupAnyExistingSession(): Unit = SparkSession.cleanupAnyExistingSession()

  def asNullable(dt: DataType): DataType = dt.asNullable

  /** Return a new InputMetrics instance */
  def newInputMetrics(): InputMetrics = new InputMetrics()

  /**
   * Increment the task's memory bytes spilled metric. If the current thread does not
   * correspond to a Spark task then this call does nothing.
   * @param amountSpilled amount of memory spilled in bytes
   */
  def incTaskMetricsMemoryBytesSpilled(amountSpilled: Long): Unit = {
    Option(TaskContext.get).foreach(_.taskMetrics().incMemoryBytesSpilled(amountSpilled))
  }

  /**
   * Increment the task's disk bytes spilled metric. If the current thread does not
   * correspond to a Spark task then this call does nothing.
   * @param amountSpilled amount of memory spilled in bytes
   */
  def incTaskMetricsDiskBytesSpilled(amountSpilled: Long): Unit = {
    Option(TaskContext.get).foreach(_.taskMetrics().incDiskBytesSpilled(amountSpilled))
  }

  /**
   * Returns a function that can be called to find Hadoop FileSystem bytes read. If
   * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
   * return the bytes read on r since t.
   */
  def getFSBytesReadOnThreadCallback(): () => Long = {
    SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
  }

  /** Set the bytes read task input metric */
  def incBytesRead(inputMetrics: InputMetrics, bytesRead: Long): Unit = {
    inputMetrics.incBytesRead(bytesRead)
  }

  /** Get the simple name of a class with fixup for any Scala internal errors */
  def getSimpleName(cls: Class[_]): String = {
    Utils.getSimpleName(cls)
  }

  /** Create a `BlockManagerId` instance */
  def newBlockManagerId(
      execId: String,
      host: String,
      port: Int,
      topologyInfo: Option[String] = None): BlockManagerId =
    BlockManagerId(execId, host, port, topologyInfo)

  def getTaskMemoryManager(): TaskMemoryManager = {
    TaskContext.get.taskMemoryManager()
  }

  /** Set the task context for the current thread */
  def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)

  /** Remove the task context for the current thread */
  def unsetTaskContext(): Unit = TaskContext.unset()

  /** Add shutdown hook with priority */
  def addShutdownHook(priority: Int, runnable: Runnable): AnyRef = {
    ShutdownHookManager.addShutdownHook(priority)(() => runnable.run())
  }

  def classForName[C](
      className: String,
      initialize: Boolean = true,
      noSparkClassLoader: Boolean = false): Class[C] = {
    Utils.classForName(className, initialize, noSparkClassLoader)
  }

  def getSparkConf(spark: SparkSession): SQLConf = {
    spark.sqlContext.conf
  }

  def setExecutorEnv(sc: SparkContext, key: String, value: String): Unit = {
    sc.executorEnvs(key) = value
  }

  def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
    CompressionCodec.createCodec(conf, codecName)
  }

  def createCodec(conf: SparkConf): CompressionCodec = {
    CompressionCodec.createCodec(conf)
  }

  def getCodecShortName(codecName: String): String = CompressionCodec.getShortName(codecName)

  def getSerializerManager(): SerializerManager = {
    if (SparkEnv.get != null) SparkEnv.get.serializerManager else null
  }

  // For test only
  def createSerializerManager(conf: SparkConf): SerializerManager = {
    new SerializerManager(new JavaSerializer(conf), conf, Some(CryptoStreamUtils.createKey(conf)))
  }

  def isCompressSpill(conf: SparkConf): Boolean = {
    conf.get(config.SHUFFLE_SPILL_COMPRESS)
  }

  // If the master is a local mode (local or local-cluster), return the number
  // of cores per executor it is going to use, otherwise return 1.
  def getCoresInLocalMode(master: String, conf: SparkConf): Int = {
    master match {
      case SparkMasterRegex.LOCAL_CLUSTER_REGEX(_, coresPerWorker, _) =>
        coresPerWorker.toInt
      case master if master.startsWith("local") =>
        SparkContext.numDriverCores(master, conf)
      case _ =>
        1
    }
  }

  def newDaemonCachedThreadPool(
      prefix: String,
      maxThreadNumber: Int,
      keepAliveSeconds: Int = 60): ThreadPoolExecutor = {
    // We want to utilize the ThreadUtils class' ThreadPoolExecutor creation
    // which gives us important Hadoop config variables that are needed for the
    // Unity Catalog authentication
    ThreadUtils.newDaemonCachedThreadPool(prefix, maxThreadNumber, keepAliveSeconds)
  }

  def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
    ThreadUtils.newDaemonSingleThreadScheduledExecutor(threadName)
  }

  def postEvent(sc: SparkContext, sparkEvent: SparkListenerEvent): Unit = {
    sc.listenerBus.post(sparkEvent)
  }

  def getSparkHadoopUtilConf: Configuration = SparkHadoopUtil.get.conf

}

/**
 * This class is to only be used to throw errors specific to the
 * RAPIDS Accelerator or errors mirroring Spark where a raw
 * AnalysisException is thrown directly rather than via an error
 * utility class (this should be rare).
 */
class RapidsAnalysisException(msg: String) extends AnalysisException(msg)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy