All Downloads are FREE. Search and download functionalities are using the official Maven repository.

no.ks.kes.test.KesTestSetup.kt Maven / Gradle / Ivy

The newest version!
package no.ks.kes.test

import no.ks.kes.lib.*
import no.ks.kes.serdes.jackson.JacksonCmdSerdes
import no.ks.kes.serdes.jackson.JacksonEventSerdes
import no.ks.kes.serdes.jackson.JacksonSagaStateSerdes
import java.nio.charset.StandardCharsets
import java.time.Instant
import java.util.*
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import kotlin.reflect.KClass

private val LOG = mu.KotlinLogging.logger { }

suspend fun withKes(eventSerdes: EventSerdes, cmdSerdes: CmdSerdes, context: suspend (KesTestSetup) -> Unit) {
    KesTestSetup(eventSerdes, cmdSerdes).use {
        context.invoke(it)
    }
}

suspend fun withKes(cmds: Set>>, events: Set>>, context: suspend (KesTestSetup) -> Unit) {
    withKes(eventSerdes = JacksonEventSerdes(events = events), cmdSerdes = JacksonCmdSerdes(cmds), context)
}

class KesTestSetup(val eventSerdes: EventSerdes, val cmdSerdes: CmdSerdes) : AutoCloseable {
    val eventStream = TestEventStream()

    val subscriberFactory: EventSubscriberFactory by lazy {
        TestEventSubscriberFactory(eventSerdes, eventStream)
    }

    val aggregateRepository: AggregateRepository by lazy {
        TestAggregateRepository(eventSerdes, eventStream)
    }

    fun createCommandQueue(cmdHandlers: Set> = emptySet()) = TestCommandQueue(cmdHandlers, cmdSerdes)
    fun createSagaRepository(commandQueue: TestCommandQueue, sagaStateSerdes: SagaStateSerdes = JacksonSagaStateSerdes()) =
            TestSagaRepository(sagaStateSerdes = sagaStateSerdes) { aggregateId, command ->
        commandQueue.addToQueue(aggregateId, command)
    }

    val projectionRepository = TestProjectionRepository()

    override fun close() {
        eventStream.close()
    }
}

data class AggregateKey(val type: String, val aggregateId: UUID)
class TestEventStream : AutoCloseable {
    private val stream: MutableMap>> = mutableMapOf()
    private val listeners = mutableSetOf()

    fun add(aggregateKey: AggregateKey, eventWrappers: List>) {
        stream.getOrDefault(aggregateKey, emptyList()).apply {
            stream[aggregateKey] = this.plus(eventWrappers)
        }.also {
            eventWrappers.forEach { event ->
                LOG.debug { "Sender event" }
                listeners.forEach { listener ->
                    listener.eventAdded(event)
                }
            }
        }
    }

    fun get(aggregateKey: AggregateKey): List>? = stream[aggregateKey]

    fun eventCount(): Long = stream.map { it.value.size }.sum().toLong()

    fun addListener(eventListener: EventListener) {
        listeners.add(eventListener)
    }

    fun removeListener(eventListener: EventListener) {
        listeners.remove(eventListener)
    }

    override fun close() {
        listeners.clear()
    }
}

interface EventListener {
    fun eventAdded(eventWrapper: Event<*>)
}

class TestEventSubscription(private val factory: TestEventSubscriberFactory,
                            private val onEvent: (EventWrapper>) -> Unit,
                            private val closeHandler: (TestEventSubscription) -> Unit
) : EventSubscription, EventListener, AutoCloseable {
    private val lastProcessedEvent = AtomicLong(-1)
    override fun lastProcessedEvent(): Long = lastProcessedEvent.get()
    override fun eventAdded(eventWrapper: Event<*>) {
        EventUpgrader.upgrade(eventWrapper.eventData).run {
            onEvent.invoke(EventWrapper(
                Event(
                    aggregateId = eventWrapper.aggregateId,
                    eventData = this
                ),
                    eventNumber = lastProcessedEvent.getAndIncrement(),
                    serializationId = factory.getSerializationId(this::class as KClass>)
            ))
        }
    }

    override fun close() {
        closeHandler.invoke(this)
    }
}

class TestEventSubscriberFactory(private val serdes: EventSerdes, private val testEventStream: TestEventStream) : EventSubscriberFactory {
    override fun getSerializationId(eventDataClass: KClass>): String = serdes.getSerializationId(eventDataClass)

    override fun createSubscriber(
        hwmId: String,
        fromEvent: Long,
        onEvent: (EventWrapper>) -> Unit,
        onError: (Exception) -> Unit,
        onLive: () -> Unit
    ): TestEventSubscription {
        return TestEventSubscription(factory = this, onEvent = onEvent) {
            testEventStream.removeListener(it)
        }.also { testEventStream.addListener(it) }
    }

}

internal class TestAggregateRepository(private val eventSerdes: EventSerdes, private val testEventStream: TestEventStream) : AggregateRepository() {

    override fun getSerializationId(eventDataClass: KClass>) = eventSerdes.getSerializationId(eventDataClass)

    override fun append(aggregateType: String, aggregateId: UUID, expectedEventNumber: ExpectedEventNumber, eventWrappers: List>) {
        AggregateKey(aggregateType, aggregateId).run {
            addEvent(this, eventWrappers)
        }
    }

    override fun  read(aggregateId: UUID, aggregateType: String, applicator: (state: A?, event: EventWrapper<*>) -> A?): AggregateReadResult =
            testEventStream.get(AggregateKey(aggregateType, aggregateId))?.fold(null as A? to null as Long?, { a, e ->
                applicator.invoke(a.first, EventWrapper(Event(aggregateId,e.eventData,null), getEventIndex(), eventSerdes.getSerializationId(e.eventData::class))) to getEventIndex()
            })?.let {
                when {
                    //when the aggregate stream has events, but applying these did not lead to a initialized state
                    it.first == null && it.second != null -> AggregateReadResult.UninitializedAggregate(it.second!!)

                    //when the aggregate stream has events, and applying these has lead to a initialized state
                    it.first != null && it.second != null -> AggregateReadResult.InitializedAggregate(it.first!!, it.second!!)

                    //when the aggregate stream has no events
                    else -> error("Nothing to read")
                }
            } ?: AggregateReadResult.NonExistingAggregate


    private fun getEventIndex() = testEventStream.eventCount()

    private fun addEvent(aggregateKey: AggregateKey, eventWrappers: List>) {
        testEventStream.add(aggregateKey, eventWrappers)
    }

}

class TestHwmTrackerRepository : HwmTrackerRepository {

    private val subscriptionIndex: MutableMap = mutableMapOf()
    override fun current(subscriber: String): Long? = subscriptionIndex[subscriber]?.get()

    override fun getOrInit(subscriber: String): Long = subscriptionIndex.getOrDefault(subscriber, AtomicLong(-1))
            .apply {
                subscriptionIndex[subscriber] = this
            }.let {
                it.get()
            }

    override fun update(subscriber: String, hwm: Long) {
        subscriptionIndex[subscriber]?.apply {
            this.set(hwm)
        } ?: error("""Can not update hwm "$subscriber" that has not been initialized""")
    }
}

class TestProjectionRepository : ProjectionRepository {
    override val hwmTracker: HwmTrackerRepository = TestHwmTrackerRepository()

    override fun transactionally(runnable: () -> Unit) {
        runnable.invoke()
    }
}


internal data class CmdStatus(val aggregateId: UUID, val serializationId: String, val errorId: UUID? = null, val retries: Int = 0, val nextExecution: Instant, val deserializedCommd: String)
class TestCommandQueue(cmdHandlers: Set>, private val cmdSerdes: CmdSerdes) : CommandQueue(cmdHandlers) {

    private val cmdIdSequence = AtomicLong(0L)
    private val commandQueue = mutableMapOf()
    private val commandsAwaiting
        get() = commandQueue.filterValues { Instant.now().isAfter(it.nextExecution) }.toSortedMap(compareBy { it })

    override fun delete(cmdId: Long) {
        commandQueue.minusAssign(cmdId)
    }

    override fun incrementAndSetError(cmdId: Long, errorId: UUID) {
        commandQueue[cmdId]?.apply {
            commandQueue[cmdId] = copy(errorId = errorId)
        }
    }

    override fun incrementAndSetNextExecution(cmdId: Long, nextExecution: Instant) {
        commandQueue[cmdId]?.apply {
            commandQueue[cmdId] = copy(nextExecution = nextExecution, retries = retries + 1)
        }
    }

    override fun nextCmd(): CmdWrapper>? = commandsAwaiting.toList().firstOrNull()?.let {
        CmdWrapper(id = it.first, retries = it.second.retries, cmd = cmdSerdes.deserialize(it.second.deserializedCommd.toByteArray(), it.second.serializationId))
    }.also { LOG.debug { "nextCmd is: $it" } }


    override fun transactionally(runnable: () -> Unit) {
        runnable.invoke()
    }

    fun addToQueue(aggregateId: UUID, cmd: Cmd<*>) {
        commandQueue.plusAssign(cmdIdSequence.incrementAndGet() to CmdStatus(
                aggregateId = aggregateId,
                serializationId = cmdSerdes.getSerializationId(cmd::class as KClass>),
                nextExecution = Instant.now(),
                deserializedCommd = cmdSerdes.serialize(cmd).toString(StandardCharsets.UTF_8)
        ))
    }
}

class TestSagaRepository(private val sagaStateSerdes: SagaStateSerdes, private val addCommandToQueue: (UUID, Cmd<*>) -> Unit) : SagaRepository {

    private val scheduledExecutor = Executors.newScheduledThreadPool(5)
    private val sagaStates = mutableMapOf()

    private data class SagaKey(val correlationId: UUID, val serializationId: String)

    private val timeouts = mutableMapOf()

    private data class TimeoutKey(val serializationId: String, val correlationId: UUID, val timeoutId: String)
    private data class TimeoutEntry(val timeout: Instant, val error: Int = 0)

    override fun  getSagaState(correlationId: UUID, serializationId: String, sagaStateClass: KClass): T? =
            sagaStates[SagaKey(correlationId, serializationId)]?.let {
                sagaStateSerdes.deserialize(it, sagaStateClass)
            }


    override fun update(states: Set) {
        LOG.info { "Updating sagas: $states" }
        states.filterIsInstance()
                .forEach {
                    sagaStates.plusAssign(SagaKey(it.correlationId, it.serializationId) to sagaStateSerdes.serialize(it.newState))
                }

        states.filterIsInstance().forEach {
            it.timeouts.forEach { timeout ->
                timeouts.plusAssign(TimeoutKey(correlationId = it.correlationId, serializationId = it.serializationId, timeoutId = timeout.timeoutId) to TimeoutEntry(timeout = Instant.now()))

            }
        }

        states.filterIsInstance().forEach { sagaUpdate ->
            SagaKey(correlationId = sagaUpdate.correlationId, serializationId = sagaUpdate.serializationId).let { sagaKey ->
                if (sagaStates.containsKey(sagaKey)) {
                    sagaUpdate?.newState?.apply {
                        sagaStates[sagaKey] = sagaStateSerdes.serialize(this)
                    }
                }
            }
        }
        states.flatMap { it.commands }.forEach {
            addCommandToQueue.invoke(it.aggregateId, it)
        }
    }

    override fun getReadyTimeouts(): SagaRepository.Timeout? =
            timeouts.filter { entry -> entry.value.error == 0 && entry.value.timeout.isBefore(Instant.now()) }.toList().firstOrNull()?.let {
                SagaRepository.Timeout(it.first.correlationId, it.first.serializationId, it.first.timeoutId)
            }

    override fun deleteTimeout(timeout: SagaRepository.Timeout) {
        timeouts.minusAssign(TimeoutKey(serializationId = timeout.sagaSerializationId, correlationId = timeout.sagaCorrelationId, timeoutId = timeout.timeoutId))
    }

    override fun transactionally(runnable: () -> Unit) {
        runnable.invoke()
    }

    override val hwmTracker: HwmTrackerRepository = TestHwmTrackerRepository()
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy