All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.gu.contentapi.firehose.kinesis.SingleEventProcessor.scala Maven / Gradle / Ivy

package com.gu.contentapi.firehose.kinesis

import com.gu.thrift.serializer.ThriftDeserializer
import com.typesafe.scalalogging.LazyLogging
import com.twitter.scrooge.{ ThriftStruct, ThriftStructCodec }
import software.amazon.kinesis.lifecycle.ShutdownReason
import software.amazon.kinesis.lifecycle.events.{ InitializationInput, LeaseLostInput, ProcessRecordsInput, ShardEndedInput, ShutdownRequestedInput }
import software.amazon.kinesis.processor.{ RecordProcessorCheckpointer, ShardRecordProcessor }

import java.nio.charset.StandardCharsets
import java.util.Base64
import java.util.concurrent.atomic.{ AtomicInteger, AtomicLong }
import scala.jdk.CollectionConverters._
import scala.concurrent.duration._
import scala.util.{ Failure, Success, Try }

abstract class EventProcessor[EventT <: ThriftStruct: ThriftStructCodec]
  extends ShardRecordProcessor
  with LazyLogging {

  val checkpointInterval: Duration
  val maxCheckpointBatchSize: Int

  private[this] var shardId: String = _

  /* Use atomic to prevent any concurrent access issues */
  private[this] val lastCheckpointedAt = new AtomicLong(System.nanoTime())
  private[this] val recordsProcessedSinceCheckpoint = new AtomicInteger()

  override def initialize(input: InitializationInput): Unit = {
    this.shardId = input.shardId()
    logger.info(s"Initialized an event processor for shard $shardId")
  }

  override def processRecords(input: ProcessRecordsInput): Unit = {
    val events = input.records().asScala.flatMap { record =>
      val buffer = record.data()
      val op = ThriftDeserializer.deserialize(buffer)
      op match {
        case Success(event) => Some(event)
        case Failure(e) => {
          logger.error(s"deserialization of event buffer failed: ${e.getMessage}", e)
          buffer.rewind()
          val encoded = Base64.getEncoder.encode(buffer)
          val b64string = new String(encoded.array(), StandardCharsets.ISO_8859_1)
          logger.error(s"Offending binary content: $b64string")
          None
        }
      }
    }.toSeq //.toSeq is required on Scala 2.13 as the comprehension above gives us a mutable.Buffer which is not directly compatible with Seq.

    processEvents(events)

    /* increment the record counter */
    recordsProcessedSinceCheckpoint.addAndGet(events.size)

    if (shouldCheckpointNow) {
      checkpoint(input.checkpointer())
    }
  }

  protected def processEvents(events: Seq[EventT]): Unit

  /* Checkpoint after every X seconds or every Y records */
  private def shouldCheckpointNow =
    recordsProcessedSinceCheckpoint.get() >= maxCheckpointBatchSize ||
      lastCheckpointedAt.get() < System.nanoTime() - checkpointInterval.toNanos

  private def checkpoint(checkpointer: RecordProcessorCheckpointer) = {
    /* Store our latest position in the stream */
    checkpointer.checkpoint()

    /* Reset the counters */
    lastCheckpointedAt.set(System.nanoTime())
    recordsProcessedSinceCheckpoint.set(0)
  }

  def leaseLost(leaseLostInput: LeaseLostInput): Unit = {
    logger.info(s"Shutdown event processor for shard $shardId because lease was lost")
    shutdown(ShutdownReason.LEASE_LOST)
  }

  def shardEnded(shardEndedInput: ShardEndedInput): Unit = {
    logger.info(s"Shutdown event processor for shard $shardId because the shard ended")
    shutdown(ShutdownReason.SHARD_END)
  }

  def shutdownRequested(shutdownRequestedInput: ShutdownRequestedInput): Unit = {
    shutdownRequestedInput.checkpointer().checkpoint()
    logger.info(s"Shutdown event processor for shard $shardId because shutdown was requested")
    shutdown(ShutdownReason.REQUESTED)
  }

  /**
   * Subclass this method if you want to be informed of shutdown events.  The default implementation does nothing.
   * @param reason ShutdownReason indicating why the shutdown occurred
   */
  def shutdown(reason: ShutdownReason): Unit = {}
}

trait SingleEventProcessor[EventT <: ThriftStruct] extends EventProcessor[EventT] {

  override protected def processEvents(events: Seq[EventT]) = events foreach processEvent
  protected def processEvent(eventWithSize: EventT): Unit

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy