com.azure.cosmos.spark.BulkWriter.scala Maven / Gradle / Ivy
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark
// scalastyle:off underscore.import
import com.azure.cosmos.implementation.CosmosDaemonThreadFactory
import com.azure.cosmos.{BridgeInternal, CosmosAsyncContainer, CosmosDiagnosticsContext, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosException}
import com.azure.cosmos.implementation.apachecommons.lang.StringUtils
import com.azure.cosmos.implementation.batch.{BatchRequestResponseConstants, BulkExecutorDiagnosticsTracker, ItemBulkOperation}
import com.azure.cosmos.models._
import com.azure.cosmos.spark.BulkWriter.{BulkOperationFailedException, bulkWriterInputBoundedElastic, bulkWriterRequestsBoundedElastic, bulkWriterResponsesBoundedElastic, getThreadInfo, readManyBoundedElastic}
import com.azure.cosmos.spark.diagnostics.DefaultDiagnostics
import reactor.core.Scannable
import reactor.core.publisher.Mono
import reactor.core.scheduler.Scheduler
import java.util
import java.util.Objects
import java.util.concurrent.{ScheduledFuture, ScheduledThreadPoolExecutor}
import scala.collection.concurrent.TrieMap
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.collection.mutable
import scala.concurrent.duration.Duration
// scalastyle:on underscore.import
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import com.azure.cosmos.implementation.guava25.base.Preconditions
import com.azure.cosmos.implementation.spark.{OperationContextAndListenerTuple, OperationListener}
import com.azure.cosmos.models.PartitionKey
import com.azure.cosmos.spark.BulkWriter.{DefaultMaxPendingOperationPerCore, emitFailureHandler}
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.TaskContext
import reactor.core.Disposable
import reactor.core.publisher.Sinks
import reactor.core.publisher.Sinks.{EmitFailureHandler, EmitResult}
import reactor.core.scala.publisher.SMono.PimpJFlux
import reactor.core.scala.publisher.{SFlux, SMono}
import reactor.core.scheduler.Schedulers
import java.util.UUID
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong, AtomicReference}
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.{Semaphore, TimeUnit}
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
//scalastyle:off null
//scalastyle:off multiple.string.literals
//scalastyle:off file.size.limit
private class BulkWriter
(
container: CosmosAsyncContainer,
partitionKeyDefinition: PartitionKeyDefinition,
writeConfig: CosmosWriteConfig,
diagnosticsConfig: DiagnosticsConfig,
outputMetricsPublisher: OutputMetricsPublisherTrait,
commitAttempt: Long = 1
) extends AsyncItemWriter {
private val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
private val verboseLoggingAfterReEnqueueingRetriesEnabled = new AtomicBoolean(false)
private val cpuCount = SparkUtils.getNumberOfHostCPUCores
// each bulk writer allows up to maxPendingOperations being buffered
// there is one bulk writer per spark task/partition
// and default config will create one executor per core on the executor host
// so multiplying by cpuCount in the default config is too aggressive
private val maxPendingOperations = writeConfig.bulkMaxPendingOperations
.getOrElse(DefaultMaxPendingOperationPerCore)
private val maxConcurrentPartitions = writeConfig.maxConcurrentCosmosPartitions match {
// using the provided maximum of concurrent partitions per Spark partition on the input data
// multiplied by 2 to leave space for partition splits during ingestion
case Some(configuredMaxConcurrentPartitions) => 2 * configuredMaxConcurrentPartitions
// using the total number of physical partitions
// multiplied by 2 to leave space for partition splits during ingestion
case None => 2 * ContainerFeedRangesCache.getFeedRanges(container).block().size
}
log.logInfo(
s"BulkWriter instantiated (Host CPU count: $cpuCount, maxPendingOperations: $maxPendingOperations, " +
s"maxConcurrentPartitions: $maxConcurrentPartitions ...")
// Artificial operation used to signale to the bufferUntil operator that
// the buffer should be flushed. A timer-based scheduler will publish this
// dummy operation for every batchIntervalInMs ms. This operation
// is filtered out and will never be flushed to the backend
private val readManyFlushOperationSingleton = ReadManyOperation(
new CosmosItemIdentity(
new PartitionKey("ReadManyOperation.FlushSingleton"),
"ReadManyOperation.FlushSingleton"
),
null,
null
)
private val closed = new AtomicBoolean(false)
private val lock = new ReentrantLock
private val pendingTasksCompleted = lock.newCondition
private val pendingRetries = new AtomicLong(0)
private val pendingBulkWriteRetries = java.util.concurrent.ConcurrentHashMap.newKeySet[CosmosItemOperation]().asScala
private val pendingReadManyRetries = java.util.concurrent.ConcurrentHashMap.newKeySet[ReadManyOperation]().asScala
private val activeTasks = new AtomicInteger(0)
private val errorCaptureFirstException = new AtomicReference[Throwable]()
private val bulkInputEmitter: Sinks.Many[CosmosItemOperation] = Sinks.many().unicast().onBackpressureBuffer()
private val activeBulkWriteOperations =java.util.concurrent.ConcurrentHashMap.newKeySet[CosmosItemOperation]().asScala
private val activeReadManyOperations = java.util.concurrent.ConcurrentHashMap.newKeySet[ReadManyOperation]().asScala
private val semaphore = new Semaphore(maxPendingOperations)
private val totalScheduledMetrics = new AtomicLong(0)
private val totalSuccessfulIngestionMetrics = new AtomicLong(0)
private val maxOperationTimeout = java.time.Duration.ofSeconds(CosmosConstants.batchOperationEndToEndTimeoutInSeconds)
private val endToEndTimeoutPolicy = new CosmosEndToEndOperationLatencyPolicyConfigBuilder(maxOperationTimeout)
.enable(true)
.build
private val cosmosBulkExecutionOptions = new CosmosBulkExecutionOptions(BulkWriter.bulkProcessingThresholds)
private val cosmosBulkExecutionOptionsImpl = ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper
.getCosmosBulkExecutionOptionsAccessor
.getImpl(cosmosBulkExecutionOptions)
private val monotonicOperationCounter = new AtomicLong(0)
cosmosBulkExecutionOptionsImpl.setSchedulerOverride(bulkWriterRequestsBoundedElastic)
cosmosBulkExecutionOptionsImpl.setMaxConcurrentCosmosPartitions(maxConcurrentPartitions)
cosmosBulkExecutionOptionsImpl.setCosmosEndToEndLatencyPolicyConfig(endToEndTimeoutPolicy)
private class ForwardingMetricTracker(val verboseLoggingEnabled: AtomicBoolean) extends BulkExecutorDiagnosticsTracker {
override def trackDiagnostics(ctx: CosmosDiagnosticsContext): Unit = {
val ctxOption = Option.apply(ctx)
outputMetricsPublisher.trackWriteOperation(0, ctxOption)
if (ctxOption.isDefined && verboseLoggingEnabled.get) {
BulkWriter.log.logWarning(s"Track bulk operation after re-enqueued retry: ${ctxOption.get.toJson}")
}
}
override def verboseLoggingAfterReEnqueueingRetriesEnabled(): Boolean = {
verboseLoggingEnabled.get()
}
}
cosmosBulkExecutionOptionsImpl.setDiagnosticsTracker(
new ForwardingMetricTracker(verboseLoggingAfterReEnqueueingRetriesEnabled)
)
ThroughputControlHelper.populateThroughputControlGroupName(cosmosBulkExecutionOptions, writeConfig.throughputControlConfig)
writeConfig.maxMicroBatchPayloadSizeInBytes match {
case Some(customMaxMicroBatchPayloadSizeInBytes) =>
cosmosBulkExecutionOptionsImpl
.setMaxMicroBatchPayloadSizeInBytes(customMaxMicroBatchPayloadSizeInBytes)
case None =>
}
writeConfig.initialMicroBatchSize match {
case Some(customInitialMicroBatchSize) =>
cosmosBulkExecutionOptions.setInitialMicroBatchSize(Math.max(1, customInitialMicroBatchSize))
case None =>
}
writeConfig.maxMicroBatchSize match {
case Some(customMaxMicroBatchSize) =>
cosmosBulkExecutionOptions.setMaxMicroBatchSize(
Math.max(
1,
Math.min(customMaxMicroBatchSize, BatchRequestResponseConstants.MAX_OPERATIONS_IN_DIRECT_MODE_BATCH_REQUEST)
)
)
case None =>
}
private val operationContext = initializeOperationContext()
private val cosmosPatchHelperOpt = writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemBulkUpdate =>
Some(new CosmosPatchHelper(diagnosticsConfig, writeConfig.patchConfigs.get))
case _ => None
}
private val readManyInputEmitterOpt: Option[Sinks.Many[ReadManyOperation]] = {
writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemBulkUpdate => Some(Sinks.many().unicast().onBackpressureBuffer())
case _ => None
}
}
private val batchIntervalInMs = cosmosBulkExecutionOptionsImpl
.getMaxMicroBatchInterval
.toMillis
private[this] val flushExecutorHolder: Option[(ScheduledThreadPoolExecutor, ScheduledFuture[_])] = {
writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemBulkUpdate =>
val executor = new ScheduledThreadPoolExecutor(
1,
new CosmosDaemonThreadFactory(
"BulkWriterReadManyFlush" + UUID.randomUUID()
))
executor.setExecuteExistingDelayedTasksAfterShutdownPolicy(false)
executor.setRemoveOnCancelPolicy(true)
val future:ScheduledFuture[_] = executor.scheduleWithFixedDelay(
() => this.onFlushReadMany(),
batchIntervalInMs,
batchIntervalInMs,
TimeUnit.MILLISECONDS)
Some(executor, future)
case _ => None
}
}
private def initializeOperationContext(): SparkTaskContext = {
val taskContext = TaskContext.get
val diagnosticsContext: DiagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "BulkWriter")
if (taskContext != null) {
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
taskContext.stageId(),
taskContext.partitionId(),
taskContext.taskAttemptId(),
"")
val listener: OperationListener =
DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass)
val operationContextAndListenerTuple = new OperationContextAndListenerTuple(taskDiagnosticsContext, listener)
cosmosBulkExecutionOptionsImpl
.setOperationContextAndListenerTuple(operationContextAndListenerTuple)
taskDiagnosticsContext
} else{
SparkTaskContext(diagnosticsContext.correlationActivityId,
-1,
-1,
-1,
"")
}
}
private def onFlushReadMany(): Unit = {
if (this.readManyInputEmitterOpt.isEmpty) {
throw new IllegalStateException("Callback onFlushReadMany should only be scheduled for bulk update.")
}
try {
this.readManyInputEmitterOpt.get.tryEmitNext(readManyFlushOperationSingleton) match {
case EmitResult.OK => log.logInfo("onFlushReadMany Successfully emitted flush")
case faultEmitResult =>
log.logError(s"Callback invocation 'onFlush' failed with result: $faultEmitResult.")
}
}
catch {
case t: Throwable =>
log.logError("Callback invocation 'onFlush' failed.", t)
}
}
private val readManySubscriptionDisposableOpt: Option[Disposable] = {
writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemBulkUpdate => Some(createReadManySubscriptionDisposable())
case _ => None
}
}
private def createReadManySubscriptionDisposable(): Disposable = {
log.logTrace(s"readManySubscriptionDisposable, Context: ${operationContext.toString} $getThreadInfo")
// We start from using the bulk batch size and interval and concurrency
// If in the future, there is a need to separate the configuration, can re-consider
val bulkBatchSize = writeConfig.maxMicroBatchSize match {
case Some(customMaxMicroBatchSize) => Math.min(
BatchRequestResponseConstants.MAX_OPERATIONS_IN_DIRECT_MODE_BATCH_REQUEST,
Math.max(1, customMaxMicroBatchSize))
case None => BatchRequestResponseConstants.MAX_OPERATIONS_IN_DIRECT_MODE_BATCH_REQUEST
}
val batchConcurrency = cosmosBulkExecutionOptionsImpl.getMaxMicroBatchConcurrency
val firstRecordTimeStamp = new AtomicLong(-1)
val currentMicroBatchSize = new AtomicLong(0)
readManyInputEmitterOpt
.get
.asFlux()
.publishOn(readManyBoundedElastic)
.timestamp
.bufferUntil(timestampReadManyOperationTuple => {
val timestamp = timestampReadManyOperationTuple.getT1
val readManyOperation = timestampReadManyOperationTuple.getT2
if (readManyOperation eq readManyFlushOperationSingleton) {
log.logTrace(s"FlushSingletonReceived, Context: ${operationContext.toString}")
val currentMicroBatchSizeSnapshot = currentMicroBatchSize.get
if (currentMicroBatchSizeSnapshot > 0) {
firstRecordTimeStamp.set(-1)
currentMicroBatchSize.set(0)
log.logTrace(s"FlushSingletonReceived - flushing batch, Context: ${operationContext.toString}")
true
} else {
// avoid counting flush operations for the micro batch size calculation
log.logTrace(s"FlushSingletonReceived - empty buffer, nothing to flush, Context: ${operationContext.toString}")
false
}
} else {
firstRecordTimeStamp.compareAndSet(-1, timestamp)
val age = timestamp - firstRecordTimeStamp.get
val batchSize = currentMicroBatchSize.incrementAndGet
if (batchSize >= bulkBatchSize || age >= batchIntervalInMs) {
log.logTrace(s"BatchIntervalExpired - flushing batch, Context: ${operationContext.toString}")
firstRecordTimeStamp.set(-1)
currentMicroBatchSize.set(0)
true
} else {
false
}
}
})
.subscribeOn(readManyBoundedElastic)
.asScala
.flatMap(timestampReadManyOperationTuples => {
val readManyOperations = timestampReadManyOperationTuples
.filter(candidate => !candidate.getT2.equals(readManyFlushOperationSingleton))
.map(tuple => tuple.getT2)
if (readManyOperations.isEmpty) {
Mono.empty()
} else {
val cosmosIdentitySet = readManyOperations.map(option => option.cosmosItemIdentity).toSet
// for each batch, use readMany to read items from cosmosdb
val requestOptions = new CosmosReadManyRequestOptions()
val requestOptionsImpl = ImplementationBridgeHelpers
.CosmosReadManyRequestOptionsHelper
.getCosmosReadManyRequestOptionsAccessor.getImpl(requestOptions)
ThroughputControlHelper.populateThroughputControlGroupName(
requestOptionsImpl,
writeConfig.throughputControlConfig)
ImplementationBridgeHelpers
.CosmosAsyncContainerHelper
.getCosmosAsyncContainerAccessor
.readMany(container, cosmosIdentitySet.toList.asJava, requestOptions, classOf[ObjectNode])
.switchIfEmpty(
// For Java SDK, empty pages will not be returned (this can happen when all the items does not exists yet)
// create a fake empty page response
Mono.just(
ImplementationBridgeHelpers
.FeedResponseHelper
.getFeedResponseAccessor
.createFeedResponse(new util.ArrayList[ObjectNode](), null, null)))
.doOnNext(feedResponse => {
// Tracking the bytes read as part of client-side patch (readMany + replace) as bytes written as well
// to have a way to indicate the additional work happening here
outputMetricsPublisher.trackWriteOperation(
0,
Option.apply(feedResponse.getCosmosDiagnostics) match {
case Some(diagnostics) => Option.apply(diagnostics.getDiagnosticsContext)
case None => None
})
val resultMap = new TrieMap[CosmosItemIdentity, ObjectNode]()
for (itemNode: ObjectNode <- feedResponse.getResults.asScala) {
resultMap += (
new CosmosItemIdentity(
PartitionKeyHelper.getPartitionKeyPath(itemNode, partitionKeyDefinition),
itemNode.get(CosmosConstants.Properties.Id).asText()) -> itemNode)
}
// It is possible that multiple cosmosPatchBulkUpdateOperations were targeting for the same item
// Currently, we are still creating one bulk item operation for each cosmosPatchBulkUpdateOperations
// for easier exception and semaphore handling
// However a consequences of it could be, the generated bulk item operation will fail due to conflicts or pre-condition failure
// If this turns out to be a problem, we can do more optimization here: merge multiple cosmosPatchBulkUpdateOperations into one bulkItemOperation
// But even the above approach can only work within the same batch but not for the whole spark partition processing.
for (readManyOperation <- readManyOperations) {
val cosmosPatchBulkUpdateOperations =
cosmosPatchHelperOpt
.get
.createCosmosPatchBulkUpdateOperations(readManyOperation.objectNode)
val rootNode =
cosmosPatchHelperOpt
.get
.patchBulkUpdateItem(resultMap.get(readManyOperation.cosmosItemIdentity), cosmosPatchBulkUpdateOperations)
// create bulk item operation
val etagOpt = Option.apply(rootNode.get(CosmosConstants.Properties.ETag))
val bulkItemOperation = etagOpt match {
case Some(etag) =>
CosmosBulkOperations.getReplaceItemOperation(
readManyOperation.cosmosItemIdentity.getId,
rootNode,
readManyOperation.cosmosItemIdentity.getPartitionKey,
new CosmosBulkItemRequestOptions().setIfMatchETag(etag.asText()),
new OperationContext(
readManyOperation.operationContext.itemId,
readManyOperation.operationContext.partitionKeyValue,
Some(etag.asText()),
readManyOperation.operationContext.attemptNumber,
monotonicOperationCounter.incrementAndGet(),
Some(readManyOperation.objectNode)
))
case None => CosmosBulkOperations.getCreateItemOperation(
rootNode,
readManyOperation.cosmosItemIdentity.getPartitionKey,
new OperationContext(
readManyOperation.operationContext.itemId,
readManyOperation.operationContext.partitionKeyValue,
eTagInput = None,
readManyOperation.operationContext.attemptNumber,
monotonicOperationCounter.incrementAndGet(),
Some(readManyOperation.objectNode)
))
}
this.emitBulkInput(bulkItemOperation)
}
})
.onErrorResume(throwable => {
for (readManyOperation <- readManyOperations) {
handleReadManyExceptions(throwable, readManyOperation)
}
Mono.empty()
})
.doFinally(_ => {
for (readManyOperation <- readManyOperations) {
val activeReadManyOperationFound = activeReadManyOperations.remove(readManyOperation)
// for ItemBulkUpdate strategy, each active task includes two stages: ReadMany + BulkWrite
// so we are not going to make task complete here
if (!activeReadManyOperationFound) {
// can't find the read-many operation in list of active operations!
logInfoOrWarning(s"Cannot find active read-many for '"
+ s"${readManyOperation.cosmosItemIdentity.getPartitionKey}/"
+ s"${readManyOperation.cosmosItemIdentity.getId}'. This can happen when "
+ s"retries get re-enqueued.")
if (pendingReadManyRetries.remove(readManyOperation)) {
pendingRetries.decrementAndGet()
}
}
}
})
.`then`(Mono.empty())
}
}, batchConcurrency)
.subscribe()
}
private def handleReadManyExceptions(throwable: Throwable, ReadManyOperation: ReadManyOperation): Unit = {
throwable match {
case e: CosmosException =>
outputMetricsPublisher.trackWriteOperation(
0,
Option.apply(e.getDiagnostics) match {
case Some(diagnostics) => Option.apply(diagnostics.getDiagnosticsContext)
case None => None
})
val requestOperationContext = ReadManyOperation.operationContext
if (shouldRetry(e.getStatusCode, e.getSubStatusCode, requestOperationContext)) {
log.logInfo(s"for itemId=[${requestOperationContext.itemId}], partitionKeyValue=[${requestOperationContext.partitionKeyValue}], " +
s"encountered status code '${e.getStatusCode}:${e.getSubStatusCode}' in read many, will retry! " +
s"attemptNumber=${requestOperationContext.attemptNumber}, exceptionMessage=${e.getMessage}, " +
s"Context: {${operationContext.toString}} $getThreadInfo")
// the task will be re-queued at the beginning of the flow, so mark it complete here
markTaskCompletion()
this.scheduleRetry(
trackPendingRetryAction = () => pendingReadManyRetries.add(ReadManyOperation),
clearPendingRetryAction = () => pendingReadManyRetries.remove(ReadManyOperation),
ReadManyOperation.cosmosItemIdentity.getPartitionKey,
ReadManyOperation.objectNode,
ReadManyOperation.operationContext,
e.getStatusCode)
} else {
// Non-retryable exception or has exceeded the max retry count
val requestOperationContext = ReadManyOperation.operationContext
log.logError(s"for itemId=[${requestOperationContext.itemId}], partitionKeyValue=[${requestOperationContext.partitionKeyValue}], " +
s"encountered status code '${e.getStatusCode}:${e.getSubStatusCode}', all retries exhausted! " +
s"attemptNumber=${requestOperationContext.attemptNumber}, exceptionMessage=${e.getMessage}, " +
s"Context: {${operationContext.toString} $getThreadInfo")
val message = s"All retries exhausted for readMany - " +
s"statusCode=[${e.getStatusCode}:${e.getSubStatusCode}] " +
s"itemId=[${requestOperationContext.itemId}], partitionKeyValue=[${requestOperationContext.partitionKeyValue}]"
val exceptionToBeThrown = new BulkOperationFailedException(e.getStatusCode, e.getSubStatusCode, message, e)
captureIfFirstFailure(exceptionToBeThrown)
cancelWork()
markTaskCompletion()
}
case _ => // handle non cosmos exceptions
log.logError(s"Unexpected failure code path in Bulk ingestion readMany stage, " +
s"Context: ${operationContext.toString} $getThreadInfo", throwable)
captureIfFirstFailure(throwable)
cancelWork()
markTaskCompletion()
}
}
private def scheduleRetry(
trackPendingRetryAction: () => Boolean,
clearPendingRetryAction: () => Boolean,
partitionKey: PartitionKey,
objectNode: ObjectNode,
operationContext: OperationContext,
statusCode: Int): Unit = {
if (trackPendingRetryAction()) {
this.pendingRetries.incrementAndGet()
}
// this is to ensure the submission will happen on a different thread in background
// and doesn't block the active thread
val deferredRetryMono = SMono.defer(() => {
scheduleWriteInternal(
partitionKey,
objectNode,
new OperationContext(
operationContext.itemId,
operationContext.partitionKeyValue,
operationContext.eTag,
operationContext.attemptNumber + 1,
operationContext.sequenceNumber))
if (clearPendingRetryAction()) {
this.pendingRetries.decrementAndGet()
}
SMono.empty
})
if (Exceptions.isTimeout(statusCode)) {
deferredRetryMono
.delaySubscription(
Duration(
BulkWriter.minDelayOn408RequestTimeoutInMs +
scala.util.Random.nextInt(
BulkWriter.maxDelayOn408RequestTimeoutInMs - BulkWriter.minDelayOn408RequestTimeoutInMs),
TimeUnit.MILLISECONDS),
Schedulers.boundedElastic())
.subscribeOn(Schedulers.boundedElastic())
.subscribe()
} else {
deferredRetryMono
.subscribeOn(Schedulers.boundedElastic())
.subscribe()
}
}
private val subscriptionDisposable: Disposable = {
log.logTrace(s"subscriptionDisposable, Context: ${operationContext.toString} $getThreadInfo")
val inputFlux = bulkInputEmitter
.asFlux()
.onBackpressureBuffer()
.publishOn(bulkWriterInputBoundedElastic)
.doOnError(t => {
log.logError(s"Input publishing flux failed, Context: ${operationContext.toString} $getThreadInfo", t)
})
val bulkOperationResponseFlux: SFlux[CosmosBulkOperationResponse[Object]] =
container
.executeBulkOperations[Object](
inputFlux,
cosmosBulkExecutionOptions)
.onBackpressureBuffer()
.publishOn(bulkWriterResponsesBoundedElastic)
.doOnError(t => {
log.logError(s"Bulk execution flux failed, Context: ${operationContext.toString} $getThreadInfo", t)
})
.asScala
bulkOperationResponseFlux.subscribe(
resp => {
val isGettingRetried = new AtomicBoolean(false)
val shouldSkipTaskCompletion = new AtomicBoolean(false)
try {
val itemOperation = resp.getOperation
val itemOperationFound = activeBulkWriteOperations.remove(itemOperation)
val pendingRetriesFound = pendingBulkWriteRetries.remove(itemOperation)
if (pendingRetriesFound) {
pendingRetries.decrementAndGet()
}
if (!itemOperationFound) {
// can't find the item operation in list of active operations!
logInfoOrWarning(s"Cannot find active operation for '${itemOperation.getOperationType} " +
s"${itemOperation.getPartitionKeyValue}/${itemOperation.getId}'. This can happen when " +
s"retries get re-enqueued.")
shouldSkipTaskCompletion.set(true)
}
if (pendingRetriesFound || itemOperationFound) {
val context = itemOperation.getContext[OperationContext]
val itemResponse = resp.getResponse
if (resp.getException != null) {
Option(resp.getException) match {
case Some(cosmosException: CosmosException) =>
handleNonSuccessfulStatusCode(
context, itemOperation, itemResponse, isGettingRetried, Some(cosmosException))
case _ =>
log.logWarning(
s"unexpected failure: itemId=[${context.itemId}], partitionKeyValue=[" +
s"${context.partitionKeyValue}], encountered , attemptNumber=${context.attemptNumber}, " +
s"exceptionMessage=${resp.getException.getMessage}, " +
s"Context: ${operationContext.toString} $getThreadInfo", resp.getException)
captureIfFirstFailure(resp.getException)
cancelWork()
}
} else if (Option(itemResponse).isEmpty || !itemResponse.isSuccessStatusCode) {
handleNonSuccessfulStatusCode(context, itemOperation, itemResponse, isGettingRetried, None)
} else {
// no error case
outputMetricsPublisher.trackWriteOperation(1, None)
totalSuccessfulIngestionMetrics.getAndIncrement()
}
}
}
finally {
if (!isGettingRetried.get) {
semaphore.release()
}
}
if (!shouldSkipTaskCompletion.get) {
markTaskCompletion()
}
},
errorConsumer = Option.apply(
ex => {
log.logError(s"Unexpected failure code path in Bulk ingestion, " +
s"Context: ${operationContext.toString} $getThreadInfo", ex)
// if there is any failure this closes the bulk.
// at this point bulk api doesn't allow any retrying
// we don't know the list of failed item-operations
// they only way to retry to keep a dictionary of pending operations outside
// so we know which operations failed and which ones can be retried.
// this is currently a kill scenario.
captureIfFirstFailure(ex)
cancelWork()
markTaskCompletion()
}
)
)
}
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
Preconditions.checkState(!closed.get())
throwIfCapturedExceptionExists()
val activeTasksSemaphoreTimeout = 10
val operationContext = new OperationContext(
getId(objectNode),
partitionKeyValue,
getETag(objectNode),
1,
monotonicOperationCounter.incrementAndGet())
val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0)
// Don't clone the activeOperations for the first iteration
// to reduce perf impact before the Semaphore has been acquired
// this means if the semaphore can't be acquired within 10 minutes
// the first attempt will always assume it wasn't stale - so effectively we
// allow staleness for ten additional minutes - which is perfectly fine
var activeBulkWriteOperationsSnapshot = mutable.Set.empty[CosmosItemOperation]
var pendingBulkWriteRetriesSnapshot = mutable.Set.empty[CosmosItemOperation]
var activeReadManyOperationsSnapshot = mutable.Set.empty[ReadManyOperation]
var pendingReadManyRetriesSnapshot = mutable.Set.empty[ReadManyOperation]
log.logTrace(
s"Before TryAcquire ${totalScheduledMetrics.get}, Context: ${operationContext.toString} $getThreadInfo")
while (!semaphore.tryAcquire(activeTasksSemaphoreTimeout, TimeUnit.MINUTES)) {
log.logDebug(s"Not able to acquire semaphore, Context: ${operationContext.toString} $getThreadInfo")
if (subscriptionDisposable.isDisposed ||
(readManySubscriptionDisposableOpt.isDefined && readManySubscriptionDisposableOpt.get.isDisposed)) {
captureIfFirstFailure(
new IllegalStateException("Can't accept any new work - BulkWriter has been disposed already"))
}
throwIfProgressStaled(
"Semaphore acquisition",
activeBulkWriteOperationsSnapshot,
pendingBulkWriteRetriesSnapshot,
activeReadManyOperationsSnapshot,
pendingReadManyRetriesSnapshot,
numberOfIntervalsWithIdenticalActiveOperationSnapshots,
allowRetryOnNewBulkWriterInstance = false)
activeBulkWriteOperationsSnapshot = activeBulkWriteOperations.clone()
pendingBulkWriteRetriesSnapshot = pendingBulkWriteRetries.clone()
activeReadManyOperationsSnapshot = activeReadManyOperations.clone()
pendingReadManyRetriesSnapshot = pendingReadManyRetries.clone()
}
val cnt = totalScheduledMetrics.getAndIncrement()
log.logTrace(s"total scheduled $cnt, Context: ${operationContext.toString} $getThreadInfo")
scheduleWriteInternal(partitionKeyValue, objectNode, operationContext)
}
private def scheduleWriteInternal(partitionKeyValue: PartitionKey,
objectNode: ObjectNode,
operationContext: OperationContext): Unit = {
activeTasks.incrementAndGet()
if (operationContext.attemptNumber > 1) {
logInfoOrWarning(s"bulk scheduleWrite attemptCnt: ${operationContext.attemptNumber}, " +
s"Context: ${operationContext.toString} $getThreadInfo")
}
// The handling will make sure that during retry:
// For itemBulkUpdate -> the retry will go through readMany stage -> bulkWrite stage.
// For other strategies -> the retry will only go through bulk write stage
writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemBulkUpdate => scheduleReadManyInternal(partitionKeyValue, objectNode, operationContext)
case _ => scheduleBulkWriteInternal(partitionKeyValue, objectNode, operationContext)
}
}
private def scheduleReadManyInternal(partitionKeyValue: PartitionKey,
objectNode: ObjectNode,
operationContext: OperationContext): Unit = {
// For FAIL_NON_SERIALIZED, will keep retry, while for other errors, use the default behavior
val readManyOperation = ReadManyOperation(new CosmosItemIdentity(partitionKeyValue, operationContext.itemId), objectNode, operationContext)
activeReadManyOperations.add(readManyOperation)
readManyInputEmitterOpt.get.emitNext(readManyOperation, emitFailureHandler)
}
private def scheduleBulkWriteInternal(partitionKeyValue: PartitionKey,
objectNode: ObjectNode,
operationContext: OperationContext): Unit = {
val bulkItemOperation = writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemOverwrite =>
CosmosBulkOperations.getUpsertItemOperation(objectNode, partitionKeyValue, operationContext)
case ItemWriteStrategy.ItemOverwriteIfNotModified =>
operationContext.eTag match {
case Some(eTag) =>
CosmosBulkOperations.getReplaceItemOperation(
operationContext.itemId,
objectNode,
partitionKeyValue,
new CosmosBulkItemRequestOptions().setIfMatchETag(eTag),
operationContext)
case _ => CosmosBulkOperations.getCreateItemOperation(objectNode, partitionKeyValue, operationContext)
}
case ItemWriteStrategy.ItemAppend =>
CosmosBulkOperations.getCreateItemOperation(objectNode, partitionKeyValue, operationContext)
case ItemWriteStrategy.ItemDelete =>
CosmosBulkOperations.getDeleteItemOperation(operationContext.itemId, partitionKeyValue, operationContext)
case ItemWriteStrategy.ItemDeleteIfNotModified =>
CosmosBulkOperations.getDeleteItemOperation(
operationContext.itemId,
partitionKeyValue,
operationContext.eTag match {
case Some(eTag) => new CosmosBulkItemRequestOptions().setIfMatchETag(eTag)
case _ => new CosmosBulkItemRequestOptions()
},
operationContext)
case ItemWriteStrategy.ItemPatch => getPatchItemOperation(operationContext.itemId, partitionKeyValue, partitionKeyDefinition, objectNode, operationContext)
case _ =>
throw new RuntimeException(s"${writeConfig.itemWriteStrategy} not supported")
}
this.emitBulkInput(bulkItemOperation)
}
private[this] def emitBulkInput(bulkItemOperation: CosmosItemOperation): Unit = {
activeBulkWriteOperations.add(bulkItemOperation)
// For FAIL_NON_SERIALIZED, will keep retry, while for other errors, use the default behavior
bulkInputEmitter.emitNext(bulkItemOperation, emitFailureHandler)
}
private[this] def getPatchItemOperation(itemId: String,
partitionKey: PartitionKey,
partitionKeyDefinition: PartitionKeyDefinition,
objectNode: ObjectNode,
context: OperationContext): CosmosItemOperation = {
assert(writeConfig.patchConfigs.isDefined)
assert(cosmosPatchHelperOpt.isDefined)
val patchConfigs = writeConfig.patchConfigs.get
val cosmosPatchHelper = cosmosPatchHelperOpt.get
val cosmosPatchOperations = cosmosPatchHelper.createCosmosPatchOperations(itemId, partitionKeyDefinition, objectNode)
val requestOptions = new CosmosBulkPatchItemRequestOptions()
if (patchConfigs.filter.isDefined && !StringUtils.isEmpty(patchConfigs.filter.get)) {
requestOptions.setFilterPredicate(patchConfigs.filter.get)
}
CosmosBulkOperations.getPatchItemOperation(itemId, partitionKey, cosmosPatchOperations, requestOptions, context)
}
//scalastyle:off method.length
//scalastyle:off cyclomatic.complexity
private[this] def handleNonSuccessfulStatusCode
(
context: OperationContext,
itemOperation: CosmosItemOperation,
itemResponse: CosmosBulkItemResponse,
isGettingRetried: AtomicBoolean,
responseException: Option[CosmosException]
) : Unit = {
val exceptionMessage = responseException match {
case Some(e) => e.getMessage
case None => ""
}
val effectiveStatusCode = Option(itemResponse) match {
case Some(r) => r.getStatusCode
case None => responseException match {
case Some(e) => e.getStatusCode
case None => CosmosConstants.StatusCodes.Timeout
}
}
val effectiveSubStatusCode = Option(itemResponse) match {
case Some(r) => r.getSubStatusCode
case None => responseException match {
case Some(e) => e.getSubStatusCode
case None => 0
}
}
log.logDebug(s"encountered item operation response with status code " +
s"$effectiveStatusCode:$effectiveSubStatusCode, " +
s"Context: ${operationContext.toString} $getThreadInfo")
if (shouldIgnore(effectiveStatusCode, effectiveSubStatusCode)) {
log.logDebug(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"ignored encountered status code '$effectiveStatusCode:$effectiveSubStatusCode', " +
s"Context: ${operationContext.toString}")
totalSuccessfulIngestionMetrics.getAndIncrement()
// work done
} else if (shouldRetry(effectiveStatusCode, effectiveSubStatusCode, context)) {
// requeue
log.logWarning(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"encountered status code '$effectiveStatusCode:$effectiveSubStatusCode', will retry! " +
s"attemptNumber=${context.attemptNumber}, exceptionMessage=$exceptionMessage, " +
s"Context: {${operationContext.toString}} $getThreadInfo")
// If the write strategy is patchBulkUpdate, the OperationContext.sourceItem will not be the original objectNode,
// It is computed through read item from cosmosdb, and then patch the item locally.
// During retry, it is important to use the original objectNode (for example for preCondition failure, it requires to go through the readMany step again)
val sourceItem = itemOperation match {
case _: ItemBulkOperation[ObjectNode, OperationContext] =>
context.sourceItem match {
case Some(bulkOperationSourceItem) => bulkOperationSourceItem
case None => itemOperation.getItem.asInstanceOf[ObjectNode]
}
case _ => itemOperation.getItem.asInstanceOf[ObjectNode]
}
this.scheduleRetry(
trackPendingRetryAction = () => pendingBulkWriteRetries.add(itemOperation),
clearPendingRetryAction = () => pendingBulkWriteRetries.remove(itemOperation),
itemOperation.getPartitionKeyValue,
sourceItem,
context,
effectiveStatusCode)
isGettingRetried.set(true)
} else {
log.logError(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"encountered status code '$effectiveStatusCode:$effectiveSubStatusCode', all retries exhausted! " +
s"attemptNumber=${context.attemptNumber}, exceptionMessage=$exceptionMessage, " +
s"Context: {${operationContext.toString} $getThreadInfo")
val message = s"All retries exhausted for '${itemOperation.getOperationType}' bulk operation - " +
s"statusCode=[$effectiveStatusCode:$effectiveSubStatusCode] " +
s"itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}]"
val exceptionToBeThrown = responseException match {
case Some(e) =>
new BulkOperationFailedException(effectiveStatusCode, effectiveSubStatusCode, message, e)
case None =>
new BulkOperationFailedException(effectiveStatusCode, effectiveSubStatusCode, message, null)
}
captureIfFirstFailure(exceptionToBeThrown)
cancelWork()
}
}
//scalastyle:on method.length
//scalastyle:on cyclomatic.complexity
private[this] def throwIfCapturedExceptionExists(): Unit = {
val errorSnapshot = errorCaptureFirstException.get()
if (errorSnapshot != null) {
log.logError(s"throw captured error ${errorSnapshot.getMessage}, " +
s"Context: ${operationContext.toString} $getThreadInfo")
throw errorSnapshot
}
}
private[this] def getActiveOperationsLog(
activeOperationsSnapshot: mutable.Set[CosmosItemOperation],
activeReadManyOperationsSnapshot: mutable.Set[ReadManyOperation]): String = {
val sb = new StringBuilder()
activeOperationsSnapshot
.take(BulkWriter.maxItemOperationsToShowInErrorMessage)
.foreach(itemOperation => {
if (sb.nonEmpty) {
sb.append(", ")
}
sb.append(itemOperation.getOperationType)
sb.append("->")
val ctx = itemOperation.getContext[OperationContext]
sb.append(s"${ctx.partitionKeyValue}/${ctx.itemId}/${ctx.eTag}(${ctx.attemptNumber})")
})
// add readMany snapshot logs
activeReadManyOperationsSnapshot
.take(BulkWriter.maxItemOperationsToShowInErrorMessage - activeOperationsSnapshot.size)
.foreach(readManyOperation => {
if (sb.nonEmpty) {
sb.append(", ")
}
sb.append("ReadMany")
sb.append("->")
val ctx = readManyOperation.operationContext
sb.append(s"${ctx.partitionKeyValue}/${ctx.itemId}/${ctx.eTag}(${ctx.attemptNumber})")
})
sb.toString()
}
private[this] def sameBulkWriteOperations
(
snapshot: mutable.Set[CosmosItemOperation],
current: mutable.Set[CosmosItemOperation]
): Boolean = {
if (snapshot.size != current.size) {
false
} else {
snapshot.forall(snapshotOperation => {
current.exists(
currentOperation => snapshotOperation.getOperationType == currentOperation.getOperationType
&& snapshotOperation.getPartitionKeyValue == currentOperation.getPartitionKeyValue
&& Objects.equals(snapshotOperation.getId, currentOperation.getId)
&& Objects.equals(snapshotOperation.getItem[ObjectNode], currentOperation.getItem[ObjectNode])
)
})
}
}
private[this] def sameReadManyOperations
(
snapshot: mutable.Set[ReadManyOperation],
current: mutable.Set[ReadManyOperation]
): Boolean = {
if (snapshot.size != current.size) {
false
} else {
snapshot.forall(snapshotOperation => {
current.exists(
currentOperation => snapshotOperation.cosmosItemIdentity == currentOperation.cosmosItemIdentity
&& Objects.equals(snapshotOperation.objectNode, currentOperation.objectNode)
)
})
}
}
private[this] def throwIfProgressStaled
(
operationName: String,
activeOperationsSnapshot: mutable.Set[CosmosItemOperation],
pendingRetriesSnapshot: mutable.Set[CosmosItemOperation],
activeReadManyOperationsSnapshot: mutable.Set[ReadManyOperation],
pendingReadManyOperationsSnapshot: mutable.Set[ReadManyOperation],
numberOfIntervalsWithIdenticalActiveOperationSnapshots: AtomicLong,
allowRetryOnNewBulkWriterInstance: Boolean
): Unit = {
val operationsLog = getActiveOperationsLog(activeOperationsSnapshot, activeReadManyOperationsSnapshot)
if (sameBulkWriteOperations(pendingRetriesSnapshot ++ activeOperationsSnapshot , activeBulkWriteOperations ++ pendingBulkWriteRetries)
&& sameReadManyOperations(pendingReadManyOperationsSnapshot ++ activeReadManyOperationsSnapshot , activeReadManyOperations ++ pendingReadManyRetries)) {
numberOfIntervalsWithIdenticalActiveOperationSnapshots.incrementAndGet()
log.logWarning(
s"$operationName has been waiting $numberOfIntervalsWithIdenticalActiveOperationSnapshots " +
s"times for identical set of operations: $operationsLog " +
s"Context: ${operationContext.toString} $getThreadInfo"
)
} else {
numberOfIntervalsWithIdenticalActiveOperationSnapshots.set(0)
logInfoOrWarning(
s"$operationName is waiting for active bulkWrite operations: $operationsLog " +
s"Context: ${operationContext.toString} $getThreadInfo"
)
}
val secondsWithoutProgress = numberOfIntervalsWithIdenticalActiveOperationSnapshots.get *
writeConfig.flushCloseIntervalInSeconds
val maxAllowedIntervalWithoutAnyProgressExceeded =
secondsWithoutProgress >= writeConfig.maxRetryNoProgressIntervalInSeconds ||
(commitAttempt == 1
&& allowRetryOnNewBulkWriterInstance
&& this.activeReadManyOperations.isEmpty
&& this.pendingReadManyRetries.isEmpty
&& secondsWithoutProgress >= writeConfig.maxNoProgressIntervalInSeconds)
if (maxAllowedIntervalWithoutAnyProgressExceeded) {
val exception = if (activeReadManyOperationsSnapshot.isEmpty) {
val retriableRemainingOperations = if (allowRetryOnNewBulkWriterInstance) {
Some(
(pendingRetriesSnapshot ++ activeOperationsSnapshot)
.toList
.sortBy(op => op.getContext[OperationContext].sequenceNumber)
)
} else {
None
}
new BulkWriterNoProgressException(
s"Stale bulk ingestion identified in $operationName - the following active operations have not been " +
s"completed (first ${BulkWriter.maxItemOperationsToShowInErrorMessage} shown) or progressed after " +
s"${writeConfig.maxNoProgressIntervalInSeconds} seconds: $operationsLog",
commitAttempt,
retriableRemainingOperations)
} else {
new BulkWriterNoProgressException(
s"Stale bulk ingestion as well as readMany operations identified in $operationName - the following active operations have not been " +
s"completed (first ${BulkWriter.maxItemOperationsToShowInErrorMessage} shown) or progressed after " +
s"${writeConfig.maxRetryNoProgressIntervalInSeconds} : $operationsLog",
commitAttempt,
None)
}
captureIfFirstFailure(exception)
cancelWork()
}
throwIfCapturedExceptionExists()
}
// the caller has to ensure that after invoking this method scheduleWrite doesn't get invoked
// scalastyle:off method.length
// scalastyle:off cyclomatic.complexity
override def flushAndClose(): Unit = {
this.synchronized {
try {
if (!closed.get()) {
log.logInfo(s"flushAndClose invoked, Context: ${operationContext.toString} $getThreadInfo")
log.logInfo(s"completed so far ${totalSuccessfulIngestionMetrics.get()}, " +
s"pending bulkWrite asks ${activeBulkWriteOperations.size}, pending readMany tasks ${activeReadManyOperations.size}," +
s" Context: ${operationContext.toString} $getThreadInfo")
// error handling, if there is any error and the subscription is cancelled
// the remaining tasks will not be processed hence we never reach activeTasks = 0
// once we do error handling we should think how to cover the scenario.
lock.lock()
try {
val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0)
var activeTasksSnapshot = activeTasks.get()
var pendingRetriesSnapshot = pendingRetries.get()
while ((pendingRetriesSnapshot > 0 || activeTasksSnapshot > 0)
&& errorCaptureFirstException.get == null) {
logInfoOrWarning(
s"Waiting for pending activeTasks $activeTasksSnapshot and/or pendingRetries " +
s"$pendingRetriesSnapshot, Context: ${operationContext.toString} $getThreadInfo")
val activeOperationsSnapshot = activeBulkWriteOperations.clone()
val activeReadManyOperationsSnapshot = activeReadManyOperations.clone()
val pendingOperationsSnapshot = pendingBulkWriteRetries.clone()
val pendingReadManyOperationsSnapshot = pendingReadManyRetries.clone()
val awaitCompleted = pendingTasksCompleted.await(writeConfig.flushCloseIntervalInSeconds, TimeUnit.SECONDS)
if (!awaitCompleted) {
throwIfProgressStaled(
"FlushAndClose",
activeOperationsSnapshot,
pendingOperationsSnapshot,
activeReadManyOperationsSnapshot,
pendingReadManyOperationsSnapshot,
numberOfIntervalsWithIdenticalActiveOperationSnapshots,
allowRetryOnNewBulkWriterInstance = true
)
if (numberOfIntervalsWithIdenticalActiveOperationSnapshots.get > 0L) {
val buffered = Scannable.from(bulkInputEmitter).scan(Scannable.Attr.BUFFERED)
if (verboseLoggingAfterReEnqueueingRetriesEnabled.compareAndSet(false, true)) {
log.logWarning(s"Starting to re-enqueue retries. Enabling verbose logs. "
+ s"Number of intervals with identical pending operations: "
+ s"$numberOfIntervalsWithIdenticalActiveOperationSnapshots Active Bulk Operations: "
+ s"$activeOperationsSnapshot, Active Read-Many Operations: $activeReadManyOperationsSnapshot, "
+ s"PendingRetries: $pendingRetriesSnapshot, Buffered tasks: $buffered "
+ s"Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - "
+ s"Context: ${operationContext.toString} $getThreadInfo")
} else if ((numberOfIntervalsWithIdenticalActiveOperationSnapshots.get % 3) == 0) {
log.logWarning(s"Reattempting to re-enqueue retries. Enabling verbose logs. "
+ s"Number of intervals with identical pending operations: "
+ s"$numberOfIntervalsWithIdenticalActiveOperationSnapshots Active Bulk Operations: "
+ s"$activeOperationsSnapshot, Active Read-Many Operations: $activeReadManyOperationsSnapshot, "
+ s"PendingRetries: $pendingRetriesSnapshot, Buffered tasks: $buffered "
+ s"Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - "
+ s"Context: ${operationContext.toString} $getThreadInfo")
}
activeOperationsSnapshot.foreach(operation => {
if (activeBulkWriteOperations.contains(operation)) {
// re-validating whether the operation is still active - if so, just re-enqueue another retry
// this is harmless - because all bulkItemOperations from Spark connector are always idempotent
// For FAIL_NON_SERIALIZED, will keep retry, while for other errors, use the default behavior
bulkInputEmitter.emitNext(operation, emitFailureHandler)
log.logWarning(s"Re-enqueued a retry for pending active write task '${operation.getOperationType} "
+ s"(${operation.getPartitionKeyValue}/${operation.getId})' "
+ s"- Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - "
+ s"Context: ${operationContext.toString} $getThreadInfo")
}
})
activeReadManyOperationsSnapshot.foreach(operation => {
if (activeReadManyOperations.contains(operation)) {
// re-validating whether the operation is still active - if so, just re-enqueue another retry
// this is harmless - because all bulkItemOperations from Spark connector are always idempotent
// For FAIL_NON_SERIALIZED, will keep retry, while for other errors, use the default behavior
readManyInputEmitterOpt.get.emitNext(operation, emitFailureHandler)
log.logWarning(s"Re-enqueued a retry for pending active read-many task '"
+ s"(${operation.cosmosItemIdentity.getPartitionKey}/${operation.cosmosItemIdentity.getId})' "
+ s"- Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - "
+ s"Context: ${operationContext.toString} $getThreadInfo")
}
})
}
}
activeTasksSnapshot = activeTasks.get()
pendingRetriesSnapshot = pendingRetries.get()
val semaphoreAvailablePermitsSnapshot = semaphore.availablePermits()
if (awaitCompleted) {
logInfoOrWarning(s"Waiting completed for pending activeTasks $activeTasksSnapshot, pendingRetries " +
s"$pendingRetriesSnapshot Context: ${operationContext.toString} $getThreadInfo")
} else {
logInfoOrWarning(s"Waiting interrupted for pending activeTasks $activeTasksSnapshot , pendingRetries " +
s"$pendingRetriesSnapshot - available permits $semaphoreAvailablePermitsSnapshot, " +
s"Context: ${operationContext.toString} $getThreadInfo")
}
}
logInfoOrWarning(s"Waiting completed for pending activeTasks $activeTasksSnapshot, pendingRetries " +
s"$pendingRetriesSnapshot Context: ${operationContext.toString} $getThreadInfo")
} finally {
lock.unlock()
}
logInfoOrWarning(s"invoking bulkInputEmitter.onComplete(), Context: ${operationContext.toString} $getThreadInfo")
semaphore.release(Math.max(0, activeTasks.get()))
bulkInputEmitter.emitComplete(BulkWriter.emitFailureHandlerForComplete)
// complete readManyInputEmitter
if (readManyInputEmitterOpt.isDefined) {
readManyInputEmitterOpt.get.emitComplete(BulkWriter.emitFailureHandlerForComplete)
}
throwIfCapturedExceptionExists()
assume(activeTasks.get() <= 0)
assume(activeBulkWriteOperations.isEmpty)
assume(activeReadManyOperations.isEmpty)
assume(semaphore.availablePermits() >= maxPendingOperations)
if (totalScheduledMetrics.get() != totalSuccessfulIngestionMetrics.get) {
log.logWarning(s"flushAndClose completed with no error but inconsistent total success and " +
s"scheduled metrics. This indicates that successful completion was only possible after re-enqueueing " +
s"retries. totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, " +
s"totalScheduled=$totalScheduledMetrics, Context: ${operationContext.toString} $getThreadInfo")
} else {
logInfoOrWarning(s"flushAndClose completed with no error. " +
s"totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, " +
s"totalScheduled=$totalScheduledMetrics, Context: ${operationContext.toString} $getThreadInfo")
}
}
} finally {
subscriptionDisposable.dispose()
readManySubscriptionDisposableOpt match {
case Some(readManySubscriptionDisposable) =>
readManySubscriptionDisposable.dispose()
case _ =>
}
flushExecutorHolder match {
case Some(executorAndFutureTuple) =>
val executor: ScheduledThreadPoolExecutor = executorAndFutureTuple._1
val future: ScheduledFuture[_] = executorAndFutureTuple._2
try {
future.cancel(true)
log.logDebug(s"Cancelled all future scheduled tasks $getThreadInfo, Context: ${operationContext.toString}")
} catch {
case e: Exception =>
log.logWarning(s"Failed to cancel scheduled tasks $getThreadInfo, Context: ${operationContext.toString}", e)
}
try {
log.logDebug(s"Shutting down the executor service, Context: ${operationContext.toString}")
executor.shutdownNow
log.logDebug(s"Successfully shut down the executor service, Context: ${operationContext.toString}")
} catch {
case e: Exception =>
log.logWarning(s"Failed to shut down the executor service, Context: ${operationContext.toString}", e)
}
case _ =>
}
closed.set(true)
}
}
}
// scalastyle:on method.length
// scalastyle:on cyclomatic.complexity
private def logInfoOrWarning(msg: => String): Unit = {
if (this.verboseLoggingAfterReEnqueueingRetriesEnabled.get()) {
log.logWarning(msg)
} else {
log.logInfo(msg)
}
}
private def markTaskCompletion(): Unit = {
lock.lock()
try {
val activeTasksLeftSnapshot = activeTasks.decrementAndGet()
val exceptionSnapshot = errorCaptureFirstException.get()
log.logTrace(s"markTaskCompletion, Active tasks left: $activeTasksLeftSnapshot, " +
s"error: $exceptionSnapshot, Context: ${operationContext.toString} $getThreadInfo")
if (activeTasksLeftSnapshot == 0 || exceptionSnapshot != null) {
pendingTasksCompleted.signal()
}
} finally {
lock.unlock()
}
}
private def captureIfFirstFailure(throwable: Throwable): Unit = {
log.logError(s"capture failure, Context: {${operationContext.toString}} $getThreadInfo", throwable)
lock.lock()
try {
errorCaptureFirstException.compareAndSet(null, throwable)
pendingTasksCompleted.signal()
} finally {
lock.unlock()
}
}
private def cancelWork(): Unit = {
logInfoOrWarning(s"cancelling remaining unprocessed tasks ${activeTasks.get} " +
s"[bulkWrite tasks ${activeBulkWriteOperations.size}, readMany tasks ${activeReadManyOperations.size} ]" +
s"Context: ${operationContext.toString}")
subscriptionDisposable.dispose()
if (readManySubscriptionDisposableOpt.isDefined) {
readManySubscriptionDisposableOpt.get.dispose()
}
}
private def shouldIgnore(statusCode: Int, subStatusCode: Int): Boolean = {
val returnValue = writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemAppend => Exceptions.isResourceExistsException(statusCode)
case ItemWriteStrategy.ItemDelete => Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode)
case ItemWriteStrategy.ItemDeleteIfNotModified => Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) ||
Exceptions.isPreconditionFailedException(statusCode)
case ItemWriteStrategy.ItemOverwriteIfNotModified =>
Exceptions.isResourceExistsException(statusCode) ||
Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) ||
Exceptions.isPreconditionFailedException(statusCode)
case _ => false
}
returnValue
}
private def shouldRetry(statusCode: Int, subStatusCode: Int, operationContext: OperationContext): Boolean = {
var returnValue = false
if (operationContext.attemptNumber < writeConfig.maxRetryCount) {
returnValue = writeConfig.itemWriteStrategy match {
case ItemWriteStrategy.ItemBulkUpdate =>
this.shouldRetryForItemPatchBulkUpdate(statusCode, subStatusCode)
// Upsert can return 404/0 in rare cases (when due to TTL expiration there is a race condition
case ItemWriteStrategy.ItemOverwrite =>
Exceptions.canBeTransientFailure(statusCode, subStatusCode) ||
statusCode == 0 || // Gateway mode reports inability to connect due to PoolAcquirePendingLimitException as status code 0
Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode)
case _ =>
Exceptions.canBeTransientFailure(statusCode, subStatusCode) ||
statusCode == 0 // Gateway mode reports inability to connect due to PoolAcquirePendingLimitException as status code 0
}
}
log.logDebug(s"Should retry statusCode '$statusCode:$subStatusCode' -> $returnValue, " +
s"Context: ${operationContext.toString} $getThreadInfo")
returnValue
}
private def shouldRetryForItemPatchBulkUpdate(statusCode: Int, subStatusCode: Int): Boolean = {
Exceptions.canBeTransientFailure(statusCode, subStatusCode) ||
statusCode == 0 // Gateway mode reports inability to connect due to
// PoolAcquirePendingLimitException as status code 0
Exceptions.isResourceExistsException(statusCode) ||
Exceptions.isPreconditionFailedException(statusCode)
}
private def getId(objectNode: ObjectNode) = {
val idField = objectNode.get(CosmosConstants.Properties.Id)
assume(idField != null && idField.isTextual)
idField.textValue()
}
/**
* Don't wait for any remaining work but signal to the writer the ungraceful close
* Should not throw any exceptions
*/
override def abort(shouldThrow: Boolean): Unit = {
if (shouldThrow) {
log.logError(s"Abort, Context: ${operationContext.toString} $getThreadInfo")
// signal an exception that will be thrown for any pending work/flushAndClose if no other exception has
// been registered
captureIfFirstFailure(
new IllegalStateException(s"The Spark task was aborted, Context: ${operationContext.toString}"))
} else {
log.logWarning(s"BulkWriter aborted and commit retried, Context: ${operationContext.toString} $getThreadInfo")
}
cancelWork()
}
private class OperationContext
(
itemIdInput: String,
partitionKeyValueInput: PartitionKey,
eTagInput: Option[String],
val attemptNumber: Int,
val sequenceNumber: Long,
/** starts from 1 * */
sourceItemInput: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
{
private val ctxCore: OperationContextCore = OperationContextCore(itemIdInput, partitionKeyValueInput, eTagInput, sourceItemInput)
override def equals(obj: Any): Boolean = ctxCore.equals(obj)
override def hashCode(): Int = ctxCore.hashCode()
override def toString: String = {
ctxCore.toString + s", attemptNumber = $attemptNumber"
}
def itemId: String = ctxCore.itemId
def partitionKeyValue: PartitionKey = ctxCore.partitionKeyValue
def eTag: Option[String] = ctxCore.eTag
def sourceItem: Option[ObjectNode] = ctxCore.sourceItem
}
private case class OperationContextCore
(
itemId: String,
partitionKeyValue: PartitionKey,
eTag: Option[String],
sourceItem: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation
{
override def productPrefix: String = "OperationContext"
}
private case class ReadManyOperation(
cosmosItemIdentity: CosmosItemIdentity,
objectNode: ObjectNode,
operationContext: OperationContext)
}
private object BulkWriter {
private val log = new DefaultDiagnostics().getLogger(this.getClass)
//scalastyle:off magic.number
private val maxDelayOn408RequestTimeoutInMs = 3000
private val minDelayOn408RequestTimeoutInMs = 500
private val maxItemOperationsToShowInErrorMessage = 10
private val BULK_WRITER_REQUESTS_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-requests-bounded-elastic"
private val BULK_WRITER_INPUT_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-input-bounded-elastic"
private val BULK_WRITER_RESPONSES_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-responses-bounded-elastic"
private val READ_MANY_BOUNDED_ELASTIC_THREAD_NAME = "read-many-bounded-elastic"
private val TTL_FOR_SCHEDULER_WORKER_IN_SECONDS = 60 // same as BoundedElasticScheduler.DEFAULT_TTL_SECONDS
//scalastyle:on magic.number
// let's say the spark executor VM has 16 CPU cores.
// let's say we have a cosmos container with 1M RU which is 167 partitions
// let's say we are ingesting items of size 1KB
// let's say max request size is 1MB
// hence we want 1MB/ 1KB items per partition to be buffered
// 1024 * 167 items should get buffered on a 16 CPU core VM
// so per CPU core we want (1024 * 167 / 16) max items to be buffered
// Reduced the targeted buffer from 2MB per partition and core to 1 MB because
// we had a few customers seeing to high CPU usage with the previous setting
// Reason is that several customers use larger than 1 KB documents so we need
// to be less aggressive with the buffering
val DefaultMaxPendingOperationPerCore: Int = 1024 * 167 / 16
val emitFailureHandler: EmitFailureHandler =
(signalType, emitResult) => {
if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) {
log.logDebug(s"emitFailureHandler - Signal: ${signalType.toString}, Result: ${emitResult.toString}")
true
} else {
log.logError(s"emitFailureHandler - Signal: ${signalType.toString}, Result: ${emitResult.toString}")
false
}
}
private val emitFailureHandlerForComplete: EmitFailureHandler =
(signalType, emitResult) => {
if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) {
log.logDebug(s"emitFailureHandlerForComplete - Signal: ${signalType.toString}, Result: ${emitResult.toString}")
true
} else if (emitResult.equals(EmitResult.FAIL_CANCELLED) || emitResult.equals(EmitResult.FAIL_TERMINATED)) {
log.logDebug(s"emitFailureHandlerForComplete - Already completed - Signal: ${signalType.toString}, Result: ${emitResult.toString}")
false
} else {
log.logError(s"emitFailureHandlerForComplete - Signal: ${signalType.toString}, Result: ${emitResult.toString}")
false
}
}
private val bulkProcessingThresholds = new CosmosBulkExecutionThresholdsState()
private val maxPendingOperationsPerJVM: Int = DefaultMaxPendingOperationPerCore * SparkUtils.getNumberOfHostCPUCores
// Custom bounded elastic scheduler to consume input flux
val bulkWriterRequestsBoundedElastic: Scheduler = Schedulers.newBoundedElastic(
Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE,
Schedulers.DEFAULT_BOUNDED_ELASTIC_QUEUESIZE + 2 * maxPendingOperationsPerJVM,
BULK_WRITER_REQUESTS_BOUNDED_ELASTIC_THREAD_NAME,
TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true)
// Custom bounded elastic scheduler to consume input flux
val bulkWriterInputBoundedElastic: Scheduler = Schedulers.newBoundedElastic(
Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE,
Schedulers.DEFAULT_BOUNDED_ELASTIC_QUEUESIZE + 2 * maxPendingOperationsPerJVM,
BULK_WRITER_INPUT_BOUNDED_ELASTIC_THREAD_NAME,
TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true)
// Custom bounded elastic scheduler to switch off IO thread to process response.
val bulkWriterResponsesBoundedElastic: Scheduler = Schedulers.newBoundedElastic(
Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE,
Schedulers.DEFAULT_BOUNDED_ELASTIC_QUEUESIZE + maxPendingOperationsPerJVM,
BULK_WRITER_RESPONSES_BOUNDED_ELASTIC_THREAD_NAME,
TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true)
// Custom bounded elastic scheduler to switch off IO thread to process response.
val readManyBoundedElastic: Scheduler = Schedulers.newBoundedElastic(
2 * Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE,
Schedulers.DEFAULT_BOUNDED_ELASTIC_QUEUESIZE + maxPendingOperationsPerJVM,
READ_MANY_BOUNDED_ELASTIC_THREAD_NAME,
TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true)
def getThreadInfo: String = {
val t = Thread.currentThread()
val group = Option.apply(t.getThreadGroup) match {
case Some(group) => group.getName
case None => "n/a"
}
s"Thread[Name: ${t.getName}, Group: $group, IsDaemon: ${t.isDaemon} Id: ${t.getId}]"
}
private class BulkOperationFailedException(statusCode: Int, subStatusCode: Int, message:String, cause: Throwable)
extends CosmosException(statusCode, message, null, cause) {
BridgeInternal.setSubStatusCode(this, subStatusCode)
}
}
//scalastyle:on multiple.string.literals
//scalastyle:on null
//scalastyle:on file.size.limit
© 2015 - 2025 Weber Informatics LLC | Privacy Policy