All Downloads are FREE. Search and download functionalities are using the official Maven repository.

wvlet.airframe.control.Retry.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airframe.control

import wvlet.airframe.control.ResultClass.Failed
import wvlet.log.LogSupport
import wvlet.airframe.rx.Rx

import java.util.concurrent.TimeUnit
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Random, Success, Try}

/**
  * Retry logic implementation helper
  */
object Retry extends LogSupport {
  def retryableFailure(e: Throwable)    = Failed(isRetryable = true, e)
  def nonRetryableFailure(e: Throwable) = Failed(isRetryable = false, e)

  def withBackOff(
      maxRetry: Int = 3,
      initialIntervalMillis: Int = 100,
      maxIntervalMillis: Int = 15000,
      multiplier: Double = 1.5
  ): RetryContext = {
    defaultRetryContext.withMaxRetry(maxRetry).withBackOff(initialIntervalMillis, maxIntervalMillis, multiplier)
  }

  def withBoundedBackoff(
      initialIntervalMillis: Int = 100,
      maxTotalWaitMillis: Int = 180000,
      multiplier: Double = 1.5
  ): RetryContext = {
    require(initialIntervalMillis > 0, s"initialWaitMillis must be > 0: ${initialIntervalMillis}")

    // S = totalWaitMillis = w * r^0 + w * r^1 + w * r^2 + ...  + w * r^n
    // S = w * (1-r^n) / (1-r)
    // r^n = 1 - S * (1-r)/w
    // n * log(r) = log(1 - S * (1-r) / w)
    val N = math.log(1 - (maxTotalWaitMillis * (1 - multiplier) / initialIntervalMillis)) / math.log(multiplier)

    def total(n: Int) = initialIntervalMillis * (1 - math.pow(multiplier, n)) / (1 - multiplier)

    var maxRetry = N.ceil.toInt
    while (maxRetry > 0 && total(maxRetry) > maxTotalWaitMillis) {
      maxRetry -= 1
    }
    var maxIntervalMillis = initialIntervalMillis * math.pow(multiplier, N).toInt
    withBackOff(
      maxRetry = maxRetry.max(0),
      initialIntervalMillis = initialIntervalMillis,
      maxIntervalMillis = maxIntervalMillis,
      multiplier = multiplier
    )
  }

  def withJitter(
      maxRetry: Int = 3,
      initialIntervalMillis: Int = 100,
      maxIntervalMillis: Int = 15000,
      multiplier: Double = 1.5
  ): RetryContext = {
    defaultRetryContext.withMaxRetry(maxRetry).withJitter(initialIntervalMillis, maxIntervalMillis, multiplier)
  }

  private val defaultRetryContext: RetryContext = {
    val retryConfig = RetryPolicyConfig()
    RetryContext(
      context = None,
      lastError = NOT_STARTED,
      retryCount = 0,
      maxRetry = 3,
      retryWaitStrategy = new Jitter(retryConfig),
      nextWaitMillis = retryConfig.initialIntervalMillis,
      baseWaitMillis = retryConfig.initialIntervalMillis,
      extraWaitMillis = 0
    )
  }

  case class MaxRetryException(retryContext: RetryContext)
      extends Exception(
        s"Reached the max retry count ${retryContext.retryCount}/${retryContext.maxRetry}: ${retryContext.lastError.getMessage}",
        retryContext.lastError
      )

  // Throw this to force retry the execution
  case class RetryableFailure(e: Throwable) extends Exception(e)

  case object NOT_STARTED extends Exception("Code is not executed")

  private def REPORT_RETRY_COUNT: RetryContext => Unit = { (ctx: RetryContext) =>
    warn(
      f"[${ctx.retryCount}/${ctx.maxRetry}] Execution failed: ${ctx.lastError.getMessage}. Retrying in ${ctx.nextWaitMillis / 1000.0}%.2f sec."
    )
  }

  private def RETHROW_ALL: Throwable => ResultClass.Failed = { (e: Throwable) => throw e }

  private[control] val noExtraWait = ExtraWait()

  case class ExtraWait(maxExtraWaitMillis: Int = 0, factor: Double = 0.0) {
    require(maxExtraWaitMillis >= 0)
    require(factor >= 0)

    def hasNoWait: Boolean = {
      maxExtraWaitMillis == 0 && factor == 0.0
    }

    // Compute the extra wait millis based on the next wait millis
    def extraWaitMillis(nextWaitMillis: Int): Int = {
      if (maxExtraWaitMillis == 0) {
        if (factor == 0.0) {
          0
        } else {
          (nextWaitMillis * factor).toInt
        }
      } else {
        if (factor == 0.0) {
          maxExtraWaitMillis
        } else {
          (nextWaitMillis * factor).toInt.min(maxExtraWaitMillis)
        }
      }
    }
  }

  case class RetryContext(
      context: Option[Any],
      lastError: Throwable,
      retryCount: Int,
      maxRetry: Int,
      retryWaitStrategy: RetryPolicy,
      nextWaitMillis: Int,
      baseWaitMillis: Int,
      extraWaitMillis: Int,
      resultClassifier: Any => ResultClass = ResultClass.ALWAYS_SUCCEED,
      errorClassifier: Throwable => ResultClass.Failed = ResultClass.ALWAYS_RETRY,
      beforeRetryAction: RetryContext => Any = REPORT_RETRY_COUNT
  ) {
    def init(context: Option[Any] = None): RetryContext = {
      this.copy(
        context = context,
        lastError = NOT_STARTED,
        retryCount = 0,
        nextWaitMillis = retryWaitStrategy.retryPolicyConfig.initialIntervalMillis,
        baseWaitMillis = retryWaitStrategy.retryPolicyConfig.initialIntervalMillis,
        extraWaitMillis = 0
      )
    }

    def canContinue: Boolean = {
      retryCount < maxRetry
    }

    /**
      * Update the retry context, including retry count, last error, next wait time, etc.
      *
      * @param retryReason
      * @return
      *   the next retry context
      */
    def nextRetry(retryReason: Throwable): RetryContext = {
      val nextRetryCtx = this.copy(
        lastError = retryReason,
        retryCount = retryCount + 1,
        nextWaitMillis = retryWaitStrategy.nextWait(baseWaitMillis) + extraWaitMillis,
        baseWaitMillis = retryWaitStrategy.updateBaseWait(baseWaitMillis),
        extraWaitMillis = 0
      )
      beforeRetryAction(nextRetryCtx)
      nextRetryCtx
    }

    def withExtraWait(extraWait: ExtraWait): RetryContext = {
      if (extraWait.hasNoWait && this.extraWaitMillis == 0) {
        this
      } else {
        this.copy(extraWaitMillis = extraWait.extraWaitMillis(nextWaitMillis))
      }
    }

    def withRetryWaitStrategy(newRetryWaitStrategy: RetryPolicy): RetryContext = {
      this.copy(retryWaitStrategy = newRetryWaitStrategy)
    }

    def withMaxRetry(newMaxRetry: Int): RetryContext = {
      this.copy(maxRetry = newMaxRetry)
    }

    def noRetry: RetryContext = {
      this.copy(maxRetry = 0)
    }

    def withBackOff(
        initialIntervalMillis: Int = 100,
        maxIntervalMillis: Int = 15000,
        multiplier: Double = 1.5
    ): RetryContext = {
      val config = RetryPolicyConfig(initialIntervalMillis, maxIntervalMillis, multiplier)
      this.copy(retryWaitStrategy = new ExponentialBackOff(config))
    }

    def withJitter(
        initialIntervalMillis: Int = 100,
        maxIntervalMillis: Int = 15000,
        multiplier: Double = 1.5
    ): RetryContext = {
      val config = RetryPolicyConfig(initialIntervalMillis, maxIntervalMillis, multiplier)
      this.copy(retryWaitStrategy = new Jitter(config))
    }

    def withResultClassifier[U](newResultClassifier: U => ResultClass): RetryContext = {
      this.copy(resultClassifier = newResultClassifier.asInstanceOf[Any => ResultClass])
    }

    /**
      * Set a detailed error handler upon Exception. If the given exception is not retryable, just rethrow the
      * exception. Otherwise, consume the exception.
      */
    def withErrorClassifier(errorClassifier: Throwable => ResultClass.Failed): RetryContext = {
      this.copy(errorClassifier = errorClassifier)
    }

    def beforeRetry[U](handler: RetryContext => U): RetryContext = {
      this.copy(beforeRetryAction = handler)
    }

    /**
      * Clear the default beforeRetry action
      */
    def noRetryLogging: RetryContext = {
      this.copy(beforeRetryAction = { (x: RetryContext) => })
    }

    /**
      * Add a partial function that accepts exceptions that need to be retried.
      *
      * @param errorClassifier
      * @return
      */
    def retryOn(errorClassifier: PartialFunction[Throwable, ResultClass.Failed]): RetryContext = {
      this.copy(errorClassifier = { (e: Throwable) => errorClassifier.applyOrElse(e, RETHROW_ALL) })
    }

    def run[A](body: => A): A = {
      runInternal(None)(body)
    }

    def runWithContext[A](context: Any, circuitBreaker: CircuitBreaker = CircuitBreaker.alwaysClosed)(body: => A): A = {
      runInternal(Option(context), circuitBreaker)(body)
    }

    private def classifyResult[A](result: A): ResultClass = {
      val resultClass = result match {
        case Success(x) =>
          // Test whether the code block execution is succeeded or failed
          resultClassifier(x)
        case Failure(RetryableFailure(e)) =>
          ResultClass.retryableFailure(e)
        case Failure(e) =>
          errorClassifier(e)
      }
      resultClass
    }

    protected def runInternal[A](context: Option[Any], circuitBreaker: CircuitBreaker = CircuitBreaker.alwaysClosed)(
        body: => A
    ): A = {
      var result: Option[A]          = None
      var retryContext: RetryContext = init(context)

      var isFirst: Boolean = true

      while (isFirst || (result.isEmpty && retryContext.canContinue)) {
        isFirst = false

        val ret = Try {
          circuitBreaker.verifyConnection
          body
        }
        val resultClass = classifyResult(ret)
        resultClass match {
          case ResultClass.Succeeded =>
            circuitBreaker.recordSuccess
            // OK. Exit the loop
            result = Some(ret.get)
          case ResultClass.Failed(isRetryable, cause, extraWait) if isRetryable =>
            circuitBreaker.recordFailure(cause)
            // Retryable error
            retryContext = retryContext.withExtraWait(extraWait).nextRetry(cause)
            // Wait until the next retry
            Compat.sleep(retryContext.nextWaitMillis)
          case ResultClass.Failed(_, cause, _) =>
            // For regular non-retryable failures, we need to treat them as successful responses
            circuitBreaker.recordSuccess
            // Non-retryable error. Exit the loop by throwing the exception
            throw cause
        }
      }

      result match {
        case Some(a) =>
          a
        case None =>
          throw MaxRetryException(retryContext)
      }
    }

    def runAsyncWithContext[A](context: Any, circuitBreaker: CircuitBreaker = CircuitBreaker.alwaysClosed)(
        body: => Rx[A]
    ): Rx[A] = {
      def loop(retryContext: RetryContext, isFirst: Boolean): Rx[A] = {
        if (!isFirst && !retryContext.canContinue) {
          Rx.exception(MaxRetryException(retryContext))
        } else {
          Rx.fromTry(Try(circuitBreaker.verifyConnection))
            .flatMap(_ => body)
            .transformRx { (ret: Try[A]) =>
              val resultClass = classifyResult(ret)
              resultClass match {
                case ResultClass.Succeeded =>
                  circuitBreaker.recordSuccess
                  // Exit the loop
                  Rx.fromTry(ret)
                case ResultClass.Failed(isRetryable, cause, extraWait) if isRetryable =>
                  // Retryable error
                  circuitBreaker.recordFailure(cause)
                  // Add retry wait
                  val nextRetry = retryContext.withExtraWait(extraWait).nextRetry(cause)
                  Rx.delay(nextRetry.nextWaitMillis, TimeUnit.MILLISECONDS)
                    .flatMap(_ => loop(nextRetry, isFirst = false))
                case ResultClass.Failed(_, cause, _) =>
                  // For regular non-retryable failures, we need to treat them as successful responses
                  circuitBreaker.recordSuccess
                  // Non-retryable error. Exit the loop with the exception
                  Rx.exception(cause)
              }
            }
        }
      }

      loop(retryContext = init(Option(context)), isFirst = true)
    }
  }

  case class RetryPolicyConfig(
      initialIntervalMillis: Int = 100,
      maxIntervalMillis: Int = 15000,
      multiplier: Double = 1.5
  ) {
    require(initialIntervalMillis >= 0)
    require(maxIntervalMillis >= 0)
    require(multiplier >= 0)
  }

  trait RetryPolicy {
    def retryPolicyConfig: RetryPolicyConfig
    def updateBaseWait(waitMillis: Int): Int = {
      math.round(waitMillis * retryPolicyConfig.multiplier).toInt.min(retryPolicyConfig.maxIntervalMillis)
    }
    def nextWait(baseWaitMillis: Int): Int
  }

  class ExponentialBackOff(val retryPolicyConfig: RetryPolicyConfig) extends RetryPolicy {
    override def nextWait(baseWaitMillis: Int): Int = {
      baseWaitMillis
    }
  }

  class Jitter(val retryPolicyConfig: RetryPolicyConfig, rand: Random = new Random()) extends RetryPolicy {
    override def nextWait(baseWaitMillis: Int): Int = {
      (baseWaitMillis.toDouble * rand.nextDouble()).round.toInt
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy