All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.paimon.spark.commands.MergeIntoPaimonTable.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.paimon.spark.commands

import org.apache.paimon.spark.SparkTable
import org.apache.paimon.spark.catalyst.analysis.PaimonRelation
import org.apache.paimon.spark.leafnode.PaimonLeafRunnableCommand
import org.apache.paimon.spark.schema.{PaimonMetadataColumn, SparkSystemColumns}
import org.apache.paimon.spark.schema.PaimonMetadataColumn.{FILE_PATH, FILE_PATH_COLUMN, ROW_INDEX, ROW_INDEX_COLUMN}
import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils}
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.CommitMessage
import org.apache.paimon.types.RowKind

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.PaimonUtils._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum}
import org.apache.spark.sql.types.{ByteType, StructField, StructType}

import scala.collection.mutable

/** Command for Merge Into. */
case class MergeIntoPaimonTable(
    v2Table: SparkTable,
    targetTable: LogicalPlan,
    sourceTable: LogicalPlan,
    mergeCondition: Expression,
    matchedActions: Seq[MergeAction],
    notMatchedActions: Seq[MergeAction],
    notMatchedBySourceActions: Seq[MergeAction])
  extends PaimonLeafRunnableCommand
  with PaimonCommand {

  import MergeIntoPaimonTable._

  override val table: FileStoreTable = v2Table.getTable.asInstanceOf[FileStoreTable]

  lazy val relation: DataSourceV2Relation = PaimonRelation.getPaimonRelation(targetTable)

  lazy val tableSchema: StructType = v2Table.schema

  private lazy val writer = PaimonSparkWriter(table)

  private lazy val (targetOnlyCondition, filteredTargetPlan): (Option[Expression], LogicalPlan) = {
    val filtersOnlyTarget = getExpressionOnlyRelated(mergeCondition, targetTable)
    (
      filtersOnlyTarget,
      filtersOnlyTarget
        .map(Filter.apply(_, targetTable))
        .getOrElse(targetTable))
  }

  override def run(sparkSession: SparkSession): Seq[Row] = {
    // Avoid that more than one source rows match the same target row.
    checkMatchRationality(sparkSession)
    val commitMessages = if (withPrimaryKeys) {
      performMergeForPkTable(sparkSession)
    } else {
      performMergeForNonPkTable(sparkSession)
    }
    writer.commit(commitMessages)
    Seq.empty[Row]
  }

  private def performMergeForPkTable(sparkSession: SparkSession): Seq[CommitMessage] = {
    writer.write(
      constructChangedRows(
        sparkSession,
        createDataset(sparkSession, filteredTargetPlan),
        remainDeletedRow = true))
  }

  private def performMergeForNonPkTable(sparkSession: SparkSession): Seq[CommitMessage] = {
    val targetDS = createDataset(sparkSession, filteredTargetPlan)
    val sourceDS = createDataset(sparkSession, sourceTable)

    // Step1: get the candidate data splits which are filtered by Paimon Predicate.
    val candidateDataSplits =
      findCandidateDataSplits(targetOnlyCondition.getOrElse(TrueLiteral), relation.output)
    val dataFilePathToMeta = candidateFileMap(candidateDataSplits)

    if (deletionVectorsEnabled) {
      // Step2: generate dataset that should contains ROW_KIND, FILE_PATH, ROW_INDEX columns
      val metadataCols = Seq(FILE_PATH, ROW_INDEX)
      val filteredRelation = createDataset(
        sparkSession,
        createNewScanPlan(
          candidateDataSplits,
          targetOnlyCondition.getOrElse(TrueLiteral),
          relation,
          metadataCols))
      val ds = constructChangedRows(
        sparkSession,
        filteredRelation,
        remainDeletedRow = true,
        metadataCols = metadataCols)

      ds.cache()
      try {
        val rowKindAttribute = ds.queryExecution.analyzed.output
          .find(attr => sparkSession.sessionState.conf.resolver(attr.name, ROW_KIND_COL))
          .getOrElse(throw new RuntimeException("Can not find _row_kind_ column."))

        // Step3: filter rows that should be marked as DELETED in Deletion Vector mode.
        val dvDS = ds.where(
          s"$ROW_KIND_COL = ${RowKind.DELETE.toByteValue} or $ROW_KIND_COL = ${RowKind.UPDATE_AFTER.toByteValue}")
        val deletionVectors = collectDeletionVectors(dataFilePathToMeta, dvDS, sparkSession)
        val indexCommitMsg = writer.persistDeletionVectors(deletionVectors)

        // Step4: filter rows that should be written as the inserted/updated data.
        val toWriteDS = ds
          .where(
            s"$ROW_KIND_COL = ${RowKind.INSERT.toByteValue} or $ROW_KIND_COL = ${RowKind.UPDATE_AFTER.toByteValue}")
          .drop(FILE_PATH_COLUMN, ROW_INDEX_COLUMN)
        val addCommitMessage = writer.write(toWriteDS)

        // Step5: commit index and data commit messages
        addCommitMessage ++ indexCommitMsg
      } finally {
        ds.unpersist()
      }
    } else {
      val touchedFilePathsSet = mutable.Set.empty[String]
      def hasUpdate(actions: Seq[MergeAction]): Boolean = {
        actions.exists {
          case _: UpdateAction | _: DeleteAction => true
          case _ => false
        }
      }
      if (hasUpdate(matchedActions)) {
        touchedFilePathsSet ++= findTouchedFiles(
          targetDS.join(sourceDS, toColumn(mergeCondition), "inner"),
          sparkSession)
      }
      if (hasUpdate(notMatchedBySourceActions)) {
        touchedFilePathsSet ++= findTouchedFiles(
          targetDS.join(sourceDS, toColumn(mergeCondition), "left_anti"),
          sparkSession)
      }

      val targetFilePaths: Array[String] = findTouchedFiles(targetDS, sparkSession)
      val touchedFilePaths: Array[String] = touchedFilePathsSet.toArray
      val unTouchedFilePaths = targetFilePaths.filterNot(touchedFilePaths.contains)

      val (touchedFiles, touchedFileRelation) =
        createNewRelation(touchedFilePaths, dataFilePathToMeta, relation)
      val (_, unTouchedFileRelation) =
        createNewRelation(unTouchedFilePaths, dataFilePathToMeta, relation)

      // Add FILE_TOUCHED_COL to mark the row as coming from the touched file, if the row has not been
      // modified and was from touched file, it should be kept too.
      val targetDSWithFileTouchedCol = createDataset(sparkSession, touchedFileRelation)
        .withColumn(FILE_TOUCHED_COL, lit(true))
        .union(createDataset(sparkSession, unTouchedFileRelation)
          .withColumn(FILE_TOUCHED_COL, lit(false)))

      val toWriteDS =
        constructChangedRows(sparkSession, targetDSWithFileTouchedCol).drop(ROW_KIND_COL)
      val addCommitMessage = writer.write(toWriteDS)
      val deletedCommitMessage = buildDeletedCommitMessage(touchedFiles)

      addCommitMessage ++ deletedCommitMessage
    }
  }

  /** Get a Dataset where each of Row has an additional column called _row_kind_. */
  private def constructChangedRows(
      sparkSession: SparkSession,
      targetDataset: Dataset[Row],
      remainDeletedRow: Boolean = false,
      deletionVectorEnabled: Boolean = false,
      metadataCols: Seq[PaimonMetadataColumn] = Seq.empty): Dataset[Row] = {
    val targetDS = targetDataset
      .withColumn(TARGET_ROW_COL, lit(true))

    val sourceDS = createDataset(sparkSession, sourceTable)
      .withColumn(SOURCE_ROW_COL, lit(true))

    val joinedDS = sourceDS.join(targetDS, toColumn(mergeCondition), "fullOuter")
    val joinedPlan = joinedDS.queryExecution.analyzed

    def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
      resolveExpressions(sparkSession)(exprs, joinedPlan)
    }

    val targetOutput = filteredTargetPlan.output
    val targetRowNotMatched = resolveOnJoinedPlan(
      Seq(toExpression(sparkSession, col(SOURCE_ROW_COL).isNull))).head
    val sourceRowNotMatched = resolveOnJoinedPlan(
      Seq(toExpression(sparkSession, col(TARGET_ROW_COL).isNull))).head
    val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral))
    val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral))
    val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral))
    val noopOutput = targetOutput :+ Alias(Literal(NOOP_ROW_KIND_VALUE), ROW_KIND_COL)()
    val keepOutput = targetOutput :+ Alias(Literal(RowKind.INSERT.toByteValue), ROW_KIND_COL)()

    val resolver = sparkSession.sessionState.conf.resolver
    val metadataAttributes = metadataCols.flatMap {
      metadataCol => joinedPlan.output.find(attr => resolver(metadataCol.name, attr.name))
    }
    def processMergeActions(actions: Seq[MergeAction]): Seq[Seq[Expression]] = {
      val columnExprs = actions.map {
        case UpdateAction(_, assignments) =>
          assignments.map(_.value) :+ Literal(RowKind.UPDATE_AFTER.toByteValue)
        case DeleteAction(_) =>
          if (remainDeletedRow || deletionVectorEnabled) {
            targetOutput :+ Literal(RowKind.DELETE.toByteValue)
          } else {
            // If RowKind = NOOP_ROW_KIND_VALUE, then these rows will be dropped in MergeIntoProcessor.processPartition by default.
            // If these rows still need to be remained, set MergeIntoProcessor.remainNoopRow true.
            noopOutput
          }
        case InsertAction(_, assignments) =>
          assignments.map(_.value) :+ Literal(RowKind.INSERT.toByteValue)
      }
      columnExprs.map(exprs => exprs ++ metadataAttributes)
    }

    val matchedOutputs = processMergeActions(matchedActions)
    val notMatchedBySourceOutputs = processMergeActions(notMatchedBySourceActions)
    val notMatchedOutputs = processMergeActions(notMatchedActions)
    val outputFields = mutable.ArrayBuffer(tableSchema.fields: _*)
    outputFields += StructField(ROW_KIND_COL, ByteType)
    outputFields ++= metadataCols.map(_.toStructField)
    val outputSchema = StructType(outputFields.toSeq)

    val joinedRowEncoder = EncoderUtils.encode(joinedPlan.schema)
    val outputEncoder = EncoderUtils.encode(outputSchema).resolveAndBind()

    val processor = MergeIntoProcessor(
      joinedPlan.output,
      targetRowNotMatched,
      sourceRowNotMatched,
      matchedExprs,
      matchedOutputs,
      notMatchedBySourceExprs,
      notMatchedBySourceOutputs,
      notMatchedExprs,
      notMatchedOutputs,
      noopOutput,
      keepOutput,
      joinedRowEncoder,
      outputEncoder
    )
    joinedDS.mapPartitions(processor.processPartition)(outputEncoder).toDF()
  }

  private def checkMatchRationality(sparkSession: SparkSession): Unit = {
    if (matchedActions.nonEmpty) {
      val targetDS = createDataset(sparkSession, filteredTargetPlan)
        .withColumn(ROW_ID_COL, monotonically_increasing_id())
      val sourceDS = createDataset(sparkSession, sourceTable)
      val count = sourceDS
        .join(targetDS, toColumn(mergeCondition), "inner")
        .select(col(ROW_ID_COL), lit(1).as("one"))
        .groupBy(ROW_ID_COL)
        .agg(sum("one").as("count"))
        .filter("count > 1")
        .count()
      if (count > 0) {
        throw new RuntimeException(
          "Can't execute this MergeInto when there are some target rows that each of " +
            "them match more then one source rows. It may lead to an unexpected result.")
      }
    }
  }
}

object MergeIntoPaimonTable {
  private val ROW_ID_COL = "_row_id_"
  private val SOURCE_ROW_COL = "_source_row_"
  private val TARGET_ROW_COL = "_target_row_"
  private val FILE_TOUCHED_COL = "_file_touched_col_"
  // +I, +U, -U, -D
  private val ROW_KIND_COL: String = SparkSystemColumns.ROW_KIND_COL
  private val NOOP_ROW_KIND_VALUE: Byte = "-1".toByte

  private case class MergeIntoProcessor(
      joinedAttributes: Seq[Attribute],
      targetRowHasNoMatch: Expression,
      sourceRowHasNoMatch: Expression,
      matchedConditions: Seq[Expression],
      matchedOutputs: Seq[Seq[Expression]],
      notMatchedBySourceConditions: Seq[Expression],
      notMatchedBySourceOutputs: Seq[Seq[Expression]],
      notMatchedConditions: Seq[Expression],
      notMatchedOutputs: Seq[Seq[Expression]],
      noopCopyOutput: Seq[Expression],
      keepOutput: Seq[Expression],
      joinedRowEncoder: ExpressionEncoder[Row],
      outputRowEncoder: ExpressionEncoder[Row]
  ) extends Serializable {

    private val rowKindColumnIndex: Int = outputRowEncoder.schema.fieldIndex(ROW_KIND_COL)

    private val fileTouchedColumnIndex: Int =
      SparkRowUtils.getFieldIndex(joinedRowEncoder.schema, FILE_TOUCHED_COL)

    private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = {
      UnsafeProjection.create(exprs, joinedAttributes)
    }

    private def generatePredicate(expr: Expression): BasePredicate = {
      GeneratePredicate.generate(expr, joinedAttributes)
    }

    private def fromTouchedFile(row: InternalRow): Boolean = {
      fileTouchedColumnIndex != -1 && row.getBoolean(fileTouchedColumnIndex)
    }

    private def unusedRow(row: InternalRow): Boolean = {
      row.getByte(rowKindColumnIndex) == NOOP_ROW_KIND_VALUE
    }

    def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
      val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch)
      val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch)
      val matchedPreds = matchedConditions.map(generatePredicate)
      val matchedProjs = matchedOutputs.map(generateProjection)
      val notMatchedBySourcePreds = notMatchedBySourceConditions.map(generatePredicate)
      val notMatchedBySourceProjs = notMatchedBySourceOutputs.map(generateProjection)
      val notMatchedPreds = notMatchedConditions.map(generatePredicate)
      val notMatchedProjs = notMatchedOutputs.map(generateProjection)
      val noopCopyProj = generateProjection(noopCopyOutput)
      val keepProj = generateProjection(keepOutput)
      val outputProj = UnsafeProjection.create(outputRowEncoder.schema)

      def processRow(inputRow: InternalRow): InternalRow = {
        def applyPreds(preds: Seq[BasePredicate], projs: Seq[UnsafeProjection]): InternalRow = {
          preds.zip(projs).find { case (predicate, _) => predicate.eval(inputRow) } match {
            case Some((_, projections)) =>
              projections.apply(inputRow)
            case None =>
              // keep the row if it is from touched file and not be matched
              if (fromTouchedFile(inputRow)) {
                keepProj.apply(inputRow)
              } else {
                noopCopyProj.apply(inputRow)
              }
          }
        }

        if (targetRowHasNoMatchPred.eval(inputRow)) {
          applyPreds(notMatchedBySourcePreds, notMatchedBySourceProjs)
        } else if (sourceRowHasNoMatchPred.eval(inputRow)) {
          applyPreds(notMatchedPreds, notMatchedProjs)
        } else {
          applyPreds(matchedPreds, matchedProjs)
        }
      }

      val toRow = joinedRowEncoder.createSerializer()
      val fromRow = outputRowEncoder.createDeserializer()
      rowIterator
        .map(toRow)
        .map(processRow)
        .filterNot(unusedRow)
        .map(notDeletedInternalRow => fromRow(outputProj(notDeletedInternalRow)))
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy