All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.paimon.spark.commands.PaimonSparkWriter.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.paimon.spark.commands

import org.apache.paimon.CoreOptions.WRITE_ONLY
import org.apache.paimon.codegen.CodeGenUtils
import org.apache.paimon.crosspartition.{IndexBootstrap, KeyPartOrRow}
import org.apache.paimon.data.serializer.InternalSerializers
import org.apache.paimon.deletionvectors.DeletionVector
import org.apache.paimon.deletionvectors.append.AppendDeletionFileMaintainer
import org.apache.paimon.index.{BucketAssigner, SimpleHashBucketAssigner}
import org.apache.paimon.io.{CompactIncrement, DataIncrement, IndexIncrement}
import org.apache.paimon.manifest.{FileKind, IndexManifestEntry}
import org.apache.paimon.spark.{SparkRow, SparkTableWrite, SparkTypeUtils}
import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, ROW_KIND_COL}
import org.apache.paimon.spark.util.SparkRowUtils
import org.apache.paimon.table.BucketMode._
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink._
import org.apache.paimon.types.{RowKind, RowType}
import org.apache.paimon.utils.SerializationUtils

import org.apache.spark.{Partitioner, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._

import java.io.IOException
import java.util.Collections.singletonMap

import scala.collection.JavaConverters._

case class PaimonSparkWriter(table: FileStoreTable) {

  private lazy val tableSchema = table.schema

  private lazy val rowType = table.rowType()

  private lazy val bucketMode = table.bucketMode

  @transient private lazy val serializer = new CommitMessageSerializer

  val writeBuilder: BatchWriteBuilder = table.newBatchWriteBuilder()

  def writeOnly(): PaimonSparkWriter = {
    PaimonSparkWriter(table.copy(singletonMap(WRITE_ONLY.key(), "true")))
  }

  def write(data: DataFrame): Seq[CommitMessage] = {
    val sparkSession = data.sparkSession
    import sparkSession.implicits._

    val withInitBucketCol = bucketMode match {
      case CROSS_PARTITION if !data.schema.fieldNames.contains(ROW_KIND_COL) =>
        data
          .withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
          .withColumn(BUCKET_COL, lit(-1))
      case _ => data.withColumn(BUCKET_COL, lit(-1))
    }
    val rowKindColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, ROW_KIND_COL)
    val bucketColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, BUCKET_COL)
    val encoderGroupWithBucketCol = EncoderSerDeGroup(withInitBucketCol.schema)

    def newWrite(): SparkTableWrite = new SparkTableWrite(writeBuilder, rowType, rowKindColIdx)

    def sparkParallelism = {
      val defaultParallelism = sparkSession.sparkContext.defaultParallelism
      val numShufflePartitions = sparkSession.sessionState.conf.numShufflePartitions
      Math.max(defaultParallelism, numShufflePartitions)
    }

    def writeWithoutBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    def writeWithBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row, row.getInt(bucketColIdx)))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    def writeWithBucketProcessor(
        dataFrame: DataFrame,
        processor: BucketProcessor[Row]): Dataset[Array[Byte]] = {
      val repartitioned = repartitionByPartitionsAndBucket(
        dataFrame
          .mapPartitions(processor.processPartition)(encoderGroupWithBucketCol.encoder)
          .toDF())
      writeWithBucket(repartitioned)
    }

    def writeWithBucketAssigner(
        dataFrame: DataFrame,
        funcFactory: () => Row => Int): Dataset[Array[Byte]] = {
      dataFrame.mapPartitions {
        iter =>
          {
            val assigner = funcFactory.apply()
            val write = newWrite()
            try {
              iter.foreach(row => write.write(row, assigner.apply(row)))
              write.finish()
            } finally {
              write.close()
            }
          }
      }
    }

    val written: Dataset[Array[Byte]] = bucketMode match {
      case CROSS_PARTITION =>
        // Topology: input -> bootstrap -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
        val rowType = SparkTypeUtils.toPaimonType(withInitBucketCol.schema).asInstanceOf[RowType]
        val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
          .map(_.toInt)
          .getOrElse(sparkParallelism)
        val bootstrapped = bootstrapAndRepartitionByKeyHash(
          withInitBucketCol,
          assignerParallelism,
          rowKindColIdx,
          rowType)

        val globalDynamicBucketProcessor =
          GlobalDynamicBucketProcessor(
            table,
            rowType,
            assignerParallelism,
            encoderGroupWithBucketCol)
        val repartitioned = repartitionByPartitionsAndBucket(
          sparkSession.createDataFrame(
            bootstrapped.mapPartitions(globalDynamicBucketProcessor.processPartition),
            withInitBucketCol.schema))

        writeWithBucket(repartitioned)

      case HASH_DYNAMIC =>
        val assignerParallelism = Option(table.coreOptions.dynamicBucketAssignerParallelism)
          .map(_.toInt)
          .getOrElse(sparkParallelism)
        val numAssigners = Option(table.coreOptions.dynamicBucketInitialBuckets)
          .map(initialBuckets => Math.min(initialBuckets.toInt, assignerParallelism))
          .getOrElse(assignerParallelism)

        def partitionByKey(): DataFrame = {
          repartitionByKeyPartitionHash(
            sparkSession,
            withInitBucketCol,
            assignerParallelism,
            numAssigners)
        }

        if (table.snapshotManager().latestSnapshot() == null) {
          // bootstrap mode
          // Topology: input -> shuffle by special key & partition hash -> bucket-assigner
          writeWithBucketAssigner(
            partitionByKey(),
            () => {
              val extractor = new RowPartitionKeyExtractor(table.schema)
              val assigner =
                new SimpleHashBucketAssigner(
                  numAssigners,
                  TaskContext.getPartitionId(),
                  table.coreOptions.dynamicBucketTargetRowNum)
              row => {
                val sparkRow = new SparkRow(rowType, row)
                assigner.assign(
                  extractor.partition(sparkRow),
                  extractor.trimmedPrimaryKey(sparkRow).hashCode)
              }
            }
          )
        } else {
          // Topology: input -> shuffle by special key & partition hash -> bucket-assigner -> shuffle by partition & bucket
          writeWithBucketProcessor(
            partitionByKey(),
            DynamicBucketProcessor(
              table,
              bucketColIdx,
              assignerParallelism,
              numAssigners,
              encoderGroupWithBucketCol)
          )
        }

      case BUCKET_UNAWARE =>
        // Topology: input ->
        writeWithoutBucket(data)

      case HASH_FIXED =>
        // Topology: input -> bucket-assigner -> shuffle by partition & bucket
        writeWithBucketProcessor(
          withInitBucketCol,
          CommonBucketProcessor(table, bucketColIdx, encoderGroupWithBucketCol))

      case _ =>
        throw new UnsupportedOperationException(s"Spark doesn't support $bucketMode mode.")
    }

    written
      .collect()
      .map(deserializeCommitMessage(serializer, _))
      .toSeq
  }

  /**
   * Write all the deletion vectors to the index files. If it's in unaware mode, one index file maps
   * deletion vectors; else, one index file will contains all deletion vector with the same
   * partition and bucket.
   */
  def persistDeletionVectors(deletionVectors: Dataset[SparkDeletionVectors]): Seq[CommitMessage] = {
    val sparkSession = deletionVectors.sparkSession
    import sparkSession.implicits._
    val snapshotId = table.snapshotManager().latestSnapshotId();
    val serializedCommits = deletionVectors
      .groupByKey(_.partitionAndBucket)
      .mapGroups {
        (_, iter: Iterator[SparkDeletionVectors]) =>
          val indexHandler = table.store().newIndexFileHandler()
          var dvIndexFileMaintainer: AppendDeletionFileMaintainer = null
          while (iter.hasNext) {
            val sdv: SparkDeletionVectors = iter.next()
            if (dvIndexFileMaintainer == null) {
              val partition = SerializationUtils.deserializeBinaryRow(sdv.partition)
              dvIndexFileMaintainer = if (bucketMode == BUCKET_UNAWARE) {
                AppendDeletionFileMaintainer.forUnawareAppend(indexHandler, snapshotId, partition)
              } else {
                AppendDeletionFileMaintainer.forBucketedAppend(
                  indexHandler,
                  snapshotId,
                  partition,
                  sdv.bucket)
              }
            }
            if (dvIndexFileMaintainer == null) {
              throw new RuntimeException("can't create the dv maintainer.")
            }

            sdv.dataFileAndDeletionVector.foreach {
              case (dataFileName, dv) =>
                dvIndexFileMaintainer.notifyNewDeletionVector(
                  dataFileName,
                  DeletionVector.deserializeFromBytes(dv))
            }
          }
          val indexEntries = dvIndexFileMaintainer.persist()

          val (added, deleted) = indexEntries.asScala.partition(_.kind() == FileKind.ADD)

          val commitMessage = new CommitMessageImpl(
            dvIndexFileMaintainer.getPartition,
            dvIndexFileMaintainer.getBucket,
            DataIncrement.emptyIncrement(),
            CompactIncrement.emptyIncrement(),
            new IndexIncrement(added.map(_.indexFile).asJava, deleted.map(_.indexFile).asJava)
          )
          val serializer = new CommitMessageSerializer
          serializer.serialize(commitMessage)
      }
    serializedCommits
      .collect()
      .map(deserializeCommitMessage(serializer, _))
  }

  def buildCommitMessageFromIndexManifestEntry(
      indexManifestEntries: Seq[IndexManifestEntry]): Seq[CommitMessage] = {
    indexManifestEntries
      .groupBy(entry => (entry.partition(), entry.bucket()))
      .map {
        case ((partition, bucket), entries) =>
          val (added, removed) = entries.partition(_.kind() == FileKind.ADD)
          new CommitMessageImpl(
            partition,
            bucket,
            DataIncrement.emptyIncrement(),
            CompactIncrement.emptyIncrement(),
            new IndexIncrement(added.map(_.indexFile()).asJava, removed.map(_.indexFile()).asJava))
      }
      .toSeq
  }

  def commit(commitMessages: Seq[CommitMessage]): Unit = {
    val tableCommit = writeBuilder.newCommit()
    try {
      tableCommit.commit(commitMessages.toList.asJava)
    } catch {
      case e: Throwable => throw new RuntimeException(e);
    } finally {
      tableCommit.close()
    }
  }

  /** Boostrap and repartition for cross partition mode. */
  private def bootstrapAndRepartitionByKeyHash(
      data: DataFrame,
      parallelism: Int,
      rowKindColIdx: Int,
      rowType: RowType): RDD[(KeyPartOrRow, Array[Byte])] = {
    val numSparkPartitions = data.rdd.getNumPartitions
    val primaryKeys = table.schema().primaryKeys()
    val bootstrapType = IndexBootstrap.bootstrapType(table.schema())
    data.rdd
      .mapPartitions {
        iter =>
          {
            val sparkPartitionId = TaskContext.getPartitionId()
            val keyPartProject = CodeGenUtils.newProjection(bootstrapType, primaryKeys)
            val rowProject = CodeGenUtils.newProjection(rowType, primaryKeys)
            val bootstrapSer = InternalSerializers.create(bootstrapType)
            val rowSer = InternalSerializers.create(rowType)
            val bootstrapIterator = new IndexBootstrap(table)
              .bootstrap(numSparkPartitions, sparkPartitionId)
              .toCloseableIterator
            TaskContext.get().addTaskCompletionListener[Unit](_ => bootstrapIterator.close())

            bootstrapIterator.asScala
              .map(
                row => {
                  val bytes: Array[Byte] =
                    SerializationUtils.serializeBinaryRow(bootstrapSer.toBinaryRow(row))
                  (Math.abs(keyPartProject(row).hashCode()), (KeyPartOrRow.KEY_PART, bytes))
                }) ++ iter.map(
              r => {
                val sparkRow =
                  new SparkRow(rowType, r, SparkRowUtils.getRowKind(r, rowKindColIdx))
                val bytes: Array[Byte] =
                  SerializationUtils.serializeBinaryRow(rowSer.toBinaryRow(sparkRow))
                (Math.abs(rowProject(sparkRow).hashCode()), (KeyPartOrRow.ROW, bytes))
              })
          }
      }
      .partitionBy(ModPartitioner(parallelism))
      .map(_._2)
  }

  /** Repartition for dynamic bucket mode. */
  private def repartitionByKeyPartitionHash(
      sparkSession: SparkSession,
      data: DataFrame,
      parallelism: Int,
      numAssigners: Int): DataFrame = {
    sparkSession.createDataFrame(
      data.rdd
        .mapPartitions(
          iterator => {
            val rowPartitionKeyExtractor = new RowPartitionKeyExtractor(tableSchema)
            iterator.map(
              row => {
                val sparkRow = new SparkRow(rowType, row)
                val partitionHash = rowPartitionKeyExtractor.partition(sparkRow).hashCode
                val keyHash = rowPartitionKeyExtractor.trimmedPrimaryKey(sparkRow).hashCode
                (
                  BucketAssigner
                    .computeHashKey(partitionHash, keyHash, parallelism, numAssigners),
                  row)
              })
          })
        .partitionBy(ModPartitioner(parallelism))
        .map(_._2),
      data.schema
    )
  }

  private def repartitionByPartitionsAndBucket(df: DataFrame): DataFrame = {
    val partitionCols = tableSchema.partitionKeys().asScala.map(col).toSeq
    df.repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
  }

  private def deserializeCommitMessage(
      serializer: CommitMessageSerializer,
      bytes: Array[Byte]): CommitMessage = {
    try {
      serializer.deserialize(serializer.getVersion, bytes)
    } catch {
      case e: IOException =>
        throw new RuntimeException("Failed to deserialize CommitMessage's object", e)
    }
  }

  private case class ModPartitioner(partitions: Int) extends Partitioner {
    override def numPartitions: Int = partitions
    override def getPartition(key: Any): Int = key.asInstanceOf[Int] % numPartitions
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy