All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.ocadotechnology.sttp.oauth2.cache.cats.CachingAccessTokenProvider.scala Maven / Gradle / Ivy

package org.polyvariant.sttp.oauth2.cache.cats

import cats.data.OptionT
import cats.effect.kernel.Clock
import cats.effect.kernel.Concurrent
import cats.effect.kernel.MonadCancelThrow
import cats.effect.std.Semaphore
import cats.syntax.all._
import org.polyvariant.sttp.oauth2.AccessTokenProvider
import org.polyvariant.sttp.oauth2.ClientCredentialsToken
import org.polyvariant.sttp.oauth2.Secret
import org.polyvariant.sttp.oauth2.cache.ExpiringCache
import org.polyvariant.sttp.oauth2.cache.cats.CachingAccessTokenProvider.TokenWithExpirationTime
import org.polyvariant.sttp.oauth2.common.Scope

import java.time.Instant
import scala.concurrent.duration.Duration

final class CachingAccessTokenProvider[F[_]: MonadCancelThrow: Clock](
  delegate: AccessTokenProvider[F],
  semaphore: Semaphore[F],
  tokenCache: ExpiringCache[F, Option[Scope], TokenWithExpirationTime]
) extends AccessTokenProvider[F] {

  override def requestToken(scope: Option[Scope]): F[ClientCredentialsToken.AccessTokenResponse] =
    getFromCache(scope)
      .getOrElseF(semaphore.permit.surround(acquireToken(scope))) // semaphore prevents concurrent token fetch from external service

  private def acquireToken(scope: Option[Scope]) =
    getFromCache(scope) // duplicate cache check, to verify if any other thread filled the cache during wait for semaphore permit
      .getOrElseF(fetchAndSaveToken(scope))

  private def getFromCache(scope: Option[Scope]) =
    (OptionT(tokenCache.get(scope)), OptionT.liftF(Clock[F].realTimeInstant))
      .mapN(_.toAccessTokenResponse(_))

  private def fetchAndSaveToken(scope: Option[Scope]) =
    for {
      token           <- delegate.requestToken(scope)
      tokenWithExpiry <- calculateExpiryInstant(token)
      _               <- tokenCache.put(scope, tokenWithExpiry, tokenWithExpiry.expirationTime)
    } yield token

  private def calculateExpiryInstant(response: ClientCredentialsToken.AccessTokenResponse): F[TokenWithExpirationTime] =
    Clock[F].realTimeInstant.map(TokenWithExpirationTime.from(response, _))

}

object CachingAccessTokenProvider {

  def apply[F[_]: Concurrent: Clock](
    delegate: AccessTokenProvider[F],
    tokenCache: ExpiringCache[F, Option[Scope], TokenWithExpirationTime]
  ): F[CachingAccessTokenProvider[F]] = Semaphore[F](n = 1).map(new CachingAccessTokenProvider[F](delegate, _, tokenCache))

  def refCacheInstance[F[_]: Concurrent: Clock](delegate: AccessTokenProvider[F]): F[CachingAccessTokenProvider[F]] =
    CatsRefExpiringCache[F, Option[Scope], TokenWithExpirationTime].flatMap(CachingAccessTokenProvider(delegate, _))

  final case class TokenWithExpirationTime(
    accessToken: Secret[String],
    domain: Option[String],
    expirationTime: Instant,
    scope: Option[Scope]
  ) {

    def toAccessTokenResponse(now: Instant): ClientCredentialsToken.AccessTokenResponse = {
      val newExpiresIn = Duration.fromNanos(java.time.Duration.between(now, expirationTime).toNanos)
      ClientCredentialsToken.AccessTokenResponse(accessToken, domain, newExpiresIn, scope)
    }

  }

  object TokenWithExpirationTime {

    def from(token: ClientCredentialsToken.AccessTokenResponse, now: Instant): TokenWithExpirationTime = {
      val expirationTime = now.plusNanos(token.expiresIn.toNanos)
      TokenWithExpirationTime(token.accessToken, token.domain, expirationTime, token.scope)
    }

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy