All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.org.brightify.hyperdrive.krpc.extension.SessionNodeExtension.kt Maven / Gradle / Ivy

package org.brightify.hyperdrive.krpc.extension

import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withContext
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.modules.SerializersModule
import org.brightify.hyperdrive.Logger
import org.brightify.hyperdrive.krpc.extension.session.DefaultSession
import org.brightify.hyperdrive.krpc.RPCTransport
import org.brightify.hyperdrive.krpc.SerializedPayload
import org.brightify.hyperdrive.krpc.application.RPCNode
import org.brightify.hyperdrive.krpc.application.RPCNodeExtension
import org.brightify.hyperdrive.krpc.description.*
import org.brightify.hyperdrive.krpc.error.RPCErrorSerializer
import org.brightify.hyperdrive.krpc.protocol.ascension.PayloadSerializer
import org.brightify.hyperdrive.krpc.session.Session
import org.brightify.hyperdrive.krpc.session.SessionContextKeyRegistry
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
import kotlin.reflect.KClass

class SessionNodeExtension internal constructor(
    session: DefaultSession,
    private val plugins: List,
): ContextUpdateService, RPCNodeExtension {
    companion object {
        val logger = Logger()
        const val maximumRejections = 10
    }

    object Identifier: RPCNodeExtension.Identifier {
        override val uniqueIdentifier: String = "builtin:Session"
        override val extensionClass = SessionNodeExtension::class
    }

    interface Plugin {
        suspend fun onBindComplete(session: Session) { }

        suspend fun onContextChanged(session: Session, modifiedKeys: Set>) { }
    }

    class Factory(
        private val sessionContextKeyRegistry: SessionContextKeyRegistry,
        private val payloadSerializerFactory: PayloadSerializer.Factory,
        private val plugins: List = emptyList(),
    ): RPCNodeExtension.Factory {
        override val identifier = Identifier

        override val isRequiredOnOtherSide = true

        override fun create(): SessionNodeExtension {
            return SessionNodeExtension(
                DefaultSession(payloadSerializerFactory, sessionContextKeyRegistry),
                plugins
            )
        }
    }

    private val _session: DefaultSession = session
    val session: Session = _session

    override val providedServices: List = listOf(
        ContextUpdateService.Descriptor.describe(this)
    )

    private val modifiedKeysFlow = MutableSharedFlow>>()

    override suspend fun bind(transport: RPCTransport, contract: RPCNode.Contract) {
        _session.bind(transport, contract)

        for (plugin in plugins) {
            plugin.onBindComplete(session)
        }
    }

    override suspend fun whileConnected() {
        session.observeModifications()
            .collect {
                notifyPluginsContextChanged(it)
            }
    }

    override suspend fun enhanceParallelWorkContext(context: CoroutineContext): CoroutineContext {
        return context + session
    }

    override suspend fun update(request: ContextUpdateRequest): ContextUpdateResult = _session.update(request)

    override suspend fun clear() = _session.clear()

    private suspend fun notifyPluginsContextChanged(modifiedKeys: Set>) {
        modifiedKeysFlow.emit(modifiedKeys)

        for (plugin in plugins) {
            plugin.onContextChanged(session, modifiedKeys)
        }
    }

    override suspend fun  interceptIncomingSingleCall(
        payload: PAYLOAD,
        call: RunnableCallDescription.Single,
        next: suspend (PAYLOAD) -> RESPONSE,
    ): RESPONSE = withSessionIfNeeded(call) {
        super.interceptIncomingSingleCall(payload, call, next)
    }

    override suspend fun  interceptIncomingUpstreamCall(
        payload: PAYLOAD,
        stream: Flow,
        call: RunnableCallDescription.ColdUpstream,
        next: suspend (PAYLOAD, Flow) -> RESPONSE,
    ): RESPONSE = withSessionIfNeeded(call) {
        super.interceptIncomingUpstreamCall(payload, stream, call, next)
    }

    override suspend fun  interceptIncomingDownstreamCall(
        payload: PAYLOAD,
        call: RunnableCallDescription.ColdDownstream,
        next: suspend (PAYLOAD) -> Flow,
    ): Flow = withSessionIfNeeded(call) {
        super.interceptIncomingDownstreamCall(payload, call, next)
    }

    override suspend fun  interceptIncomingBistreamCall(
        payload: PAYLOAD,
        stream: Flow,
        call: RunnableCallDescription.ColdBistream,
        next: suspend (PAYLOAD, Flow) -> Flow,
    ): Flow = withSessionIfNeeded(call) {
        super.interceptIncomingBistreamCall(payload, stream, call, next)
    }

    @OptIn(ExperimentalContracts::class)
    private suspend fun  withSessionIfNeeded(call: RunnableCallDescription<*>, block: suspend () -> RESULT): RESULT {
        contract {
            callsInPlace(block, InvocationKind.EXACTLY_ONCE)
        }
        // We don't want to inject the service into the coroutine context if it's the CoroutineSyncService's call.
        return if (call.identifier.serviceId == ContextUpdateService.Descriptor.identifier) {
            block()
        } else {
            // TODO: Replace `session` with an immutable copy
            withContext(coroutineContext + session) {
                val result = block()
                session.awaitCompletedContextSync()
                result
            }
        }
    }

    override suspend fun  interceptOutgoingSingleCall(
        payload: PAYLOAD,
        call: SingleCallDescription,
        next: suspend (PAYLOAD) -> RESPONSE,
    ): RESPONSE = withCompletedContextSyncIfNeeded(call) {
        super.interceptOutgoingSingleCall(payload, call, next)
    }

    override suspend fun  interceptOutgoingUpstreamCall(
        payload: PAYLOAD,
        stream: Flow,
        call: ColdUpstreamCallDescription,
        next: suspend (PAYLOAD, Flow) -> RESPONSE,
    ): RESPONSE = withCompletedContextSyncIfNeeded(call) {
        super.interceptOutgoingUpstreamCall(payload, stream, call, next)
    }

    override suspend fun  interceptOutgoingDownstreamCall(
        payload: PAYLOAD,
        call: ColdDownstreamCallDescription,
        next: suspend (PAYLOAD) -> Flow,
    ): Flow = withCompletedContextSyncIfNeeded(call) {
        super.interceptOutgoingDownstreamCall(payload, call, next)
    }

    override suspend fun  interceptOutgoingBistreamCall(
        payload: PAYLOAD,
        stream: Flow,
        call: ColdBistreamCallDescription,
        next: suspend (PAYLOAD, Flow) -> Flow,
    ): Flow = withCompletedContextSyncIfNeeded(call) {
        super.interceptOutgoingBistreamCall(payload, stream, call, next)
    }

    private suspend fun  withCompletedContextSyncIfNeeded(call: CallDescription<*>, block: suspend () -> RESULT): RESULT {
        // We can't wait for the context to sync if it's the ContextSyncService's call.
        return if (call.identifier.serviceId == ContextUpdateService.Descriptor.identifier) {
            block()
        } else {
            session.awaitCompletedContextSync()
            block()
        }
    }
}

class RPCContribution(
    val contribution: T,
    contributionClass: KClass,
): CoroutineContext.Element {
    data class Key(val contributionClass: KClass): CoroutineContext.Key>

    override val key: CoroutineContext.Key<*> = Key(contributionClass)
}

suspend fun  withContributed(module: SerializersModule, block: suspend () -> RESULT): RESULT {
    return withContribution(module, SerializersModule::class, block)
}

internal fun  CoroutineContext.contribution(contributionClass: KClass): T? {
    return get(RPCContribution.Key(contributionClass))?.contribution
}

internal suspend inline fun  contextContribution(): T? {
    return coroutineContext.contribution(T::class)
}

internal suspend fun  withContribution(contribution: T, contributionClass: KClass, block: suspend () -> RESULT): RESULT {
    return withContext(coroutineContext + RPCContribution(contribution, contributionClass)) {
        block()
    }
}

@Serializable
class ContextItemDto(
    val revision: Int,
    val value: SerializedPayload,
)

typealias KeyDto = String

@Serializable
class ContextUpdateRequest(
    val modifications: Map = emptyMap(),
) {
    @Serializable
    sealed class Modification {
        abstract val oldRevisionOrNull: Int?

        @Serializable
        class Required(val oldRevision: Int? = null): Modification() {
            override val oldRevisionOrNull: Int?
                get() = oldRevision
        }

        @Serializable
        class Set(val oldRevision: Int? = null, val newItem: ContextItemDto): Modification() {
            override val oldRevisionOrNull: Int?
                get() = oldRevision
        }
        @Serializable
        class Remove(val oldRevision: Int): Modification() {
            override val oldRevisionOrNull: Int?
                get() = oldRevision
        }
    }
}

@Serializable
sealed class ContextUpdateResult {
    @Serializable
    object Accepted: ContextUpdateResult()
    @Serializable
    class Rejected(
        val rejectedModifications: Map = emptyMap(),
    ): ContextUpdateResult() {
        @Serializable
        sealed class Reason {
            @Serializable
            object Removed: Reason()
            @Serializable
            class Updated(val newItem: ContextItemDto): Reason()
        }
    }
}

interface ContextUpdateService {
    suspend fun update(request: ContextUpdateRequest): ContextUpdateResult

    suspend fun clear()

    class Client(
        private val transport: RPCTransport,
    ): ContextUpdateService {
        override suspend fun update(request: ContextUpdateRequest): ContextUpdateResult {
            return transport.singleCall(Descriptor.Call.update, request)
        }

        override suspend fun clear() {
            return transport.singleCall(Descriptor.Call.clear, Unit)
        }
    }

    object Descriptor: ServiceDescriptor {
        const val identifier = "builtin:hyperdrive.ContextSyncService"

        override fun describe(service: ContextUpdateService): ServiceDescription {
            return ServiceDescription(
                identifier,
                listOf(
                    Call.update.calling { request ->
                        service.update(request)
                    },
                    Call.clear.calling { reequest ->
                        service.clear()
                    }
                )
            )
        }

        object Call {
            val update = SingleCallDescription(
                ServiceCallIdentifier(identifier, "update"),
                ContextUpdateRequest.serializer(),
                ContextUpdateResult.serializer(),
                RPCErrorSerializer(),
            )

            val clear = SingleCallDescription(
                ServiceCallIdentifier(identifier, "clear"),
                Unit.serializer(),
                Unit.serializer(),
                RPCErrorSerializer(),
            )
        }
    }
}

data class UnsupportedKey(override val qualifiedName: String): Session.Context.Key {
    override val serializer: KSerializer = SerializedPayload.serializer()
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy