com.netease.arctic.spark.sql.execution.ExtendedV2ExistingTableWriteExec.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.arctic.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, PhysicalWriteInfoImpl, WriterCommitMessage}
import org.apache.spark.sql.execution.datasources.v2.{StreamWriterCommitProgress, V2CommandExec}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.util.LongAccumulator
import org.apache.spark.{SparkException, TaskContext}
import scala.util.control.NonFatal
trait ExtendedV2ExistingTableWriteExec[W <: DataWriter[InternalRow]] extends V2CommandExec with UnaryExecNode with Serializable {
def writingTask: WritingSparkTask[W]
def query: SparkPlan
var commitProgress: Option[StreamWriterCommitProgress] = None
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = {
val rdd: RDD[InternalRow] = {
val tempRdd = query.execute()
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
// partition rdd to make sure we at least set up one write task to write the metadata.
if (tempRdd.partitions.length == 0) {
sparkContext.parallelize(Seq.empty[InternalRow], 1)
} else {
tempRdd
}
}
val task = writingTask
// introduce a local var to avoid serializing the whole class
val writerFactory = batchWrite.createBatchWriterFactory(
PhysicalWriteInfoImpl(rdd.getNumPartitions))
val useCommitCoordinator = batchWrite.useCommitCoordinator
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
val totalNumRowsAccumulator = new LongAccumulator()
logInfo(s"Start processing data source write support: $batchWrite. " +
s"The input RDD has ${messages.length} partitions.")
try {
sparkContext.runJob(
rdd,
(context: TaskContext, iter: Iterator[InternalRow]) =>
task.run(writerFactory, context, iter, useCommitCoordinator),
rdd.partitions.indices,
(index, result: DataWritingSparkTaskResult) => {
val commitMessage = result.writerCommitMessage
messages(index) = commitMessage
totalNumRowsAccumulator.add(result.numRows)
batchWrite.onDataWriterCommit(commitMessage)
}
)
logInfo(s"Data source write support $batchWrite is committing.")
batchWrite.commit(messages)
logInfo(s"Data source write support $batchWrite committed.")
commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value))
} catch {
case cause: Throwable =>
logError(s"Data source write support $batchWrite is aborting.")
try {
batchWrite.abort(messages)
} catch {
case t: Throwable =>
logError(s"Data source write support $batchWrite failed to abort.")
cause.addSuppressed(t)
throw new SparkException("Writing job failed.", cause)
}
logError(s"Data source write support $batchWrite aborted.")
cause match {
// Only wrap non fatal exceptions.
case NonFatal(e) => throw new SparkException("Writing job aborted.", e)
case _ => throw cause
}
}
Nil
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy