All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.rapids.execution.GpuBroadcastExchangeExec.scala Maven / Gradle / Ivy

There is a newer version: 24.10.1
Show newest version
/*
 * Copyright (c) 2019-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.rapids.execution

import java.io._
import java.util.UUID
import java.util.concurrent._

import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.ref.WeakReference
import scala.util.control.NonFatal

import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange}
import ai.rapids.cudf.JCudfSerialization.HostConcatResult
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.lore.{GpuLoreDumpRDD, SimpleRDD}
import com.nvidia.spark.rapids.lore.GpuLore.LORE_DUMP_RDD_TAG
import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNode, SparkShimImpl}

import org.apache.spark.SparkException
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

/**
 * Class that is used to broadcast results (a contiguous host batch) to executors.
 *
 * This is instantiated in the driver, serialized to an output stream provided by Spark
 * to broadcast, and deserialized on the executor. Both the driver's and executor's copies
 * are cleaned via GC. Because Spark closes `AutoCloseable` broadcast results after spilling
 * to disk, this class does not subclass `AutoCloseable`. Instead we implement a `closeInternal`
 * method only to be triggered via GC.
 *
 * @param data HostConcatResult populated for a broadcast that has column, otherwise it is null.
 *             It is transient because we want the executor to deserialize its `data` from Spark's
 *             torrent-backed input stream.
 * @param output used to find the schema for this broadcast batch
 * @param numRows number of rows for this broadcast batch
 * @param dataLen size in bytes for this broadcast batch
 */
// scalastyle:off no.finalize
@SerialVersionUID(100L)
class SerializeConcatHostBuffersDeserializeBatch(
    @transient var data: HostConcatResult,
    output: Seq[Attribute],
    var numRows: Int,
    var dataLen: Long)
  extends Serializable with Logging {
  @transient private var dataTypes = output.map(_.dataType).toArray

  // used for memoization of deserialization to GPU on Executor
  @transient private var batchInternal: SpillableColumnarBatch = null

  private def maybeGpuBatch: Option[SpillableColumnarBatch] = Option(batchInternal)

  def batch: SpillableColumnarBatch = this.synchronized {
    maybeGpuBatch.getOrElse {
      withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
        val spillable =
          if (data == null || data.getTableHeader.getNumColumns == 0) {
            // If `data` is null or there are no columns, this is a rows-only batch
            SpillableColumnarBatch(
              new ColumnarBatch(Array.empty, numRows),
              SpillPriorities.ACTIVE_BATCHING_PRIORITY)
          } else if (data.getTableHeader.getNumRows == 0) {
            // If we have columns but no rows, we can use the emptyBatchFromTypes optimization
            SpillableColumnarBatch(
              GpuColumnVector.emptyBatchFromTypes(dataTypes),
              SpillPriorities.ACTIVE_BATCHING_PRIORITY)
          } else {
            // Regular GPU batch with rows/cols
            SpillableColumnarBatch(
              data.toContiguousTable,
              dataTypes,
              SpillPriorities.ACTIVE_BATCHING_PRIORITY)
          }
        // At this point we no longer need the host data and should not need to touch it again.
        // Note that we don't close this using `withResources` around the creation of the
        // `SpillableColumnarBatch`. That is because if a retry exception is thrown we want to
        // still be able to recreate this broadcast batch, so we can't close the host data
        // until we are at this line.
        data.safeClose()
        data = null
        batchInternal = spillable
        spillable
      }
    }
  }

  /**
   * Create host columnar batches from either serialized buffers or device columnar batch. This
   * method can be safely called in both driver node and executor nodes. For now, it is used on
   * the driver side for reusing GPU broadcast results in the CPU.
   *
   * NOTE: The caller is responsible to release these host columnar batches.
   */
  def hostBatch: ColumnarBatch = this.synchronized {
    maybeGpuBatch.map { spillable =>
      withResource(spillable.getColumnarBatch()) { batch =>
        val hostColumns: Array[ColumnVector] = GpuColumnVector
          .extractColumns(batch)
          .safeMap(_.copyToHost())
        new ColumnarBatch(hostColumns, numRows)
      }
    }.getOrElse {
      withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
        if (data == null) {
          new ColumnarBatch(Array.empty, numRows)
        } else {
          val header = data.getTableHeader
          val buffer = data.getHostBuffer
          val hostColumns = SerializedHostTableUtils.buildHostColumns(
            header, buffer, dataTypes)
          val rowCount = header.getNumRows
          new ColumnarBatch(hostColumns.toArray, rowCount)
        }
      }
    }
  }

  private def writeObject(out: ObjectOutputStream): Unit = {
    doWriteObject(out)
  }

  private def readObject(in: ObjectInputStream): Unit = {
    doReadObject(in)
  }

  /**
   * doWriteObject is invoked from both the driver, when it is trying to write
   * a collected broadcast result on an stream to torrent broadcast to executors, and also
   * when the executor MemoryStore evicts a "broadcast_[id]" block to make room in host memory.
   *
   * The driver will have `data` populated on construction and the executor will deserialize
   * the object and, as part of the deserialization, invoke `doReadObject`.
   * This will populate `data` before any task has had a chance to call `.batch` on this class.
   *
   * If `batchInternal` is defined we are in the executor, and there is no work to be done.
   * This broadcast has been materialized on the GPU/RapidsBufferCatalog, and it is completely
   * managed by the plugin.
   *
   * Public for unit tests.
   *
   * @param out the stream to write to
   */
  def doWriteObject(out: ObjectOutputStream): Unit = this.synchronized {
    maybeGpuBatch.map {
      case justRows: JustRowsColumnarBatch =>
        JCudfSerialization.writeRowsToStream(out, justRows.numRows())
      case scb: SpillableColumnarBatch =>
        val table = withResource(scb.getColumnarBatch()) { cb =>
          GpuColumnVector.from(cb)
        }
        withResource(table) { _ =>
          JCudfSerialization.writeToStream(table, out, 0, table.getRowCount)
        }
        out.writeObject(dataTypes)
    }.getOrElse {
      if (data == null || data.getTableHeader.getNumColumns == 0) {
        JCudfSerialization.writeRowsToStream(out, numRows)
      } else if (numRows == 0) {
        // We didn't get any data back, but we need to write out an empty table that matches
        withResource(GpuColumnVector.emptyHostColumns(dataTypes)) { hostVectors =>
          JCudfSerialization.writeToStream(hostVectors, out, 0, 0)
        }
        out.writeObject(dataTypes)
      } else {
        val headers = Array(data.getTableHeader)
        val buffers = Array(data.getHostBuffer)
        JCudfSerialization.writeConcatedStream(headers, buffers, out)
        out.writeObject(dataTypes)
      }
    }
  }

  /**
   * Deserializes a broadcast result in the host into `data`, `numRows` and `dataLen`.
   *
   * Public for unit tests.
   */
  def doReadObject(in: ObjectInputStream): Unit = this.synchronized {
    // no-op if we already have `batchInternal` or `data` set
    if (batchInternal == null && data == null) {
      withResource(new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ =>
        val (header, buffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(in)
        withResource(buffer) { _ =>
          dataTypes = if (header.getNumColumns > 0) {
            in.readObject().asInstanceOf[Array[DataType]]
          } else {
            Array.empty
          }
          // for a rowsOnly broadcast, null out the `data` member.
          val rowsOnly = dataTypes.isEmpty
          numRows = header.getNumRows
          dataLen = header.getDataLen
          data = if (!rowsOnly) {
            JCudfSerialization.concatToHostBuffer(Array(header), Array(buffer))
          } else {
            null
          }
        }
      }
    }
  }

  def dataSize: Long = dataLen

  /**
   * This method is meant to only be called from `finalize` and it is not a regular
   * AutoCloseable.close because we do not want Spark to close `batchInternal` when it spills
   * the broadcast block's host torrent data.
   *
   * Reference: https://github.com/NVIDIA/spark-rapids/issues/8602
   *
   * Public for tests.
   */
  def closeInternal(): Unit = this.synchronized {
    Seq(data, batchInternal).safeClose()
    data = null
    batchInternal = null
  }

  @scala.annotation.nowarn("msg=method finalize in class Object is deprecated")
  override def finalize(): Unit = {
    super.finalize()
    closeInternal()
  }
}
// scalastyle:on no.finalize

// scalastyle:off no.finalize
/**
 * Object used for executors to serialize a result for their partition that will be collected
 * on the driver to be broadcasted out as part of the exchange.
 * @param batch - GPU batch to be serialized and sent to the driver.
 */
@SerialVersionUID(100L)
class SerializeBatchDeserializeHostBuffer(batch: ColumnarBatch)
  extends Serializable with AutoCloseable {
  @transient private var columns = GpuColumnVector.extractBases(batch).map(_.copyToHost())
  @transient var header: JCudfSerialization.SerializedTableHeader = null
  @transient var buffer: HostMemoryBuffer = null
  @transient private var numRows = batch.numRows()

  private def writeObject(out: ObjectOutputStream): Unit = {
    withResource(new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ =>
      if (buffer != null) {
        throw new IllegalStateException("Cannot re-serialize a batch this way...")
      } else {
        JCudfSerialization.writeToStream(columns, out, 0, numRows)
        // In this case an RDD, we want to close the batch once it is serialized out or we will
        // leak GPU memory (technically it will just wait for GC to release it and probably
        // not a lot because this is used for a broadcast that really should be small)
        // In the case of broadcast the life cycle of the object is tied to GC and there is no clean
        // way to separate the two right now.  So we accept the leak.
        columns.safeClose()
        columns = null
      }
    }
  }

  private def readObject(in: ObjectInputStream): Unit = {
    withResource(new NvtxRange("HostDeserializeBatch", NvtxColor.PURPLE)) { _ =>
      val (h, b) = SerializedHostTableUtils.readTableHeaderAndBuffer(in)
      // buffer will only be cleaned up on GC, so cannot warn about leaks
      b.noWarnLeakExpected()
      header = h
      buffer = b
      numRows = h.getNumRows
    }
  }

  def dataSize: Long = {
    JCudfSerialization.getSerializedSizeInBytes(columns, 0, numRows)
  }

  override def close(): Unit = {
    columns.safeClose()
    columns = null
    buffer.safeClose()
    buffer = null
  }

  @scala.annotation.nowarn("msg=method finalize in class Object is deprecated")
  override def finalize(): Unit = {
    super.finalize()
    close()
  }
}

class GpuBroadcastMeta(
    exchange: BroadcastExchangeExec,
    conf: RapidsConf,
    parent: Option[RapidsMeta[_, _, _]],
    rule: DataFromReplacementRule) extends
  SparkPlanMeta[BroadcastExchangeExec](exchange, conf, parent, rule) with Logging {

  override def tagPlanForGpu(): Unit = {
    if (!TrampolineUtil.isSupportedRelation(exchange.mode)) {
      willNotWorkOnGpu(
        s"unsupported BroadcastMode: ${exchange.mode}. " +
          s"GPU supports only IdentityBroadcastMode and HashedRelationBroadcastMode")
    }
    def isSupported(rm: RapidsMeta[_, _, _]): Boolean = rm.wrapped match {
      case _: BroadcastHashJoinExec => true
      case _: BroadcastNestedLoopJoinExec => true
      case _ => false
    }
    if (parent.isDefined) {
      if (!parent.exists(isSupported)) {
        willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
            "with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
      }
    }
  }

  override def convertToGpu(): GpuExec = {
    GpuBroadcastExchangeExec(exchange.mode, childPlans.head.convertIfNeeded())(
      exchange.canonicalized.asInstanceOf[BroadcastExchangeExec])
  }
}

abstract class GpuBroadcastExchangeExecBase(
    mode: BroadcastMode,
    child: SparkPlan) extends ShimBroadcastExchangeLike with ShimUnaryExecNode with GpuExec {

  override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL
  override val outputBatchesLevel: MetricsLevel = MODERATE_LEVEL
  override lazy val additionalMetrics = Map(
    "dataSize" -> createSizeMetric(ESSENTIAL_LEVEL, "data size"),
    COLLECT_TIME -> createNanoTimingMetric(ESSENTIAL_LEVEL, DESCRIPTION_COLLECT_TIME),
    BUILD_TIME -> createNanoTimingMetric(ESSENTIAL_LEVEL, DESCRIPTION_BUILD_TIME),
    "broadcastTime" -> createNanoTimingMetric(ESSENTIAL_LEVEL, "time to broadcast"))

  override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)

  // For now all broadcasts produce a single batch. We might need to change that at some point
  override def outputBatching: CoalesceGoal = RequireSingleBatch

  @transient
  protected val timeout: Long = SQLConf.get.broadcastTimeout

  // prior to Spark 3.5.0, runId is defined as `def` rather than `val` so
  // produces a new ID on each reference. We override with a `val` so that
  // the value is assigned once.
  override val runId: UUID = UUID.randomUUID

  @transient
  lazy val relationFuture: Future[Broadcast[Any]] = {
    // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
    val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
    val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
    val dataSize = gpuLongMetric("dataSize")
    val collectTime = gpuLongMetric(COLLECT_TIME)
    val buildTime = gpuLongMetric(BUILD_TIME)
    val broadcastTime = gpuLongMetric("broadcastTime")

    SQLExecution.withThreadLocalCaptured[Broadcast[Any]](
        session, GpuBroadcastExchangeExecBase.executionContext) {
      try {
        // Setup a job group here so later it may get cancelled by groupId if necessary.
        sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId ${runId})",
          interruptOnCancel = true)
        val broadcastResult = {
          val collected =
            withResource(new NvtxWithMetrics("broadcast collect", NvtxColor.GREEN,
              collectTime)) { _ =>
              val childRdd = child.executeColumnar()

              // collect batches from the executors
              val data = childRdd.map(withResource(_) { cb =>
                new SerializeBatchDeserializeHostBuffer(cb)
              })
              data.collect()
            }
          withResource(new NvtxWithMetrics("broadcast build", NvtxColor.DARK_GREEN,
            buildTime)) { _ =>
            val emptyRelation = if (collected.isEmpty) {
              SparkShimImpl.tryTransformIfEmptyRelation(mode)
            } else {
              None
            }
            emptyRelation.getOrElse {
              GpuBroadcastExchangeExecBase.makeBroadcastBatch(
                collected, output, numOutputBatches, numOutputRows, dataSize)
            }
          }
        }
        val broadcasted =
          withResource(new NvtxWithMetrics("broadcast", NvtxColor.CYAN,
              broadcastTime)) { _ =>
            // Broadcast the relation
            sparkContext.broadcast(broadcastResult)
        }
        SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
        promise.success(broadcasted)
        broadcasted
      } catch {
        // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
        // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
        // will catch this exception and re-throw the wrapped fatal throwable.
        case oe: OutOfMemoryError =>
          val ex = createOutOfMemoryException(oe)
          promise.failure(ex)
          throw ex
        case e if !NonFatal(e) =>
          val ex = new Exception(e)
          promise.failure(ex)
          throw ex
        case e: Throwable =>
          promise.failure(e)
          throw e
      }
    }
  }


  protected def createOutOfMemoryException(oe: OutOfMemoryError) = {
    new Exception(
      new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
        "worker nodes. As a workaround, you can either disable broadcast by setting " +
        s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
        s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
        .initCause(oe.getCause))
  }

  override protected def doPrepare(): Unit = {
    // Materialize the future.
    relationFuture
  }

  override protected def doExecute(): RDD[InternalRow] = {
    throw new UnsupportedOperationException(
      "GpuBroadcastExchange does not support the execute() code path.")
  }

  override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
    try {
      relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[Broadcast[T]]
    } catch {
      case ex: TimeoutException =>
        logError(s"Could not execute broadcast in $timeout secs.", ex)
        if (!relationFuture.isDone) {
          sparkContext.cancelJobGroup(runId.toString)
          relationFuture.cancel(true)
        }
        throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
          s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
          s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
          ex)
    }
  }

  final def executeColumnarBroadcast[T](): Broadcast[T] = {
    if (isCanonicalizedPlan) {
      throw new IllegalStateException("A canonicalized plan is not supposed to be executed.")
    }
    try {
      val ret = relationFuture.get(timeout, TimeUnit.SECONDS)
      doLoreDump(ret)
      ret.asInstanceOf[Broadcast[T]]
    } catch {
      case ex: TimeoutException =>
        logError(s"Could not execute broadcast in $timeout secs.", ex)
        if (!relationFuture.isDone) {
          sparkContext.cancelJobGroup(runId.toString)
          relationFuture.cancel(true)
        }
        throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
          s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
          s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
          ex)
    }
  }

  // We have to do this explicitly here rather than similar to the general version one in
  // [[GpuExec]] since in adaptive execution, the broadcast value has already been calculated
  // before we tag this plan to dump.
  private def doLoreDump(result: Broadcast[Any]): Unit = {
    val inner = new SimpleRDD(session.sparkContext, result, schema)
    getTagValue(LORE_DUMP_RDD_TAG).foreach { info =>
      val rdd = new GpuLoreDumpRDD(info, inner)
      rdd.saveMeta()
      rdd.foreach(_.close())
    }
  }

  override def runtimeStatistics: Statistics = {
    Statistics(
      sizeInBytes = metrics("dataSize").value,
      rowCount = Some(metrics(GpuMetric.NUM_OUTPUT_ROWS).value))
  }

  override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = {
    throw new IllegalStateException(s"Internal Error ${this.getClass} has column support" +
        s" mismatch:\n$this")
  }
}

object GpuBroadcastExchangeExecBase {
  val executionContext = ExecutionContext.fromExecutorService(
    org.apache.spark.util.ThreadUtils.newDaemonCachedThreadPool("gpu-broadcast-exchange",
      SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))

  protected def checkRowLimit(numRows: Int) = {
    // Spark restricts the size of broadcast relations to be less than 512000000 rows and we
    // enforce the same limit
    // scalastyle:off line.size.limit
    // https://github.com/apache/spark/blob/v3.1.1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala#L586
    // scalastyle:on line.size.limit
    if (numRows >= 512000000) {
      throw new SparkException(
        s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
    }
  }

  protected def checkSizeLimit(sizeInBytes: Long) = {
    // Spark restricts the size of broadcast relations to be less than 8GB
    if (sizeInBytes >= MAX_BROADCAST_TABLE_BYTES) {
      throw new SparkException(
        s"Cannot broadcast the table that is larger than" +
            s"${MAX_BROADCAST_TABLE_BYTES >> 30}GB: ${sizeInBytes >> 30} GB")
    }
  }

  /**
   * Concatenate deserialized host buffers into a single HostConcatResult that is then
   * passed to a `SerializeConcatHostBuffersDeserializeBatch`.
   *
   * This result will in turn be broadcasted from the driver to the executors.
   */
  def makeBroadcastBatch(
      buffers: Array[SerializeBatchDeserializeHostBuffer],
      output: Seq[Attribute],
      numOutputBatches: GpuMetric,
      numOutputRows: GpuMetric,
      dataSize: GpuMetric): SerializeConcatHostBuffersDeserializeBatch = {
    val rowsOnly = buffers.isEmpty || buffers.head.header.getNumColumns == 0
    var numRows = 0
    var dataLen: Long = 0
    val hostConcatResult = if (rowsOnly) {
      numRows = withResource(buffers) { _ =>
        require(output.isEmpty,
          "Rows-only broadcast resolved had non-empty " +
              s"output ${output.mkString(",")}")
        buffers.map(_.header.getNumRows).sum
      }
      checkRowLimit(numRows)
      null
    } else {
      val hostConcatResult = withResource(buffers) { _ =>
        JCudfSerialization.concatToHostBuffer(
          buffers.map(_.header), buffers.map(_.buffer))
      }
      closeOnExcept(hostConcatResult) { _ =>
        checkRowLimit(hostConcatResult.getTableHeader.getNumRows)
        checkSizeLimit(hostConcatResult.getTableHeader.getDataLen)
      }
      // this result will be GC'ed later, so we mark it as such
      hostConcatResult.getHostBuffer.noWarnLeakExpected()
      numRows = hostConcatResult.getTableHeader.getNumRows
      dataLen = hostConcatResult.getTableHeader.getDataLen
      hostConcatResult
    }
    numOutputBatches += 1
    numOutputRows += numRows
    dataSize += dataLen

    // create the batch we will broadcast out
    new SerializeConcatHostBuffersDeserializeBatch(
      hostConcatResult, output, numRows, dataLen)
  }
}

case class GpuBroadcastExchangeExec(
    mode: BroadcastMode,
    child: SparkPlan)
    (val cpuCanonical: BroadcastExchangeExec)
    extends GpuBroadcastExchangeExecBase(mode, child) {
  override def otherCopyArgs: Seq[AnyRef] = Seq(cpuCanonical)

  private var _isGpuPlanningComplete = false

  /**
   * Returns true if this node and children are finished being optimized by the RAPIDS Accelerator.
   */
  def isGpuPlanningComplete: Boolean = _isGpuPlanningComplete

  /**
   * Method to call after all RAPIDS Accelerator optimizations have been applied
   * to indicate this node and its children are done being planned by the RAPIDS Accelerator.
   * Some optimizations, such as AQE exchange reuse fixup, need to know when a node will no longer
   * be updated so it can be tracked for reuse.
   */
  def markGpuPlanningComplete(): Unit = {
    if (!_isGpuPlanningComplete) {
      _isGpuPlanningComplete = true
      ExchangeMappingCache.trackExchangeMapping(cpuCanonical, this)
    }
  }

  override def doCanonicalize(): SparkPlan = {
    GpuBroadcastExchangeExec(mode.canonicalized, child.canonicalized)(cpuCanonical)
  }
}

/** Caches the mappings from canonical CPU exchanges to the GPU exchanges that replaced them */
object ExchangeMappingCache extends Logging {
  // Cache is a mapping from CPU broadcast plan to GPU broadcast plan. The cache should not
  // artificially hold onto unused plans, so we make both the keys and values weak. The values
  // point to their corresponding keys, so the keys will not be collected unless the value
  // can be collected. The values will be held during normal Catalyst planning until those
  // plans are no longer referenced, allowing both the key and value to be reaped at that point.
  private val cache = new mutable.WeakHashMap[Exchange, WeakReference[Exchange]]

  /** Try to find a recent GPU exchange that has replaced the specified CPU canonical plan. */
  def findGpuExchangeReplacement(cpuCanonical: Exchange): Option[Exchange] = {
    cache.get(cpuCanonical).flatMap(_.get)
  }

  /** Add a GPU exchange to the exchange cache */
  def trackExchangeMapping(cpuCanonical: Exchange, gpuExchange: Exchange): Unit = {
    val old = findGpuExchangeReplacement(cpuCanonical)
    if (!old.exists(_.asInstanceOf[GpuBroadcastExchangeExec].isGpuPlanningComplete)) {
      cache.put(cpuCanonical, WeakReference(gpuExchange))
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy