All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.azure.cosmos.spark.TransientIOErrorsRetryingIterator.scala Maven / Gradle / Ivy

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.CosmosException
import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple
import com.azure.cosmos.models.FeedResponse
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.util.{CosmosPagedFlux, CosmosPagedIterable}
import reactor.core.scheduler.Schedulers

import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.util.Random
import scala.util.control.Breaks
import scala.concurrent.{Await, ExecutionContext, Future}
import com.azure.cosmos.implementation.{ChangeFeedSparkRowItem, OperationCancelledException, SparkBridgeImplementationInternal}


// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

// This iterator exists to allow adding more extensive retries for transient
// IO errors when draining a query or change feed query
// The Java SDK has built-in retry policies - but those retry policies are
// pretty restrictive - only allowing retries for 30 seconds for example in eventual consistency mode
// The built-in SDK policies are optimized for OLTP scenarios - assuming that client machines
// are not resource/CPU exhausted etc. In Spark scenarios it is much more common
// that executors at least temporarily have pegged CPU, often queries are very IO intense,
// use large page sizes etc. The retry policy against transient IO errors needs to be more robust
// as a consequence for Spark scenarios.
// The iterator below allows retries based on the continuation token of the previously received response
// because we know that IO errors cannot happen iterating over documents of one page it is safe
// to use the continuation token to keep draining on the retry
// TODO @fabianm - we should still have a discussion whether it would be worth to allow tweaking
//  the retry policy of the SDK. But having the Spark specific retries for now to get some experience
//  can help making the right decisions if/how to expose this in the SDK
private class TransientIOErrorsRetryingIterator[TSparkRow]
(
  val cosmosPagedFluxFactory: String => CosmosPagedFlux[TSparkRow],
  val pageSize: Int,
  val pagePrefetchBufferSize: Int,
  val operationContextAndListener: Option[OperationContextAndListenerTuple],
  val endLsn: Option[Long]
) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable {

  private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs
  private[spark] var maxRetryCount = CosmosConstants.maxRetryCountForTransientFailures
  private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration(
    5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds,
    scala.concurrent.duration.SECONDS)

  private val rnd = Random
  // scalastyle:off null
  private val lastContinuationToken = new AtomicReference[String](null)
  // scalastyle:on null
  private val retryCount = new AtomicLong(0)
  private lazy val operationContextString = operationContextAndListener match {
    case Some(o) => if (o.getOperationContext != null) {
      o.getOperationContext.toString
    } else {
      "n/a"
    }
    case None => "n/a"
  }

  private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None
  private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None
  private val lastPagedFlux = new AtomicReference[Option[CosmosPagedFlux[TSparkRow]]](None)
  override def hasNext: Boolean = {
    executeWithRetry("hasNextInternal", () => hasNextInternal)
  }

  /***
   * Checks whether more records exists - this will potentially trigger I/O operations and retries
   * @return true (more records exist), false (no more records exist), None (unknown call should be repeated)
   */
  private def hasNextInternal: Boolean = {
    var returnValue: Option[Boolean] = None

    while (returnValue.isEmpty) {
      returnValue = hasNextInternalCore
    }

    returnValue.get
  }

  /***
   * Checks whether more records exists - this will potentially trigger I/O operations and retries
   * @return true (more records exist), false (no more records exist), None (unknown call should be repeated)
   */
  private def hasNextInternalCore: Option[Boolean] = {
    if (hasBufferedNext) {
      Some(true)
    } else {
      val feedResponseIterator = currentFeedResponseIterator match {
        case Some(existing) => existing
        case None =>
          val newPagedFlux = Some(cosmosPagedFluxFactory.apply(lastContinuationToken.get))
          lastPagedFlux.getAndSet(newPagedFlux) match {
            case Some(oldPagedFlux) => {
              logInfo(s"Attempting to cancel oldPagedFlux, Context: $operationContextString")
              oldPagedFlux.cancelOn(Schedulers.boundedElastic()).onErrorComplete().subscribe().dispose()
            }
            case None =>
          }
          currentFeedResponseIterator = Some(
            new CosmosPagedIterable[TSparkRow](
              newPagedFlux.get,
              pageSize,
              pagePrefetchBufferSize
            )
            .iterableByPage()
            .iterator
            .asScala
            .buffered
          )

          currentFeedResponseIterator.get
      }

      val hasNext: Boolean = try {
        Await.result(
          Future {
            feedResponseIterator.hasNext
          }(TransientIOErrorsRetryingIterator.executionContext),
          maxPageRetrievalTimeout)
      } catch {

        case endToEndTimeoutException: OperationCancelledException =>
          val message = s"End-to-end timeout hit when trying to retrieve the next page. Continuation" +
            s"token: $lastContinuationToken, Context: $operationContextString"

          logError(message, throwable = endToEndTimeoutException)

          throw endToEndTimeoutException
        case timeoutException: TimeoutException =>

          val message = s"Attempting to retrieve the next page timed out. Continuation" +
            s"token: $lastContinuationToken, Context: $operationContextString"

          logError(message, timeoutException)

          val exception = new OperationCancelledException(
            message,
            null
          );
          exception.setStackTrace(timeoutException.getStackTrace());
          throw exception

        case other: Throwable => throw other
      }

      if (hasNext) {
        val feedResponse = feedResponseIterator.next()
        if (operationContextAndListener.isDefined) {
          operationContextAndListener.get.getOperationListener.feedResponseProcessedListener(
            operationContextAndListener.get.getOperationContext,
            feedResponse)
        }
        val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered
        lastContinuationToken.set(feedResponse.getContinuationToken)

        if (iteratorCandidate.hasNext && validateNextLsn(iteratorCandidate)) {
          currentItemIterator = Some(iteratorCandidate)
          Some(true)
        } else {
          // empty page interleaved
          // need to get attempt to get next FeedResponse to determine whether more records exist
          None
        }
      } else {
        Some(false)
      }
    }
  }

  private def hasBufferedNext: Boolean = {
    currentItemIterator match {
      case Some(iterator) => if (iterator.hasNext && validateNextLsn(iterator)) {
        true
      } else {
        currentItemIterator = None
        false
      }
      case None => false
    }
  }

  override def next(): TSparkRow = {
    currentItemIterator.get.next()
  }

  override def head(): TSparkRow = {
    currentItemIterator.get.head
  }

  private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = {
    val loop = new Breaks()
    var returnValue: Option[T] = None

    loop.breakable {
      while (true) {
        val retryIntervalInMs = rnd.nextInt(maxRetryIntervalInMs)

        try {
          returnValue = Some(func())
          retryCount.set(0)
          loop.break
        }
        catch {
          case cosmosException: CosmosException =>
            if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) {
              val retryCountSnapshot = retryCount.incrementAndGet()
              if (retryCountSnapshot > maxRetryCount) {
                logError(
                  s"Too many transient failure retry attempts in TransientIOErrorsRetryingIterator.$methodName",
                  cosmosException)
                throw cosmosException
              } else {
                logWarning(
                  s"Transient failure handled in TransientIOErrorsRetryingIterator.$methodName -" +
                    s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms",
                  cosmosException)
              }
            } else {
              throw cosmosException
            }
          case other: Throwable => throw other
        }

        currentItemIterator = None
        currentFeedResponseIterator = None
        Thread.sleep(retryIntervalInMs)
      }
    }

    returnValue.get
  }

  private[this] def validateNextLsn(itemIterator: BufferedIterator[TSparkRow]): Boolean = {
    this.endLsn match {
      case None =>
        // Only relevant in change feed
        // In batch mode endLsn is cleared - we will always continue reading until the change feed is
        // completely drained so all partitions return 304
        true
      case Some(endLsn) =>
        // In streaming mode we only continue until we hit the endOffset's continuation Lsn
        if (itemIterator.isEmpty) {
          return false
        }
        val node = itemIterator.head.asInstanceOf[ChangeFeedSparkRowItem]
        assert(node.lsn != null, "Change feed responses must have _lsn property.")
        assert(node.lsn != "", "Change feed responses must have non empty _lsn.")
        val nextLsn = SparkBridgeImplementationInternal.toLsn(node.lsn)

        nextLsn <= endLsn
    }
  }

  //  Correct way to cancel a flux and dispose it
  //  https://github.com/reactor/reactor-core/blob/main/reactor-core/src/test/java/reactor/core/publisher/scenarios/FluxTests.java#L837
  override def close(): Unit = {
    lastPagedFlux.getAndSet(None) match {
      case Some(oldPagedFlux) => oldPagedFlux.cancelOn(Schedulers.boundedElastic()).onErrorComplete().subscribe().dispose()
      case None =>
    }
  }
}

private object TransientIOErrorsRetryingIterator extends BasicLoggingTrait {
  private val maxConcurrency = SparkUtils.getNumberOfHostCPUCores

  val executorService: ExecutorService = new ThreadPoolExecutor(
    maxConcurrency,
    maxConcurrency,
    0L,
    TimeUnit.MILLISECONDS,
    // A synchronous queue does not have any internal capacity, not even a capacity of one.
    new SynchronousQueue(),
    SparkUtils.daemonThreadFactory(),
    // if all worker threads are busy,
    // this policy makes the caller thread execute the task.
    // This provides a simple feedback control mechanism that will slow down the rate that new tasks are submitted.
    new ThreadPoolExecutor.CallerRunsPolicy()
  )

  val executionContext = ExecutionContext.fromExecutorService(executorService)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy