All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.rapids.GpuFileFormatDataWriter.scala Maven / Gradle / Ivy

There is a newer version: 24.10.1
Show newest version
/*
 * Copyright (c) 2019-2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.rapids

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.util.hashing.{MurmurHash3 => ScalaMurmur3Hash}

import ai.rapids.cudf.{OrderByArg, Table}
import com.nvidia.spark.TimingUtils
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
import com.nvidia.spark.rapids.shims.GpuFileFormatDataWriterShim
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Cast, Concat, Expression, HiveHash, Literal, Murmur3Hash, NullsFirst, ScalaUDF, UnsafeProjection}
import org.apache.spark.sql.connector.write.DataWriter
import org.apache.spark.sql.execution.datasources.{BucketingUtils, PartitioningUtils, WriteTaskResult}
import org.apache.spark.sql.rapids.GpuFileFormatDataWriter._
import org.apache.spark.sql.rapids.GpuFileFormatWriter.GpuConcurrentOutputWriterSpec
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

object GpuFileFormatDataWriter {
  private def ceilingDiv(num: Long, divisor: Long) = {
    ((num + divisor - 1) / divisor).toInt
  }

  def shouldSplitToFitMaxRecordsPerFile(
      maxRecordsPerFile: Long, recordsInFile: Long, numRowsInBatch: Long): Boolean = {
    maxRecordsPerFile > 0 && (recordsInFile + numRowsInBatch) > maxRecordsPerFile
  }

  private def getSplitIndexes(maxRecordsPerFile: Long, recordsInFile: Long, numRows: Long) = {
    (maxRecordsPerFile > 0, recordsInFile < maxRecordsPerFile) match {
      case (false, _) => IndexedSeq.empty
      case (true, false) =>
        (1 until ceilingDiv(numRows, maxRecordsPerFile))
            .map(i => (i * maxRecordsPerFile).toInt)
      case (true, true) => {
        val filledUp = maxRecordsPerFile - recordsInFile
        val remain = numRows - filledUp
        (0 until ceilingDiv(remain, maxRecordsPerFile))
            .map(i => (filledUp + i * maxRecordsPerFile).toInt)
      }
    }
  }

  /**
   * Split a table into parts if recordsInFile + batch row count would go above
   * maxRecordsPerFile and make the splits spillable.
   *
   * The logic to find out what the splits should be is delegated to getSplitIndexes.
   *
   * The input batch is closed in case of error or in case we have to split it.
   * It is not closed if it wasn't split.
   *
   * @param batch ColumnarBatch to split (and close)
   * @param maxRecordsPerFile max rowcount per file
   * @param recordsInFile row count in the file so far
   * @return array of SpillableColumnarBatch splits
   */
  def splitToFitMaxRecordsAndClose(
      batch: ColumnarBatch,
      maxRecordsPerFile: Long,
      recordsInFile: Long): Array[SpillableColumnarBatch] = {
    val (types, splitIndexes) = closeOnExcept(batch) { _ =>
      val splitIndexes = getSplitIndexes(maxRecordsPerFile, recordsInFile, batch.numRows())
      (GpuColumnVector.extractTypes(batch), splitIndexes)
    }
    if (splitIndexes.isEmpty) {
      // this should never happen, as `splitToFitMaxRecordsAndClose` is called when
      // splits should already happen, but making it more efficient in that case
      Array(SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
    } else {
      // actually split it
      val tbl = withResource(batch) { _ =>
        GpuColumnVector.from(batch)
      }
      val cts = withResource(tbl) { _ =>
        tbl.contiguousSplit(splitIndexes: _*)
      }
      withResource(cts) { _ =>
        cts.safeMap(ct =>
          SpillableColumnarBatch(ct, types, SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
      }
    }
  }
}

/**
 * Abstract class for writing out data in a single Spark task using the GPU.
 * This is the GPU version of `org.apache.spark.sql.execution.datasources.FileFormatDataWriter`.
 */
abstract class GpuFileFormatDataWriter(
    description: GpuWriteJobDescription,
    taskAttemptContext: TaskAttemptContext,
    committer: FileCommitProtocol) extends DataWriter[ColumnarBatch] {

  protected class WriterAndStatus {
    var writer: ColumnarOutputWriter = _

    /** Number of records in current file. */
    var recordsInFile: Long = 0

    /**
     * File counter for writing current partition or bucket. For same partition or bucket,
     * we may have more than one file, due to number of records limit per file.
     */
    var fileCounter: Int = 0

    final def release(): Unit = {
      if (writer != null) {
        try {
          writer.close()
          statsTrackers.foreach(_.closeFile(writer.path()))
        } finally {
          writer = null
        }
      }
    }
  }

  /**
   * Max number of files a single task writes out due to file size. In most cases the number of
   * files written should be very small. This is just a safe guard to protect some really bad
   * settings, e.g. maxRecordsPerFile = 1.
   */
  protected val MAX_FILE_COUNTER: Int = 1000 * 1000
  protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
  protected var currentWriterStatus: WriterAndStatus = new WriterAndStatus()

  /** Trackers for computing various statistics on the data as it's being written out. */
  protected val statsTrackers: Seq[ColumnarWriteTaskStatsTracker] =
    description.statsTrackers.map(_.newTaskInstance())

  /** Release resources of a WriterStatus. */
  protected final def releaseOutWriter(status: WriterAndStatus): Unit = {
    status.release()
  }

  protected final def writeUpdateMetricsAndClose(scb: SpillableColumnarBatch,
      writerStatus: WriterAndStatus): Unit = {
    writerStatus.recordsInFile += writerStatus.writer.writeSpillableAndClose(scb, statsTrackers)
  }

  /** Release all resources. Public for testing */
  def releaseResources(): Unit = {
    // Release current writer by default, as this is the only resource to be released.
    releaseOutWriter(currentWriterStatus)
  }

  /** Write an iterator of column batch. */
  def writeWithIterator(iterator: Iterator[ColumnarBatch]): Unit = {
    while (iterator.hasNext) {
      write(iterator.next())
    }
  }

  /** Writes a columnar batch of records */
  def write(batch: ColumnarBatch): Unit

  /**
   * Returns the summary of relative information which
   * includes the list of partition strings written out. The list of partitions is sent back
   * to the driver and used to update the catalog. Other information will be sent back to the
   * driver too and used to e.g. update the metrics in UI.
   */
  override def commit(): WriteTaskResult = {
    releaseResources()
    val (taskCommitMessage, taskCommitTime) = TimingUtils.timeTakenMs {
      committer.commitTask(taskAttemptContext)
    }
    val summary = GpuFileFormatDataWriterShim.createWriteSummary(
      updatedPartitions = updatedPartitions.toSet,
      stats = statsTrackers.map(_.getFinalStats(taskCommitTime))
    )
    WriteTaskResult(taskCommitMessage, summary)
  }

  override def abort(): Unit = {
    try {
      releaseResources()
    } finally {
      committer.abortTask(taskAttemptContext)
    }
  }

  override def close(): Unit = {}
}

/** GPU data writer for empty partitions */
class GpuEmptyDirectoryDataWriter(
    description: GpuWriteJobDescription,
    taskAttemptContext: TaskAttemptContext,
    committer: FileCommitProtocol
) extends GpuFileFormatDataWriter(description, taskAttemptContext, committer) {
  override def write(batch: ColumnarBatch): Unit = {
    batch.close()
  }
}

/** Writes data to a single directory (used for non-dynamic-partition writes). */
class GpuSingleDirectoryDataWriter(
    description: GpuWriteJobDescription,
    taskAttemptContext: TaskAttemptContext,
    committer: FileCommitProtocol)
  extends GpuFileFormatDataWriter(description, taskAttemptContext, committer) {
  // Initialize currentWriter and statsTrackers
  newOutputWriter()

  @scala.annotation.nowarn(
    "msg=method newTaskTempFile in class FileCommitProtocol is deprecated"
  )
  private def newOutputWriter(): Unit = {
    currentWriterStatus.recordsInFile = 0
    val fileCounter = currentWriterStatus.fileCounter
    releaseResources()

    val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
    val currentPath = committer.newTaskTempFile(
      taskAttemptContext,
      None,
      f"-c$fileCounter%03d" + ext)

    currentWriterStatus.writer = description.outputWriterFactory.newInstance(
      path = currentPath,
      dataSchema = description.dataColumns.toStructType,
      context = taskAttemptContext)

    statsTrackers.foreach(_.newFile(currentPath))
  }

  override def write(batch: ColumnarBatch): Unit = {
    val maxRecordsPerFile = description.maxRecordsPerFile
    val recordsInFile = currentWriterStatus.recordsInFile
    if (!shouldSplitToFitMaxRecordsPerFile(
        maxRecordsPerFile, recordsInFile, batch.numRows())) {
      writeUpdateMetricsAndClose(
        SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY),
        currentWriterStatus)
    } else {
      val partBatches = splitToFitMaxRecordsAndClose(
        batch, maxRecordsPerFile, recordsInFile)
      val needNewWriterForFirstPart = recordsInFile >= maxRecordsPerFile
      closeOnExcept(partBatches) { _ =>
        partBatches.zipWithIndex.foreach { case (partBatch, partIx) =>
          if (partIx > 0 || needNewWriterForFirstPart) {
            currentWriterStatus.fileCounter += 1
            val fileCounter = currentWriterStatus.fileCounter
            assert(fileCounter <= MAX_FILE_COUNTER,
              s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
            newOutputWriter()
          }
          // null out the entry so that we don't double close
          partBatches(partIx) = null
          writeUpdateMetricsAndClose(partBatch, currentWriterStatus)
        }
      }
    }
  }
}

/**
 * Dynamic partition writer with single writer, meaning only one writer is opened at any time
 * for writing, meaning this single function can write to multiple directories (partitions)
 * or files (bucketing). The data to be written are required to be sorted on partition and/or
 * bucket column(s) before writing.
 */
class GpuDynamicPartitionDataSingleWriter(
    description: GpuWriteJobDescription,
    taskAttemptContext: TaskAttemptContext,
    committer: FileCommitProtocol)
  extends GpuFileFormatDataWriter(description, taskAttemptContext, committer) {
  /** Wrapper class to index a unique concurrent output writer. */
  protected class WriterIndex(
      var partitionPath: Option[String],
      var bucketId: Option[Int]) extends Product2[Option[String], Option[Int]] {

    override def hashCode(): Int = ScalaMurmur3Hash.productHash(this)

    override def equals(obj: Any): Boolean = {
      if (obj.isInstanceOf[WriterIndex]) {
        val otherWI = obj.asInstanceOf[WriterIndex]
        partitionPath == otherWI.partitionPath && bucketId == otherWI.bucketId
      } else {
        false
      }
    }

    override def _1: Option[String] = partitionPath
    override def _2: Option[Int] = bucketId
    override def canEqual(that: Any): Boolean = that.isInstanceOf[WriterIndex]
  }

  /**
   * A case class to hold the batch, the optional partition path and the optional bucket
   * ID for a split group. All the rows in the batch belong to the group defined by the
   * partition path and the bucket ID.
   */
  private case class SplitPack(split: SpillableColumnarBatch, path: Option[String],
      bucketId: Option[Int]) extends AutoCloseable {
    override def close(): Unit = {
      split.safeClose()
    }
  }
  /**
   * The index for current writer. Intentionally make the index mutable and reusable.
   * Avoid JVM GC issue when many short-living `WriterIndex` objects are created
   * if switching between concurrent writers frequently.
   */
  private val currentWriterId: WriterIndex = new WriterIndex(None, None)

  /** Flag saying whether or not the data to be written out is partitioned. */
  protected val isPartitioned: Boolean = description.partitionColumns.nonEmpty

  /** Flag saying whether or not the data to be written out is bucketed. */
  protected val isBucketed: Boolean = description.bucketSpec.isDefined

  assert(isPartitioned || isBucketed,
    s"""GpuDynamicPartitionWriteTask should be used for writing out data that's either
       |partitioned or bucketed. In this case neither is true.
       |GpuWriteJobDescription: $description
     """.stripMargin)

  // All data is sorted ascending with default null ordering
  private val nullsSmallest = Ascending.defaultNullOrdering == NullsFirst

  /** Extracts the partition values out of an input batch. */
  private lazy val getPartitionColumnsAsBatch: ColumnarBatch => ColumnarBatch = {
    val expressions = GpuBindReferences.bindGpuReferences(
      description.partitionColumns,
      description.allColumns)
    cb => {
      GpuProjectExec.project(cb, expressions)
    }
  }

  private lazy val getBucketIdColumnAsBatch: ColumnarBatch => ColumnarBatch = {
    val expressions = GpuBindReferences.bindGpuReferences(
      Seq(description.bucketSpec.get.bucketIdExpression),
      description.allColumns)
    cb => {
      GpuProjectExec.project(cb, expressions)
    }
  }

  /**
   * Expression that given partition columns builds a path string like: col1=val/col2=val/...
   * This is used after we pull the unique partition values back to the host.
   */
  private lazy val partitionPathExpression: Expression = Concat(
    description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
      val partitionName = ScalaUDF(
        ExternalCatalogUtils.getPartitionPathString _,
        StringType,
        Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))))
      if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
    })

  /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
   * the partition string.
   */
  private lazy val getPartitionPath: InternalRow => String = {
    val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns)
    row => proj(row).getString(0)
  }

  /** Extracts the output values of an input batch. */
  protected lazy val getDataColumnsAsBatch: ColumnarBatch => ColumnarBatch = {
    val expressions = GpuBindReferences.bindGpuReferences(
      description.dataColumns,
      description.allColumns)
    cb => {
      GpuProjectExec.project(cb, expressions)
    }
  }

  protected def getKeysBatch(cb: ColumnarBatch): ColumnarBatch = {
    val keysBatch = withResource(getPartitionColumnsAsBatch(cb)) { partCb =>
      if (isBucketed) {
        withResource(getBucketIdColumnAsBatch(cb)) { bucketIdCb =>
          GpuColumnVector.combineColumns(partCb, bucketIdCb)
        }
      } else {
        GpuColumnVector.incRefCounts(partCb)
      }
    }
    require(keysBatch.numCols() > 0, "No sort key is specified")
    keysBatch
  }

  protected def genGetBucketIdFunc(keyHostCb: ColumnarBatch): Int => Option[Int] = {
    if (isBucketed) {
      // The last column is the bucket id column
      val bucketIdCol = keyHostCb.column(keyHostCb.numCols() - 1)
      i => Some(bucketIdCol.getInt(i))
    } else {
      _ => None
    }
  }

  protected def genGetPartitionPathFunc(keyHostCb: ColumnarBatch): Int => Option[String] = {
    if (isPartitioned) {
      // Use the existing code to convert each row into a path. It would be nice to do this
      // on the GPU, but the data should be small and there are things we cannot easily
      // support on the GPU right now
      import scala.collection.JavaConverters._
      val partCols = description.partitionColumns.indices.map(keyHostCb.column)
      val iter = new ColumnarBatch(partCols.toArray, keyHostCb.numRows()).rowIterator()
        .asScala.map(getPartitionPath)
      _ => Some(iter.next)
    } else {
      _ => None
    }
  }

  // distinct value sorted the same way the input data is sorted.
  private def distinctAndSort(t: Table): Table = {
    val columnIds = 0 until t.getNumberOfColumns
    withResource(t.groupBy(columnIds: _*).aggregate()) { distinct =>
      distinct.orderBy(columnIds.map(OrderByArg.asc(_, nullsSmallest)): _*)
    }
  }

  // Get the split indexes for t given the keys we want to split on
  private def splitIndexes(t: Table, keys: Table): Array[Int] = {
    val nullsSmallestArray = Array.fill[Boolean](t.getNumberOfColumns)(nullsSmallest)
    val desc = Array.fill[Boolean](t.getNumberOfColumns)(false)
    withResource(t.upperBound(nullsSmallestArray, keys, desc)) { cv =>
      GpuColumnVector.toIntArray(cv)
    }
  }

  // Convert a table to a ColumnarBatch on the host, so we can iterate through it.
  protected def copyToHostAsBatch(input: Table, colTypes: Array[DataType]): ColumnarBatch = {
    withResource(GpuColumnVector.from(input, colTypes)) { tmp =>
      new ColumnarBatch(GpuColumnVector.extractColumns(tmp).safeMap(_.copyToHost()), tmp.numRows())
    }
  }

  /**
   * Split a batch according to the sorted keys (partitions + bucket ids).
   * Returns a tuple with an array of the splits as `ContiguousTable`'s, an array of
   * paths and bucket ids to use to write each partition and(or) bucket file.
   */
  private def splitBatchByKeyAndClose(batch: ColumnarBatch): Array[SplitPack] = {
    val (keysCb, dataCb) = withResource(batch) { _ =>
      closeOnExcept(getDataColumnsAsBatch(batch)) { data =>
        (getKeysBatch(batch), data)
      }
    }
    val (keyHostCb, splitIds) = closeOnExcept(dataCb) { _ =>
      val (splitIds, distinctKeysTbl, keysCbTypes) = withResource(keysCb) { _ =>
        val keysCbTypes = GpuColumnVector.extractTypes(keysCb)
        withResource(GpuColumnVector.from(keysCb)) { keysTable =>
          closeOnExcept(distinctAndSort(keysTable)) { distinctKeysTbl =>
            (splitIndexes(keysTable, distinctKeysTbl), distinctKeysTbl, keysCbTypes)
          }
        }
      }
      withResource(distinctKeysTbl) { _ =>
        (copyToHostAsBatch(distinctKeysTbl, keysCbTypes), splitIds)
      }
    }
    val splits = closeOnExcept(keyHostCb) { _ =>
      val scbOutput = closeOnExcept(dataCb)( _ =>
        SpillableColumnarBatch(dataCb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
      withRetryNoSplit(scbOutput) { scb =>
        withResource(scb.getColumnarBatch()) { outCb =>
          withResource(GpuColumnVector.from(outCb)) { outputColumnsTbl =>
            withResource(outputColumnsTbl) { _ =>
              outputColumnsTbl.contiguousSplit(splitIds: _*)
            }
          }
        }
      }
    }
    // Build the split result
    withResource(splits) { _ =>
      withResource(keyHostCb) { _ =>
        val getBucketId = genGetBucketIdFunc(keyHostCb)
        val getNextPartPath = genGetPartitionPathFunc(keyHostCb)
        val outDataTypes = description.dataColumns.map(_.dataType).toArray
        (0 until keyHostCb.numRows()).safeMap { idx =>
          val split = splits(idx)
          splits(idx) = null
          closeOnExcept(split) { _ =>
            SplitPack(
              SpillableColumnarBatch(split, outDataTypes,
                SpillPriorities.ACTIVE_BATCHING_PRIORITY),
              getNextPartPath(idx), getBucketId(idx))
          }
        }.toArray
      }
    }
  }

  /**
   * Create a new writer according to the given writer id, and update the given
   * writer status. It also closes the old writer in the writer status by default.
   */
  protected final def renewOutWriter(newWriterId: WriterIndex, curWriterStatus: WriterAndStatus,
      closeOldWriter: Boolean = true): Unit = {
    if (closeOldWriter) {
      releaseOutWriter(curWriterStatus)
    }
    curWriterStatus.recordsInFile = 0
    curWriterStatus.writer = newWriter(newWriterId.partitionPath, newWriterId.bucketId,
      curWriterStatus.fileCounter)
  }

  /**
   * Set up a writer to the given writer status for the given writer id.
   * It will create a new one if needed. This is used when seeing a new partition
   * and(or) a new bucket id.
   */
  protected def setupCurrentWriter(newWriterId: WriterIndex, curWriterStatus: WriterAndStatus,
      closeOldWriter: Boolean = true): Unit = {
    renewOutWriter(newWriterId, curWriterStatus, closeOldWriter)
  }

  /**
   * Opens a new OutputWriter given a partition key and/or a bucket id.
   * If bucket id is specified, we will append it to the end of the file name, but before the
   * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
   *
   * @param partDir     the partition directory
   * @param bucketId    the bucket which all tuples being written by this OutputWriter belong to,
   *                    currently does not support `bucketId`, it's always None
   * @param fileCounter integer indicating the number of files to be written to `partDir`
   */
  @scala.annotation.nowarn(
    "msg=method newTaskTempFile.* in class FileCommitProtocol is deprecated"
  )
  def newWriter(partDir: Option[String], bucketId: Option[Int],
      fileCounter: Int): ColumnarOutputWriter = {
    partDir.foreach(updatedPartitions.add)
    // Currently will be empty
    val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")

    // This must be in a form that matches our bucketing format. See BucketingUtils.
    val ext = f"$bucketIdStr.c$fileCounter%03d" +
      description.outputWriterFactory.getFileExtension(taskAttemptContext)

    val customPath = partDir.flatMap { dir =>
      description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
    }

    val currentPath = if (customPath.isDefined) {
      committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
    } else {
      committer.newTaskTempFile(taskAttemptContext, partDir, ext)
    }

    val outWriter = description.outputWriterFactory.newInstance(
      path = currentPath,
      dataSchema = description.dataColumns.toStructType,
      context = taskAttemptContext)

    statsTrackers.foreach(_.newFile(currentPath))
    outWriter
  }

  protected final def writeBatchPerMaxRecordsAndClose(scb: SpillableColumnarBatch,
      writerId: WriterIndex, writerStatus: WriterAndStatus): Unit = {
    val maxRecordsPerFile = description.maxRecordsPerFile
    val recordsInFile = writerStatus.recordsInFile

    if (!shouldSplitToFitMaxRecordsPerFile(maxRecordsPerFile, recordsInFile, scb.numRows())) {
      writeUpdateMetricsAndClose(scb, writerStatus)
    } else {
      val batch = withRetryNoSplit(scb) { scb =>
        scb.getColumnarBatch()
      }
      val splits = splitToFitMaxRecordsAndClose(batch, maxRecordsPerFile, recordsInFile)
      withResource(splits) { _ =>
        val needNewWriterForFirstPart = recordsInFile >= maxRecordsPerFile
        splits.zipWithIndex.foreach { case (part, partIx) =>
          if (partIx > 0 || needNewWriterForFirstPart) {
            writerStatus.fileCounter += 1
            assert(writerStatus.fileCounter <= MAX_FILE_COUNTER,
              s"File counter ${writerStatus.fileCounter} is beyond max value $MAX_FILE_COUNTER")
            // will create a new file, so close the old writer
            renewOutWriter(writerId, writerStatus)
          }
          splits(partIx) = null
          writeUpdateMetricsAndClose(part, writerStatus)
        }
      }
    }
  }

  /**
   * Called just before updating the current writer status when seeing a new partition
   * or a bucket.
   *
   * @param curWriterId the current writer index
   */
  protected def preUpdateCurrentWriterStatus(curWriterId: WriterIndex): Unit ={}

  override def write(batch: ColumnarBatch): Unit = {
    // The input batch that is entirely sorted, so split it up by partitions and (or)
    // bucket ids, and write the split batches one by one.
    withResource(splitBatchByKeyAndClose(batch)) { splitPacks =>
      splitPacks.zipWithIndex.foreach { case (SplitPack(sp, partPath, bucketId), i) =>
        val hasDiffPart = partPath != currentWriterId.partitionPath
        val hasDiffBucket = bucketId != currentWriterId.bucketId
        if (hasDiffPart || hasDiffBucket) {
          preUpdateCurrentWriterStatus(currentWriterId)
          if (hasDiffPart) {
            currentWriterId.partitionPath = partPath
            statsTrackers.foreach(_.newPartition())
          }
          if (hasDiffBucket) {
            currentWriterId.bucketId = bucketId
          }
          currentWriterStatus.fileCounter = 0
          setupCurrentWriter(currentWriterId, currentWriterStatus)
        }
        splitPacks(i) = null
        writeBatchPerMaxRecordsAndClose(sp, currentWriterId, currentWriterStatus)
      }
    }
  }
}

/**
 * Dynamic partition writer with concurrent writers, meaning multiple concurrent
 * writers are opened for writing.
 *
 * The process has the following steps:
 *  - Step 1: Maintain a map of output writers per each partition columns. Keep all
 *    writers opened; Cache the inputted batches by splitting them into sub-groups and
 *    each partition holds a list of spillable sub-groups; Find and write the max pending
 *    partition data if the total caches exceed the limitation.
 *  - Step 2: If number of concurrent writers exceeds limit, fall back to sort-based write
 *    (`GpuDynamicPartitionDataSingleWriter`), sort rest of batches on partition.
 *    Write batch by batch, and eagerly close the writer when finishing
 *    Caller is expected to call `writeWithIterator()` instead of `write()` to write records.
 *    Note: when fall back to `GpuDynamicPartitionDataSingleWriter`, the single writer should
 *    restore un-closed writers and should handle un-flushed spillable caches.
 */
class GpuDynamicPartitionDataConcurrentWriter(
    description: GpuWriteJobDescription,
    taskAttemptContext: TaskAttemptContext,
    committer: FileCommitProtocol,
    spec: GpuConcurrentOutputWriterSpec)
  extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer)
  with Logging {

  /** Wrapper class for status and caches of a unique concurrent output writer. */
  private class WriterStatusWithBatches extends WriterAndStatus with AutoCloseable {
    // caches for this partition or writer
    val tableCaches: ListBuffer[SpillableColumnarBatch] = ListBuffer()

    // current device bytes for the above caches
    var deviceBytes: Long = 0

    override def close(): Unit = try {
      releaseOutWriter(this)
    } finally {
      tableCaches.safeClose()
      tableCaches.clear()
    }
  }

  // Keep all the unclosed writers, key is a partition path and(or) bucket id.
  // Note: if fall back to sort-based mode, also use the opened writers in the map.
  private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatusWithBatches]()

  private val partitionFlushSize =
    if (description.concurrentWriterPartitionFlushSize <= 0) {
      // if the property is equal or less than 0, use default value given by the
      // writer factory
      description.outputWriterFactory.partitionFlushSize(taskAttemptContext)
    } else {
      // if the property is greater than 0, use the property value
      description.concurrentWriterPartitionFlushSize
    }

  // Pending split batches that are not cached for the concurrent write because
  // there are too many open writers, and it is going to fall back to the sorted
  // sequential write.
  private val pendingBatches: mutable.Queue[SpillableColumnarBatch] = mutable.Queue.empty

  override def writeWithIterator(iterator: Iterator[ColumnarBatch]): Unit = {
    // 1: try concurrent writer
    while (iterator.hasNext && pendingBatches.isEmpty) {
      // concurrent write and update the `concurrentWriters` map.
      this.write(iterator.next())
    }

    // 2: fall back to single write if the input is not all consumed.
    if (pendingBatches.nonEmpty || iterator.hasNext) {
      // sort the all the pending batches and ones in `iterator`
      val pendingCbsIter = new Iterator[ColumnarBatch] {
        override def hasNext: Boolean = pendingBatches.nonEmpty

        override def next(): ColumnarBatch = {
          if (!hasNext) {
            throw new NoSuchElementException()
          }
          withResource(pendingBatches.dequeue())(_.getColumnarBatch())
        }
      }
      val sortIter = GpuOutOfCoreSortIterator(pendingCbsIter ++ iterator,
        new GpuSorter(spec.sortOrder, spec.output), GpuSortExec.targetSize(spec.batchSize),
        NoopMetric, NoopMetric, NoopMetric, NoopMetric)
      while (sortIter.hasNext) {
        // write with sort-based sequential writer
        super.write(sortIter.next())
      }
    }
  }

  /** This is for the fallback case, used to clean the writers map. */
  override def preUpdateCurrentWriterStatus(curWriterId: WriterIndex): Unit = {
    concurrentWriters.remove(curWriterId)
  }

  /** This is for the fallback case, try to find the writer from cache first. */
  override def setupCurrentWriter(newWriterId: WriterIndex, writerStatus: WriterAndStatus,
      closeOldWriter: Boolean): Unit = {
    if (closeOldWriter) {
      releaseOutWriter(writerStatus)
    }
    val oOpenStatus = concurrentWriters.get(newWriterId)
    if (oOpenStatus.isDefined) {
      val openStatus = oOpenStatus.get
      writerStatus.writer = openStatus.writer
      writerStatus.recordsInFile = openStatus.recordsInFile
      writerStatus.fileCounter = openStatus.fileCounter
    } else {
      super.setupCurrentWriter(newWriterId, writerStatus, closeOldWriter = false)
    }
  }

  /**
   * The write path of concurrent writers
   *
   * @param cb the columnar batch to be written
   */
  override def write(cb: ColumnarBatch): Unit = {
    if (cb.numRows() == 0) {
      // TODO https://github.com/NVIDIA/spark-rapids/issues/6453
      // To solve above issue, I assume that an empty batch will be wrote for saving metadata.
      // If the assumption it's true, this concurrent writer should write the metadata here,
      // and should not run into below splitting and caching logic
      cb.close()
      return
    }

    // Split the batch and cache the result, along with opening the writers.
    splitBatchToCacheAndClose(cb)
    // Write the cached batches
    val writeFunc: (WriterIndex, WriterStatusWithBatches) => Unit =
      if (pendingBatches.nonEmpty) {
        // Flush all the caches before going into sorted sequential write
        writeOneCacheAndClose
      } else {
        // Still the concurrent write, so write out only partitions that size > threshold.
        (wi, ws) =>
          if (ws.deviceBytes > partitionFlushSize) {
            writeOneCacheAndClose(wi, ws)
          }
      }
    concurrentWriters.foreach { case (writerIdx, writerStatus) =>
      writeFunc(writerIdx, writerStatus)
    }
  }

  private def writeOneCacheAndClose(writerId: WriterIndex,
      status: WriterStatusWithBatches): Unit = {
    assert(status.tableCaches.nonEmpty)
    // Concat tables if needed
    val scbToWrite = GpuBatchUtils.concatSpillBatchesAndClose(status.tableCaches.toSeq).get
    status.tableCaches.clear()
    status.deviceBytes = 0
    writeBatchPerMaxRecordsAndClose(scbToWrite, writerId, status)
  }

  private def splitBatchToCacheAndClose(batch: ColumnarBatch): Unit = {
    // Split batch to groups by sort columns, [partition and(or) bucket id column].
    val (keysAndGroups, keyTypes) = withResource(batch) { _ =>
      val (opBatch, keyTypes) = withResource(getKeysBatch(batch)) { keysBatch =>
        val combinedCb = GpuColumnVector.combineColumns(keysBatch, batch)
        (combinedCb, GpuColumnVector.extractTypes(keysBatch))
      }
      withResource(opBatch) { _ =>
        withResource(GpuColumnVector.from(opBatch)) { opTable =>
          (opTable.groupBy(keyTypes.indices: _*).contiguousSplitGroupsAndGenUniqKeys(),
            keyTypes)
        }
      }
    }
    // Copy keys table to host and make group batches spillable
    val (keyHostCb, groups) = withResource(keysAndGroups) { _ =>
      // groups number should equal to uniq keys number
      assert(keysAndGroups.getGroups.length == keysAndGroups.getUniqKeyTable.getRowCount)
      closeOnExcept(copyToHostAsBatch(keysAndGroups.getUniqKeyTable, keyTypes)) { keyHostCb =>
        keysAndGroups.closeUniqKeyTable()
        val allTypes = description.allColumns.map(_.dataType).toArray
        val allColsIds = allTypes.indices.map(_ + keyTypes.length)
        val gps = keysAndGroups.getGroups.safeMap { gp =>
          withResource(gp.getTable) { gpTable =>
            withResource(new Table(allColsIds.map(gpTable.getColumn): _*)) { allTable =>
              SpillableColumnarBatch(GpuColumnVector.from(allTable, allTypes),
                SpillPriorities.ACTIVE_BATCHING_PRIORITY)
            }
          }
        }
        (keyHostCb, gps)
      }
    }
    // Cache the result to either the map or the pending queue.
    withResource(groups) { _ =>
      withResource(keyHostCb) { _ =>
        val getBucketId = genGetBucketIdFunc(keyHostCb)
        val getNextPartPath = genGetPartitionPathFunc(keyHostCb)
        var idx = 0
        while (idx < groups.length && concurrentWriters.size < spec.maxWriters) {
          val writerId = new WriterIndex(getNextPartPath(idx), getBucketId(idx))
          val writerStatus =
            concurrentWriters.getOrElseUpdate(writerId, new WriterStatusWithBatches)
          if (writerStatus.writer == null) {
            // a new partition or bucket, so create a writer
            renewOutWriter(writerId, writerStatus, closeOldWriter = false)
          }
          withResource(groups(idx)) { gp =>
            groups(idx) = null
            withResource(gp.getColumnarBatch()) { cb =>
              val dataScb = SpillableColumnarBatch(getDataColumnsAsBatch(cb),
                SpillPriorities.ACTIVE_BATCHING_PRIORITY)
              writerStatus.tableCaches.append(dataScb)
              writerStatus.deviceBytes += dataScb.sizeInBytes
            }
          }
          idx += 1
        }
        if (idx < groups.length) {
          // The open writers number reaches the limit, and still some partitions are
          // not cached. Append to the queue for the coming fallback to the sorted
          // sequential write.
          groups.drop(idx).foreach(g => pendingBatches.enqueue(g))
          // Set to null to avoid double close
          (idx until groups.length).foreach(groups(_) = null)
          logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches " +
            "the threshold. Fall back from concurrent writers to sort-based sequential" +
            " writer.")
        }
      }
    }
  }

  /** Release all resources. */
  override def releaseResources(): Unit = {
    pendingBatches.safeClose()
    pendingBatches.clear()

    // write all caches
    concurrentWriters.foreach { case (wi, ws) =>
      if (ws.tableCaches.nonEmpty) {
        writeOneCacheAndClose(wi, ws)
      }
    }

    // close all resources
    concurrentWriters.values.toSeq.safeClose()
    concurrentWriters.clear()
    super.releaseResources()
  }
}

/**
 * Bucketing specification for all the write tasks.
 * This is the GPU version of `org.apache.spark.sql.execution.datasources.WriterBucketSpec`
 * @param bucketIdExpression Expression to calculate bucket id based on bucket column(s).
 * @param bucketFileNamePrefix Prefix of output file name based on bucket id.
 */
case class GpuWriterBucketSpec(
  bucketIdExpression: GpuExpression,
  bucketFileNamePrefix: Int => String)

/**
 * A shared job description for all the GPU write tasks.
 * This is the GPU version of `org.apache.spark.sql.execution.datasources.WriteJobDescription`.
 */
class GpuWriteJobDescription(
    val uuid: String, // prevent collision between different (appending) write jobs
    val serializableHadoopConf: SerializableConfiguration,
    val outputWriterFactory: ColumnarOutputWriterFactory,
    val allColumns: Seq[Attribute],
    val dataColumns: Seq[Attribute],
    val partitionColumns: Seq[Attribute],
    val bucketSpec: Option[GpuWriterBucketSpec],
    val path: String,
    val customPartitionLocations: Map[TablePartitionSpec, String],
    val maxRecordsPerFile: Long,
    val timeZoneId: String,
    val statsTrackers: Seq[ColumnarWriteJobStatsTracker],
    val concurrentWriterPartitionFlushSize: Long)
  extends Serializable {

  assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
    s"""
         |All columns: ${allColumns.mkString(", ")}
         |Partition columns: ${partitionColumns.mkString(", ")}
         |Data columns: ${dataColumns.mkString(", ")}
       """.stripMargin)
}

object BucketIdMetaUtils {
  // Tag for the bucketing write using Spark Murmur3Hash
  def tagForBucketingWrite(meta: RapidsMeta[_, _, _], bucketSpec: Option[BucketSpec],
      outputColumns: Seq[Attribute]): Unit = {
    bucketSpec.foreach { bSpec =>
      // Create a Murmur3Hash expression to leverage the overriding types check.
      val expr = Murmur3Hash(
        bSpec.bucketColumnNames.map(n => outputColumns.find(_.name == n).get),
        GpuHashPartitioningBase.DEFAULT_HASH_SEED)
      val hashMeta = GpuOverrides.wrapExpr(expr, meta.conf, None)
      hashMeta.tagForGpu()
      if(!hashMeta.canThisBeReplaced) {
        meta.willNotWorkOnGpu(s"Murmur3 hashing for generating bucket IDs can not run" +
          s" on GPU. Details: ${hashMeta.explain(all=false)}")
      }
    }
  }

  def tagForBucketingHiveWrite(meta: RapidsMeta[_, _, _], bucketSpec: Option[BucketSpec],
      outputColumns: Seq[Attribute]): Unit = {
    bucketSpec.foreach { bSpec =>
      // Create a HiveHash expression to leverage the overriding types check.
      val expr = HiveHash(bSpec.bucketColumnNames.map(n => outputColumns.find(_.name == n).get))
      val hashMeta = GpuOverrides.wrapExpr(expr, meta.conf, None)
      hashMeta.tagForGpu()
      if (!hashMeta.canThisBeReplaced) {
        meta.willNotWorkOnGpu(s"Hive hashing for generating bucket IDs can not run" +
          s" on GPU. Details: ${hashMeta.explain(all = false)}")
      }
    }
  }

  def getWriteBucketSpecForHive(
      bucketCols: Seq[Attribute],
      numBuckets: Int): GpuWriterBucketSpec = {
    // Hive bucketed table: use "HiveHash" and bitwise-and as bucket id expression.
    // "bitwise-and" is used to handle the case of a wrong bucket id when
    // the hash value is negative.
    val hashId = GpuBitwiseAnd(GpuHiveHash(bucketCols), GpuLiteral(Int.MaxValue))
    val bucketIdExpression = GpuPmod(hashId, GpuLiteral(numBuckets))

    // The bucket file name prefix is following Hive, Presto and Trino conversion, then
    // Hive bucketed tables written by Plugin can be read by other SQL engines.
    val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
    GpuWriterBucketSpec(bucketIdExpression, fileNamePrefix)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy