All Downloads are FREE. Search and download functionalities are using the official Maven repository.

no.ks.kes.lib.Saga.kt Maven / Gradle / Ivy

package no.ks.kes.lib

import mu.KotlinLogging
import java.time.Instant
import java.util.*
import kotlin.reflect.KClass

abstract class Saga(private val stateClass: KClass, val serializationId: String) {

    protected var eventInitializers = mutableListOf>, EventHandler.Initializer, STATE>>>()
    protected val eventApplicators = mutableListOf>, EventHandler.Applicator, STATE>>>()
    protected val timeoutApplicators = mutableListOf>, (s: ApplyContext) -> ApplyContext>>()

    internal fun getConfiguration(serializationIdFunction: (KClass>) -> String): ValidatedSagaConfiguration =
            ValidatedSagaConfiguration(
                    stateClass = stateClass,
                    sagaSerializationId = serializationId,
                    eventInitializers = eventInitializers,
                    eventApplicators = eventApplicators,
                    timeoutApplicators = timeoutApplicators,
                    serializationIdFunction = serializationIdFunction
            )

    protected inline fun > init(crossinline correlationId: (E, UUID) -> UUID = { _: EventData<*>, aggregateId: UUID -> aggregateId }, crossinline initializer: InitContext.(E, UUID) -> Unit) =
            initWrapped({ correlationId.invoke(it.event.eventData, it.event.aggregateId) }, { w: EventWrapper -> initializer.invoke(this, w.event.eventData, w.event.aggregateId) })

    protected inline fun > apply(crossinline correlationId: (E, UUID) -> UUID = { _: EventData<*>, aggregateId: UUID -> aggregateId }, crossinline handler: ApplyContext.(E, UUID) -> Unit) =
            applyWrapped({ correlationId.invoke(it.event.eventData, it.event.aggregateId) }, { w: EventWrapper -> handler.invoke(this, w.event.eventData, w.event.aggregateId) })

    protected inline fun > timeout(crossinline correlationId: (E, UUID) -> UUID = { _: EventData<*>, aggregateId: UUID -> aggregateId }, crossinline timeoutAt: (E) -> Instant, crossinline handler: ApplyContext.() -> Unit) {
        timeoutWrapped({ correlationId.invoke(it.event.eventData, it.event.aggregateId) }, { timeoutAt.invoke(it.event.eventData) }, handler)
    }

    @Suppress("UNCHECKED_CAST")
    protected inline fun > timeoutWrapped(
            crossinline correlationId: (EventWrapper) -> UUID = { it.event.aggregateId },
            crossinline timeoutAt: (EventWrapper) -> Instant,
            crossinline handler: ApplyContext.() -> Unit
    ) {
        eventApplicators.add(E::class as KClass> to EventHandler.Applicator(
                correlationId = { correlationId.invoke(it as EventWrapper) },
                handler = { e, p -> p.apply { timeouts.add(Timeout(timeoutAt.invoke(e as EventWrapper), e.serializationId)) } }
        ))

        timeoutApplicators.add(E::class as KClass> to { context -> handler.invoke(context); context })
    }

    @Suppress("UNCHECKED_CAST")
    protected inline fun > initWrapped(
            crossinline correlationId: (EventWrapper) -> UUID = { it.event.aggregateId },
            noinline handler: InitContext.(EventWrapper) -> Unit
    ) {
        eventInitializers.add(E::class as KClass> to EventHandler.Initializer(
                correlationId = { correlationId.invoke(it as EventWrapper) },
                handler = { e, context -> handler.invoke(context, e as EventWrapper); context }
        ))
    }

    @Suppress("UNCHECKED_CAST")
    protected inline fun > applyWrapped(
            crossinline correlationId: (EventWrapper) -> UUID = { it.event.aggregateId },
            crossinline handler: ApplyContext.(EventWrapper) -> Unit
    ) {
        eventApplicators.add(E::class as KClass> to EventHandler.Applicator(
                correlationId = { correlationId.invoke(it as EventWrapper) },
                handler = { e, context -> handler.invoke(context, e as EventWrapper); context }
        ))
    }

    class ValidatedSagaConfiguration(
        private val stateClass: KClass,
        val sagaSerializationId: String,
        serializationIdFunction: (KClass>) -> String,
        eventInitializers: List>, EventHandler.Initializer, STATE>>>,
        eventApplicators: List>, EventHandler.Applicator, STATE>>>,
        timeoutApplicators: List>, (s: ApplyContext) -> ApplyContext>>
    ) {
        private val eventInitializers: Map, STATE>>
        private val eventApplicators: Map, STATE>>
        private val timeoutApplicators: Map) -> ApplyContext> = timeoutApplicators.map { serializationIdFunction.invoke(it.first) to it.second }.toMap()

        init {
            val deprecatedEvents = (eventApplicators.map { it.first } + eventInitializers.map { it.first }).filter { it.deprecated }.map { it::class.simpleName!! }
            check(deprecatedEvents.isEmpty()) { "Saga $sagaSerializationId handles deprecated event(s) ${deprecatedEvents}, please update the saga configuraton" }

            val duplicateEventApplicators = eventApplicators.map { serializationIdFunction.invoke(it.first) }.groupBy { it }.filter { it.value.size > 1 }.map { it.key }
            check(duplicateEventApplicators.isEmpty()) { "There are multiple \"apply/timeout\" configurations for event-type(s) $duplicateEventApplicators in the configuration of $sagaSerializationId, only a single \"apply/timeout\" handler is allowed for each event type" }
            this.eventApplicators = eventApplicators.map { serializationIdFunction.invoke(it.first) to it.second }.toMap()

            val duplicateEventInitializers = eventInitializers.map { serializationIdFunction.invoke(it.first) }.groupBy { it }.filter { it.value.size > 1 }.map { it.key }.distinct()
            check(duplicateEventInitializers.isEmpty()) { "There are multiple \"init\" configurations for event-type(s) $duplicateEventInitializers in the configuration of $sagaSerializationId, only a single \"init\" handler is allowed for each event type" }
            this.eventInitializers = eventInitializers.map { serializationIdFunction.invoke(it.first) to it.second }.toMap()
        }

        @Suppress("UNCHECKED_CAST")
        fun handleEvent(wrapper: EventWrapper>, stateProvider: (correlationId: UUID, stateClass: KClass<*>) -> Any?): SagaRepository.Operation? {
            val correlationIds = (eventInitializers + eventApplicators)
                    .filter { it.key == wrapper.serializationId }
                    .map { it.value }
                    .map { it.correlationId.invoke(wrapper) }
                    .distinct()

            val sagaState = when {
                //this saga does not handle this event
                correlationIds.isEmpty() -> return null
                //each handler in a saga must produce the same correlation id
                correlationIds.size > 1 -> error("applying the event ${wrapper.serializationId} to the event-handlers in $sagaSerializationId produced non-identical correlation-ids, please verify the saga configuration")
                //let's see if there's a state for this saga
                else -> stateProvider.invoke(correlationIds.single(), stateClass)
            } as STATE?

            return if (sagaState == null) {
                //non existing saga state, attempting initialization
                eventInitializers[wrapper.serializationId]
                        ?.let {
                            val context = it.handler.invoke(wrapper, InitContext())
                            SagaRepository.Operation.Insert(
                                    correlationId = it.correlationId.invoke(wrapper),
                                    serializationId = sagaSerializationId,
                                    newState = context.newState!!,
                                    commands = context.commands
                            )
                        }
            } else {
                //pre-existing state, applying
                eventApplicators[wrapper.serializationId]
                        ?.let {
                            val context = it.handler.invoke(wrapper, ApplyContext(sagaState))
                            if (context.newState == null && context.commands.isEmpty() && context.timeouts.isEmpty())
                                null
                            else
                                SagaRepository.Operation.SagaUpdate(
                                        correlationId = it.correlationId.invoke(wrapper),
                                        serializationId = sagaSerializationId,
                                        newState = context.newState,
                                        commands = context.commands,
                                        timeouts = context.timeouts.toSet()
                                )
                        }
            }
        }

        @Suppress("UNCHECKED_CAST")
        internal fun handleTimeout(
                timeout: SagaRepository.Timeout,
                stateProvider: (correlationId: UUID, stateClass: KClass<*>) -> Any?
        ): SagaRepository.Operation.SagaUpdate? =
                if (timeout.sagaSerializationId != sagaSerializationId)
                    null
                else
                    timeoutApplicators[timeout.timeoutId]
                            ?.invoke(ApplyContext((stateProvider.invoke(timeout.sagaCorrelationId, stateClass)
                                    ?: error("A timeout was triggered, but the saga-repository does not contain the saga state: $timeout")) as STATE))
                            ?.let {
                                SagaRepository.Operation.SagaUpdate(
                                        correlationId = timeout.sagaCorrelationId,
                                        serializationId = timeout.sagaSerializationId,
                                        newState = it.newState,
                                        commands = it.commands,
                                        timeouts = it.timeouts.toSet()
                                )
                            }
    }

    sealed class EventHandler> {
        abstract val correlationId: (EventWrapper) -> UUID

        data class Applicator, S : Any>(
                override val correlationId: (EventWrapper) -> UUID,
                val handler: (e: EventWrapper, context: ApplyContext) -> ApplyContext
        ) : EventHandler()

        data class Initializer, S : Any>(
                override val correlationId: (EventWrapper) -> UUID,
                val handler: (e: EventWrapper, InitContext) -> InitContext
        ) : EventHandler()
    }

    data class Timeout(val triggerAt: Instant, val timeoutId: String)

    class ApplyContext(val state: S) {
        val commands = mutableListOf>()
        var newState: S? = null
        val timeouts = mutableListOf()

        fun dispatch(cmd: Cmd<*>) {
            commands.add(cmd)
        }

        fun setState(state: S) {
            newState = state
        }
    }

    class InitContext {
        var newState: S? = null
        val commands = mutableListOf>()

        fun dispatch(cmd: Cmd<*>) {
            commands.add(cmd)
        }

        fun setState(state: S) {
            newState = state
        }
    }
}