com.nvidia.spark.rapids.RapidsDeviceMemoryStore.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of rapids-4-spark_2.12 Show documentation
Show all versions of rapids-4-spark_2.12 Show documentation
Creates the distribution package of the RAPIDS plugin for Apache Spark
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import java.nio.channels.WritableByteChannel
import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable
import ai.rapids.cudf.{ColumnVector, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import org.apache.spark.sql.rapids.GpuTaskMetrics
import org.apache.spark.sql.rapids.storage.RapidsStorageUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* Buffer storage using device memory.
* @param chunkedPackBounceBufferSize this is the size of the bounce buffer to be used
* during spill in chunked_pack. The parameter defaults to 128MB,
* with a rule-of-thumb of 1MB per SM.
*/
class RapidsDeviceMemoryStore(
chunkedPackBounceBufferSize: Long = 128L*1024*1024,
hostBounceBufferSize: Long = 128L*1024*1024)
extends RapidsBufferStore(StorageTier.DEVICE) {
// The RapidsDeviceMemoryStore handles spillability via ref counting
override protected def spillableOnAdd: Boolean = false
// bounce buffer to be used during chunked pack in GPU to host memory spill
private var chunkedPackBounceBuffer: DeviceMemoryBuffer =
DeviceMemoryBuffer.allocate(chunkedPackBounceBufferSize)
private var hostSpillBounceBuffer: HostMemoryBuffer =
HostMemoryBuffer.allocate(hostBounceBufferSize)
override protected def createBuffer(
other: RapidsBuffer,
catalog: RapidsBufferCatalog,
stream: Cuda.Stream): Option[RapidsBufferBase] = {
val memoryBuffer = withResource(other.getCopyIterator) { copyIterator =>
copyIterator.next()
}
withResource(memoryBuffer) { _ =>
val deviceBuffer = {
memoryBuffer match {
case d: DeviceMemoryBuffer => d
case h: HostMemoryBuffer =>
GpuTaskMetrics.get.readSpillFromHostTime {
closeOnExcept(DeviceMemoryBuffer.allocate(memoryBuffer.getLength)) { deviceBuffer =>
logDebug(s"copying from host $h to device $deviceBuffer")
deviceBuffer.copyFromHostBuffer(h, stream)
deviceBuffer
}
}
case b => throw new IllegalStateException(s"Unrecognized buffer: $b")
}
}
Some(new RapidsDeviceMemoryBuffer(
other.id,
deviceBuffer.getLength,
other.meta,
deviceBuffer,
other.getSpillPriority))
}
}
/**
* Adds a buffer to the device storage. This does NOT take ownership of the
* buffer, so it is the responsibility of the caller to close it.
*
* This function is called only from the RapidsBufferCatalog, under the
* catalog lock.
*
* @param id the RapidsBufferId to use for this buffer
* @param buffer buffer that will be owned by the store
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param needsSync whether the spill framework should stream synchronize while adding
* this device buffer (defaults to true)
* @return the RapidsBuffer instance that was added.
*/
def addBuffer(
id: RapidsBufferId,
buffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
needsSync: Boolean): RapidsBuffer = {
buffer.incRefCount()
val rapidsBuffer = new RapidsDeviceMemoryBuffer(
id,
buffer.getLength,
tableMeta,
buffer,
initialSpillPriority)
freeOnExcept(rapidsBuffer) { _ =>
logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " +
s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " +
s"meta_id=${tableMeta.bufferMeta.id}, " +
s"meta_size=${tableMeta.bufferMeta.size}]")
addBuffer(rapidsBuffer, needsSync)
rapidsBuffer
}
}
/**
* Adds a table to the device storage.
*
* This takes ownership of the table.
*
* This function is called only from the RapidsBufferCatalog, under the
* catalog lock.
*
* @param id the RapidsBufferId to use for this table
* @param table table that will be owned by the store
* @param initialSpillPriority starting spill priority value
* @param needsSync whether the spill framework should stream synchronize while adding
* this table (defaults to true)
* @return the RapidsBuffer instance that was added.
*/
def addTable(
id: RapidsBufferId,
table: Table,
initialSpillPriority: Long,
needsSync: Boolean): RapidsBuffer = {
val rapidsTable = new RapidsTable(
id,
table,
initialSpillPriority)
freeOnExcept(rapidsTable) { _ =>
addBuffer(rapidsTable, needsSync)
rapidsTable
}
}
/**
* A per cuDF column event handler that handles calls to .close()
* inside of the `ColumnVector` lock.
*/
class RapidsDeviceColumnEventHandler
extends ColumnVector.EventHandler {
// Every RapidsTable that references this column has an entry in this map.
// The value represents the number of times (normally 1) that a ColumnVector
// appears in the RapidsTable. This is also the ColumnVector refCount at which
// the column is considered spillable.
// The map is protected via the ColumnVector lock.
private val registration = new mutable.HashMap[RapidsTable, Int]()
/**
* Every RapidsTable iterates through its columns and either creates
* a `ColumnTracking` object and associates it with the column's
* `eventHandler` or calls into the existing one, and registers itself.
*
* The registration has two goals: it accounts for repetition of a column
* in a `RapidsTable`. If a table has the same column repeated it must adjust
* the refCount at which this column is considered spillable.
*
* The second goal is to account for aliasing. If two tables alias this column
* we are going to mark it as non spillable.
*
* @param rapidsTable - the table that is registering itself with this tracker
*/
def register(rapidsTable: RapidsTable, repetition: Int): Unit = {
registration.put(rapidsTable, repetition)
}
/**
* This is invoked during `RapidsTable.free` in order to remove the entry
* in `registration`.
* @param rapidsTable - the table that is de-registering itself
*/
def deregister(rapidsTable: RapidsTable): Unit = {
registration.remove(rapidsTable)
}
// called with the cudfCv lock held from cuDF's side
override def onClosed(cudfCv: ColumnVector, refCount: Int): Unit = {
// we only handle spillability if there is a single table registered
// (no aliasing)
if (registration.size == 1) {
val (rapidsTable, spillableRefCount) = registration.head
if (spillableRefCount == refCount) {
rapidsTable.onColumnSpillable(cudfCv)
}
}
}
}
/**
* A `RapidsTable` is the spill store holder of a cuDF `Table`.
*
* The table is not contiguous in GPU memory. Instead, this `RapidsBuffer` instance
* allows us to use the cuDF chunked_pack API to make the table contiguous as the spill
* is happening.
*
* This class owns the cuDF table and will close it when `close` is called.
*
* @param id the `RapidsBufferId` this table is associated with
* @param table the cuDF table that we are managing
* @param spillPriority a starting spill priority
*/
class RapidsTable(
id: RapidsBufferId,
table: Table,
spillPriority: Long)
extends RapidsBufferBase(
id,
null,
spillPriority)
with RapidsBufferChannelWritable {
/** The storage tier for this buffer */
override val storageTier: StorageTier = StorageTier.DEVICE
override val supportsChunkedPacker: Boolean = true
// This is the current size in batch form. It is to be used while this
// table hasn't migrated to another store.
private val unpackedSizeInBytes: Long = GpuColumnVector.getTotalDeviceMemoryUsed(table)
// By default all columns are NOT spillable since we are not the only owners of
// the columns (the caller is holding onto a ColumnarBatch that will be closed
// after instantiation, triggering onClosed callbacks)
// This hash set contains the columns that are currently spillable.
private val columnSpillability = new ConcurrentHashMap[ColumnVector, Boolean]()
private val numDistinctColumns =
(0 until table.getNumberOfColumns).map(table.getColumn).distinct.size
// we register our event callbacks as the very first action to deal with
// spillability
registerOnCloseEventHandler()
/** Release the underlying resources for this buffer. */
override protected def releaseResources(): Unit = {
table.close()
}
private lazy val (cachedMeta, cachedPackedSize) = {
withResource(makeChunkedPacker) { cp =>
(cp.getMeta, cp.getTotalContiguousSize)
}
}
override def meta: TableMeta = cachedMeta
override val memoryUsedBytes: Long = unpackedSizeInBytes
override def getPackedSizeBytes: Long = cachedPackedSize
override def makeChunkedPacker: ChunkedPacker =
new ChunkedPacker(id, table, chunkedPackBounceBuffer)
/**
* Mark a column as spillable
*
* @param column the ColumnVector to mark as spillable
*/
def onColumnSpillable(column: ColumnVector): Unit = {
columnSpillability.put(column, true)
updateSpillability()
}
/**
* Update the spillability state of this RapidsTable. This is invoked from
* two places:
*
* - from the onColumnSpillable callback, which is invoked from a
* ColumnVector.EventHandler.onClosed callback.
*
* - after adding a table to the store to mark the table as spillable if
* all columns are spillable.
*/
override def updateSpillability(): Unit = {
setSpillable(this, columnSpillability.size == numDistinctColumns)
}
/**
* Produce a `ColumnarBatch` from our table, and in the process make ourselves
* not spillable.
*
* @param sparkTypes the spark data types the batch should have
*/
override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
columnSpillability.clear()
setSpillable(this, false)
GpuColumnVector.from(table, sparkTypes)
}
/**
* Get the underlying memory buffer. This may be either a HostMemoryBuffer or a
* DeviceMemoryBuffer depending on where the buffer currently resides.
* The caller must have successfully acquired the buffer beforehand.
*
* @see [[addReference]]
* @note It is the responsibility of the caller to close the buffer.
*/
override def getMemoryBuffer: MemoryBuffer = {
throw new UnsupportedOperationException(
"RapidsDeviceMemoryBatch doesn't support getMemoryBuffer")
}
override def free(): Unit = {
// lets remove our handler from the chain of handlers for each column
removeOnCloseEventHandler()
super.free()
}
private def registerOnCloseEventHandler(): Unit = {
val columns = (0 until table.getNumberOfColumns).map(table.getColumn)
// cudfColumns could contain duplicates. We need to take this into account when we are
// deciding the floor refCount for a duplicated column
val repetitionPerColumn = new mutable.HashMap[ColumnVector, Int]()
columns.foreach { col =>
val repetitionCount = repetitionPerColumn.getOrElse(col, 0)
repetitionPerColumn(col) = repetitionCount + 1
}
repetitionPerColumn.foreach { case (distinctCv, repetition) =>
// lock the column because we are setting its event handler, and we are inspecting
// its refCount.
distinctCv.synchronized {
val eventHandler = distinctCv.getEventHandler match {
case null =>
val eventHandler = new RapidsDeviceColumnEventHandler
distinctCv.setEventHandler(eventHandler)
eventHandler
case existing: RapidsDeviceColumnEventHandler =>
existing
case other =>
throw new IllegalStateException(
s"Invalid column event handler $other")
}
eventHandler.register(this, repetition)
if (repetition == distinctCv.getRefCount) {
onColumnSpillable(distinctCv)
}
}
}
}
// this method is called from free()
private def removeOnCloseEventHandler(): Unit = {
val distinctColumns =
(0 until table.getNumberOfColumns).map(table.getColumn).distinct
distinctColumns.foreach { distinctCv =>
distinctCv.synchronized {
distinctCv.getEventHandler match {
case eventHandler: RapidsDeviceColumnEventHandler =>
eventHandler.deregister(this)
case t =>
throw new IllegalStateException(
s"Invalid column event handler $t")
}
}
}
}
override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = {
var written: Long = 0L
withResource(getCopyIterator) { copyIter =>
while(copyIter.hasNext) {
withResource(copyIter.next()) { slice =>
val iter =
new MemoryBufferToHostByteBufferIterator(
slice,
hostSpillBounceBuffer,
stream)
iter.foreach { bb =>
try {
while (bb.hasRemaining) {
written += outputChannel.write(bb)
}
} finally {
RapidsStorageUtils.dispose(bb)
}
}
}
}
written
}
}
}
class RapidsDeviceMemoryBuffer(
id: RapidsBufferId,
size: Long,
meta: TableMeta,
contigBuffer: DeviceMemoryBuffer,
spillPriority: Long)
extends RapidsBufferBase(id, meta, spillPriority)
with MemoryBuffer.EventHandler
with RapidsBufferChannelWritable {
override val memoryUsedBytes: Long = size
override val storageTier: StorageTier = StorageTier.DEVICE
// If this require triggers, we are re-adding a `DeviceMemoryBuffer` outside of
// the catalog lock, which should not possible. The event handler is set to null
// when we free the `RapidsDeviceMemoryBuffer` and if the buffer is not free, we
// take out another handle (in the catalog).
// TODO: This is not robust (to rely on outside locking and addReference/free)
// and should be revisited.
require(contigBuffer.setEventHandler(this) == null,
"DeviceMemoryBuffer with non-null event handler failed to add!!")
/**
* Override from the MemoryBuffer.EventHandler interface.
*
* If we are being invoked we have the `contigBuffer` lock, as this callback
* is being invoked from `MemoryBuffer.close`
*
* @param refCount - contigBuffer's current refCount
*/
override def onClosed(refCount: Int): Unit = {
// refCount == 1 means only 1 reference exists to `contigBuffer` in the
// RapidsDeviceMemoryBuffer (we own it)
if (refCount == 1) {
// setSpillable is being called here as an extension of `MemoryBuffer.close()`
// we hold the MemoryBuffer lock and we could be called from a Spark task thread
// Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other
// call to `setSpillable` is also under this same MemoryBuffer lock (see:
// `getDeviceMemoryBuffer`)
setSpillable(this, true)
}
}
override protected def releaseResources(): Unit = synchronized {
// we need to disassociate this RapidsBuffer from the underlying buffer
contigBuffer.close()
}
/**
* Get and increase the reference count of the device memory buffer
* in this RapidsBuffer, while making the RapidsBuffer non-spillable.
*
* @note It is the responsibility of the caller to close the DeviceMemoryBuffer
*/
override def getDeviceMemoryBuffer: DeviceMemoryBuffer = synchronized {
contigBuffer.synchronized {
setSpillable(this, false)
contigBuffer.incRefCount()
contigBuffer
}
}
override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer
override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
// calling `getDeviceMemoryBuffer` guarantees that we have marked this RapidsBuffer
// as not spillable and increased its refCount atomically
withResource(getDeviceMemoryBuffer) { buff =>
columnarBatchFromDeviceBuffer(buff, sparkTypes)
}
}
/**
* We overwrite free to make sure we don't have a handler for the underlying
* contigBuffer, since this `RapidsBuffer` is no longer tracked.
*/
override def free(): Unit = synchronized {
if (isValid) {
// it is going to be invalid when calling super.free()
contigBuffer.setEventHandler(null)
}
super.free()
}
override def writeToChannel(outputChannel: WritableByteChannel, stream: Cuda.Stream): Long = {
var written: Long = 0L
val iter = new MemoryBufferToHostByteBufferIterator(
contigBuffer,
hostSpillBounceBuffer,
stream)
iter.foreach { bb =>
try {
while (bb.hasRemaining) {
written += outputChannel.write(bb)
}
} finally {
RapidsStorageUtils.dispose(bb)
}
}
written
}
}
override def close(): Unit = {
try {
super.close()
} finally {
Seq(chunkedPackBounceBuffer, hostSpillBounceBuffer).safeClose()
chunkedPackBounceBuffer = null
hostSpillBounceBuffer = null
}
}
}