All Downloads are FREE. Search and download functionalities are using the official Maven repository.

nl.vroste.zio.kinesis.client.zionative.Consumer.scala Maven / Gradle / Ivy

The newest version!
package nl.vroste.zio.kinesis.client.zionative

import nl.vroste.zio.kinesis.client.Util.ZStreamExtensions
import nl.vroste.zio.kinesis.client.serde.Deserializer
import nl.vroste.zio.kinesis.client.zionative.FetchMode.{ EnhancedFanOut, Polling }
import nl.vroste.zio.kinesis.client.zionative.Fetcher.EndOfShard
import nl.vroste.zio.kinesis.client.zionative.LeaseCoordinator.AcquiredLease
import nl.vroste.zio.kinesis.client.zionative.fetcher.{ EnhancedFanOutFetcher, PollingFetcher }
import nl.vroste.zio.kinesis.client.zionative.leasecoordinator.{ DefaultLeaseCoordinator, LeaseCoordinationSettings }
import nl.vroste.zio.kinesis.client.zionative.leaserepository.DynamoDbLeaseRepository
import nl.vroste.zio.kinesis.client.{ HttpClientBuilder, Record, Util, _ }
import software.amazon.awssdk.services.kinesis.model.{
  KmsThrottlingException,
  LimitExceededException,
  ProvisionedThroughputExceededException
}
import zio._
import zio.aws.cloudwatch.CloudWatch
import zio.aws.kinesis.Kinesis
import zio.aws.kinesis.model._
import zio.aws.kinesis.model.primitives.{ SequenceNumber, ShardId, Timestamp }
import zio.stream.ZStream

import java.time.Instant
import scala.jdk.CollectionConverters._

final case class ExtendedSequenceNumber(sequenceNumber: String, subSequenceNumber: Long)

sealed trait FetchMode
object FetchMode {

  /**
   * Fetches records in a polling manner
   *
   * @param batchSize
   *   The maximum number of records to retrieve in one call to GetRecords. Note that Kinesis defines limits in terms of
   *   the maximum size in bytes of this call, so you need to take into account the distribution of data size of your
   *   records (i.e. avg and max).
   * @param pollSchedule
   *   Schedule for polling. The default schedule repeats immediately when there are more records available
   *   (millisBehindLatest > 0), otherwise it polls at a fixed interval of 1 second
   * @param throttlingBackoff
   *   When getting a Provisioned Throughput Exception or KmsThrottlingException, schedule to apply for backoff.
   *   Although zio-kinesis will make no more than 5 calls to GetRecords per second (the AWS limit), some limits depend
   *   on the size of the records being fetched.
   * @param retrySchedule
   *   Schedule for retrying in case of non-throttling related issues
   * @param bufferNrBatches
   *   The number of fetched batches (chunks) to buffer. A buffer allows downstream to process the records while a new
   *   poll call is being made concurrently. A batch will contain up to `batchSize` records. Prefer powers of 2 for this
   *   value for performance reasons.
   */
  final case class Polling(
    batchSize: Int = 1000,
    pollSchedule: Schedule[Any, GetRecordsResponse.ReadOnly, Any] = Polling.dynamicSchedule(1.second),
    throttlingBackoff: Schedule[Any, Any, (Duration, Long)] = Util.exponentialBackoff(5.seconds, 30.seconds),
    retrySchedule: Schedule[Any, Any, (Duration, Long)] = Util.exponentialBackoff(1.second, 1.minute),
    bufferNrBatches: Int = 2
  ) extends FetchMode

  object Polling {

    /**
     * Creates a polling schedule that immediately repeats when there are more records available (millisBehindLatest >
     * 0), otherwise polls at a fixed interval.
     *
     * @param interval
     *   Fixed interval for polling when no more records are currently available
     */
    def dynamicSchedule(interval: Duration): Schedule[Any, GetRecordsResponse.ReadOnly, Any] =
      (Schedule.recurWhile[Boolean](_ == true) || Schedule.fixed(interval))
        .contramap((_: GetRecordsResponse.ReadOnly).millisBehindLatest.getOrElse(0) != 0)
  }

  /**
   * Fetch data using enhanced fanout
   *
   * @param retrySchedule
   *   Schedule for retrying in case of connection issues
   */
  final case class EnhancedFanOut(
    deregisterConsumerAtShutdown: Boolean = true,
    maxSubscriptionsPerSecond: Int = 10,
    retrySchedule: Schedule[Any, Any, (Duration, Long)] = Util.exponentialBackoff(5.second, 1.minute)
  ) extends FetchMode
}

object Consumer {

  /**
   * Creates a stream that emits streams for each Kinesis shard that this worker holds a lease for.
   *
   * Upon initialization, a lease table is created if it does not yet exist. For each shard of the stream a lease is
   *
   * CHECKPOINTING
   *
   * Clients should periodically checkpoint their progress using the `Checkpointer`. Each processed record may be staged
   * with the Checkpointer to ensure that when the stream is interrupted, the last staged record will be checkpointed.
   *
   * When the stream is interrupted, the last staged checkpoint for each shard will be checkpointed and leases for that
   * shard are released.
   *
   * Checkpointing may fail by two (expected) causes:
   *   - Connection failures
   *   - Shard lease taken by another worker (see MULTIPLE WORKERS)
   *
   * In both cases, clients should end the shard stream by catching the error and continuing with an empty ZStream.
   *
   * MULTIPLE WORKERS
   *
   * Upon initialization, the Consumer will check existing leases to see how many workers are currently active. It will
   * immediately steal its fair share of leases from other workers, in such a way that all workers end up with a new
   * fair share of leases. When it is the only active worker, it will take all leases. The leases per worker are
   * randomized to reduce the chance of lease stealing contention.
   *
   * This procedure is safe against multiple workers initializing concurrently.
   *
   * Leases are periodically renewed and leases of other workers are refreshed. When another worker's lease has not been
   * updated for some time, it is considered expired and the worker considered a zombie. The new fair share of leases
   * for all workers is then determined and this worker will try to claim some of the expired leases.
   *
   * Each of the shard streams may end when the lease for that shard is lost.
   *
   * CONNECTION FAILURES
   *
   * The consumer will keep on running when there are connection issues to Kinesis or DynamoDB. An exponential backoff
   * schedule (user-customizable) is applied to retry in case of such failures. When leases expire due to being unable
   * to renew them under these circumstances, the shard lease is released and the shard stream is ended. When the
   * connection is restored, the lease coordinator will try to take leases again to get to the target number of leases.
   *
   * DIAGNOSTIC EVENTS An optional function `emitDiagnostic` can be passed to be called when interesting events happen
   * in the Consumer. This is useful for logging and for metrics.
   *
   * @param streamIdentifier
   *   Stream to consume from. Either just the name or the whole arn.
   * @param applicationName
   *   Name of the application. This is used as the table name for lease coordination (DynamoDB)
   * @param deserializer
   *   Record deserializer
   * @param workerIdentifier
   *   Identifier of this worker, used for lease coordination
   * @param fetchMode
   *   How to fetch records: Polling or EnhancedFanOut, including config parameters
   * @param leaseCoordinationSettings
   *   Config parameters for lease coordination
   * @param initialPosition
   *   When no checkpoint exists yet for a shard, start processing from this position
   * @param emitDiagnostic
   *   Function that is called for events happening in the Consumer. For diagnostics / metrics.
   * @param shardAssignmentStrategy
   *   How to assign shards to this worker
   * @tparam R
   * @tparam T
   *   Record type
   * @return
   *   Stream of tuples of (shard ID, shard stream, checkpointer)
   */
  def shardedStream[R, T](
    streamIdentifier: StreamIdentifier,
    applicationName: String,
    deserializer: Deserializer[R, T],
    workerIdentifier: String = "worker1",
    fetchMode: FetchMode = FetchMode.Polling(),
    leaseCoordinationSettings: LeaseCoordinationSettings = LeaseCoordinationSettings(),
    initialPosition: InitialPosition = InitialPosition.TrimHorizon,
    emitDiagnostic: DiagnosticEvent => UIO[Unit] = _ => ZIO.unit,
    shardAssignmentStrategy: ShardAssignmentStrategy = ShardAssignmentStrategy.balanced()
  ): ZStream[
    Kinesis with LeaseRepository with R,
    Throwable,
    (
      String,
      ZStream[
        Any,
        Throwable,
        Record[T]
      ],
      Checkpointer
    )
  ] = {
    def toRecords(
      shardId: String,
      r: zio.aws.kinesis.model.Record.ReadOnly
    ): ZIO[R, Throwable, Chunk[Record[T]]] = {
      val dataChunk = r.data

      if (ProtobufAggregation.isAggregatedRecord(dataChunk))
        for {
          aggregatedRecord <- ZIO.fromTry(ProtobufAggregation.decodeAggregatedRecord(dataChunk))
          _                 = ZIO.logDebug(s"Found aggregated record with ${aggregatedRecord.getRecordsCount} sub records")
          records          <- ZIO.foreach(aggregatedRecord.getRecordsList.asScala.zipWithIndex.toSeq) {
                                case (subRecord, subSequenceNr) =>
                                  val data = Chunk.fromByteBuffer(subRecord.getData.asReadOnlyByteBuffer())

                                  deserializer
                                    .deserialize(data)
                                    .map { data =>
                                      Record(
                                        shardId,
                                        r.sequenceNumber,
                                        r.approximateArrivalTimestamp.toOption.get,
                                        data,
                                        aggregatedRecord.getPartitionKeyTable(subRecord.getPartitionKeyIndex.toInt),
                                        r.encryptionType.toOption,
                                        Some(subSequenceNr.toLong),
                                        if (subRecord.hasExplicitHashKeyIndex)
                                          Some(aggregatedRecord.getExplicitHashKeyTable(subRecord.getExplicitHashKeyIndex.toInt))
                                        else None,
                                        aggregated = true
                                      )
                                    }
                              }
        } yield Chunk.fromIterable(records)
      else
        deserializer
          .deserialize(r.data)
          .map { data =>
            Record(
              shardId,
              r.sequenceNumber,
              r.approximateArrivalTimestamp.toOption.get,
              data,
              r.partitionKey,
              r.encryptionType.toOption,
              subSequenceNumber = None,
              explicitHashKey = None,
              aggregated = false
            )
          }
          .map(Chunk.single)
    }

    def makeFetcher(
      streamDescription: StreamDescription.ReadOnly
    ): ZIO[Scope with Kinesis, Throwable, Fetcher] =
      fetchMode match {
        case c: Polling        => PollingFetcher.make(StreamIdentifier.fromARN(streamDescription.streamARN), c, emitDiagnostic)
        case c: EnhancedFanOut =>
          EnhancedFanOutFetcher.make(
            StreamIdentifier.fromARN(streamDescription.streamARN),
            workerIdentifier,
            c,
            emitDiagnostic
          )
      }

    val listShards: ZIO[Kinesis, Throwable, Map[ShardId, Shard.ReadOnly]] = Kinesis
      .listShards(ListShardsRequest(streamName = streamIdentifier.name, streamARN = streamIdentifier.arn))
      .mapError(_.toThrowable)
      .runCollect
      .map(_.map(l => (l.shardId, l)).toMap)
      .flatMap { shards =>
        if (shards.isEmpty) ZIO.fail(new Exception("No shards in stream!"))
        else ZIO.succeed(shards)
      }

    def createDependencies: ZIO[
      Kinesis with Scope with LeaseRepository,
      Throwable,
      (Fetcher, LeaseCoordinator)
    ] =
      Kinesis
        .describeStream(DescribeStreamRequest(streamName = streamIdentifier.name, streamARN = streamIdentifier.arn))
        .mapError(_.toThrowable)
        .map(_.streamDescription)
        .forkScoped // joined later
        .flatMap { streamDescriptionFib =>
          val fetchInitialShards = streamDescriptionFib.join.flatMap { streamDescription =>
            if (!streamDescription.hasMoreShards)
              ZIO.succeed(streamDescription.shards.map(s => s.shardId -> s).toMap)
            else
              listShards
          }

          streamDescriptionFib.join.flatMap(makeFetcher) zipPar (
            // Fetch shards and initialize the lease coordinator at the same time
            // When we have the shards, we inform the lease coordinator. When the lease table
            // still has to be created, we have the shards in time for lease claiming begins.
            // If not in time, the next cycle of takeLeases will take care of it
            // When the lease table already exists, the updateShards call will not provide
            // additional information to the lease coordinator, and the list of leases is used
            // as the list of shards.
            for {
              env              <- ZIO.environment[Kinesis]
              leaseCoordinator <- DefaultLeaseCoordinator
                                    .make(
                                      applicationName,
                                      workerIdentifier,
                                      emitDiagnostic,
                                      leaseCoordinationSettings,
                                      fetchInitialShards.provideEnvironment(env),
                                      listShards.provideEnvironment(env),
                                      shardAssignmentStrategy,
                                      initialPosition
                                    )
              _                <- ZIO.logInfo("Lease coordinator created")
            } yield leaseCoordinator
          )
        }

    ZStream.logAnnotate("worker", workerIdentifier) *>
      ZStream.unwrapScoped {
        createDependencies.map { case (fetcher, leaseCoordinator) =>
          leaseCoordinator.acquiredLeases.collect { case AcquiredLease(shardId, leaseLost) =>
            (shardId, leaseLost)
          }
            .mapZIOParUnordered(leaseCoordinationSettings.maxParallelLeaseAcquisitions) { case (shardId, leaseLost) =>
              for {
                checkpointer    <- leaseCoordinator.makeCheckpointer(shardId)
                env             <- ZIO.environment[R]
                checkpointOpt   <- leaseCoordinator.getCheckpointForShard(shardId)
                startingPosition = checkpointOpt
                                     .map(checkpointToStartingPosition(_, initialPosition))
                                     .getOrElse(InitialPosition.toStartingPosition(initialPosition))
                shardStream      = fetcher
                                     .shardRecordStream(ShardId(shardId), startingPosition)
                                     .catchAll {
                                       case Left(e)                            =>
                                         ZStream.fromZIO(
                                           ZIO.logSpan(s"Shard stream ${shardId} failed")(ZIO.logErrorCause(Cause.fail(e)))
                                         ) *>
                                           ZStream.fail(e)
                                       case Right(EndOfShard(childShards @ _)) =>
                                         ZStream.fromZIO(
                                           ZIO.logDebug(
                                             s"Found end of shard for ${shardId}. " +
                                               s"Child shards are ${childShards.map(_.shardId).mkString(", ")}"
                                           ) *>
                                             checkpointer.markEndOfShard() *>
                                             leaseCoordinator.childShardsDetected(childShards)
                                         ) *> ZStream.empty
                                     }
                                     .mapChunksZIO {
                                       _.mapZIO(record => toRecords(shardId, record))
                                         .map(_.flatten)
                                     }
                                     .dropWhile(r => !checkpointOpt.forall(aggregatedRecordIsAfterCheckpoint(r, _)))
                                     .mapChunksZIO { chunk =>
                                       chunk.lastOption
                                         .fold(ZIO.unit) { r =>
                                           val extendedSequenceNumber =
                                             ExtendedSequenceNumber(
                                               r.sequenceNumber,
                                               r.subSequenceNumber.getOrElse(0L)
                                             )
                                           checkpointer.setMaxSequenceNumber(extendedSequenceNumber)
                                         }
                                         .as(chunk)
                                     }
                                     .terminateOnPromiseCompleted(leaseLost)
              } yield (
                shardId,
                shardStream.ensuring {
                  checkpointer.checkpointAndRelease.catchAll {
                    case Left(e)               =>
                      ZIO.logWarning(s"Error in checkpoint and release: ${e}").unit
                    case Right(ShardLeaseLost) =>
                      ZIO.unit // This is fine during shutdown
                  }
                }.provideEnvironment(env),
                checkpointer
              )
            }
        }
      }
  }

  /**
   * Apply an effectful function to each record in a stream
   *
   * This is the easiest way to consume Kinesis records from a stream, while benefiting from all of Consumer's features
   * like parallel streaming, checkpointing and resharding.
   *
   * Simply provide an effectful function that is applied to each record and the rest is taken care of.
   * @param streamIdentifier
   *   Stream to consume from. Either just the name or the whole arn.
   * @param checkpointBatchSize
   *   Maximum number of records before checkpointing
   * @param checkpointDuration
   *   Maximum interval before checkpointing
   * @param recordProcessor
   *   A function for processing a `Record[T]`
   * @tparam R
   *   ZIO environment type required by the `deserializer` and the `recordProcessor`
   * @tparam T
   *   Type of record values
   * @return
   *   A ZIO that completes with Unit when record processing is stopped or fails when the consumer stream fails
   */
  def consumeWith[R, RC, T](
    streamIdentifier: StreamIdentifier,
    applicationName: String,
    deserializer: Deserializer[R, T],
    workerIdentifier: String = "worker1",
    fetchMode: FetchMode = FetchMode.Polling(),
    leaseCoordinationSettings: LeaseCoordinationSettings = LeaseCoordinationSettings(),
    initialPosition: InitialPosition = InitialPosition.TrimHorizon,
    emitDiagnostic: DiagnosticEvent => UIO[Unit] = _ => ZIO.unit,
    shardAssignmentStrategy: ShardAssignmentStrategy = ShardAssignmentStrategy.balanced(),
    checkpointBatchSize: Long = 200,
    checkpointDuration: Duration = 5.minutes
  )(
    recordProcessor: Record[T] => RIO[RC, Unit]
  ): ZIO[
    R with RC with Kinesis with LeaseRepository,
    Throwable,
    Unit
  ] =
    for {
      _ <- shardedStream(
             streamIdentifier,
             applicationName,
             deserializer,
             workerIdentifier,
             fetchMode,
             leaseCoordinationSettings,
             initialPosition,
             emitDiagnostic,
             shardAssignmentStrategy
           ).flatMapPar(Int.MaxValue) { case (_, shardStream, checkpointer) =>
             shardStream
               .tap(record => recordProcessor(record) *> checkpointer.stage(record))
               .viaFunction(
                 checkpointer.checkpointBatched[RC](nr = checkpointBatchSize, interval = checkpointDuration)
               )
           }.runDrain
    } yield ()

  private[zionative] val isThrottlingException: PartialFunction[Throwable, Unit] = {
    case _: KmsThrottlingException                 => ()
    case _: ProvisionedThroughputExceededException => ()
    case _: LimitExceededException                 => ()
  }

  private[zionative] def retryOnThrottledWithSchedule[R, A](
    schedule: Schedule[R, Throwable, A]
  ): Schedule[R, Throwable, (Throwable, A)] =
    Schedule.recurWhile[Throwable](e => isThrottlingException.lift(e).isDefined) && schedule

  private[client] def childShardToShard(s: ChildShard.ReadOnly): Shard.ReadOnly = {
    val parentShards = s.parentShards

    val shard = Shard(
      s.shardId,
      hashKeyRange = s.hashKeyRange.asEditable,
      sequenceNumberRange = SequenceNumberRange(SequenceNumber("0"), None)
    )

    val shardWithParents =
      if (parentShards.size == 2)
        shard.copy(parentShardId = Some(parentShards.head), adjacentParentShardId = Some(parentShards(1)))
      else if (parentShards.size == 1)
        shard.copy(parentShardId = Some(parentShards.head))
      else
        throw new IllegalArgumentException(s"Unexpected nr of parent shards: ${parentShards.size}")

    shardWithParents.asReadOnly
  }

  val defaultEnvironment: ZLayer[Any, Throwable, Kinesis with LeaseRepository with CloudWatch] =
    HttpClientBuilder.make() >>> zio.aws.core.config.AwsConfig.default >>>
      (kinesisAsyncClientLayer() ++ (dynamoDbAsyncClientLayer() >>> DynamoDbLeaseRepository.live) ++ cloudWatchAsyncClientLayer())

  sealed trait InitialPosition
  object InitialPosition {
    case object Latest                         extends InitialPosition
    case object TrimHorizon                    extends InitialPosition
    case class AtTimestamp(timestamp: Instant) extends InitialPosition

    def toStartingPosition(p: InitialPosition): StartingPosition =
      p match {
        case InitialPosition.Latest                 => StartingPosition(ShardIteratorType.LATEST)
        case InitialPosition.TrimHorizon            => StartingPosition(ShardIteratorType.TRIM_HORIZON)
        case InitialPosition.AtTimestamp(timestamp) =>
          StartingPosition(ShardIteratorType.AT_TIMESTAMP, timestamp = Some(Timestamp(timestamp)))
      }
  }

  private[zionative] val checkpointToStartingPosition
    : (Either[SpecialCheckpoint, ExtendedSequenceNumber], InitialPosition) => StartingPosition = {
    case (Left(SpecialCheckpoint.TrimHorizon), _)                                      => StartingPosition(ShardIteratorType.TRIM_HORIZON)
    case (Left(SpecialCheckpoint.Latest), _)                                           => StartingPosition(ShardIteratorType.LATEST)
    case (Left(SpecialCheckpoint.AtTimestamp), InitialPosition.AtTimestamp(timestamp)) =>
      StartingPosition(ShardIteratorType.AT_TIMESTAMP, timestamp = Some(Timestamp(timestamp)))
    case (Right(s), _)                                                                 =>
      StartingPosition(
        ShardIteratorType.AT_SEQUENCE_NUMBER,
        sequenceNumber = Some(SequenceNumber(s.sequenceNumber))
      )
    case s @ _                                                                         =>
      throw new IllegalArgumentException(s"${s} is not a valid checkpoint as starting position")
  }

  private[zionative] def aggregatedRecordIsAfterCheckpoint(
    record: Record[_],
    checkpoint: Either[SpecialCheckpoint, ExtendedSequenceNumber]
  ): Boolean =
    (checkpoint, record.subSequenceNumber) match {
      case (Left(_), _)                                                                                      => true
      case (Right(ExtendedSequenceNumber(sequenceNumber, subSequenceNumber)), Some(recordSubSequenceNumber)) =>
        (BigInt(record.sequenceNumber) > BigInt(sequenceNumber)) ||
        (BigInt(record.sequenceNumber) == BigInt(sequenceNumber) && recordSubSequenceNumber > subSequenceNumber)
      case (Right(ExtendedSequenceNumber(sequenceNumber, _)), None)                                          =>
        BigInt(record.sequenceNumber) > BigInt(sequenceNumber)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy