
commonMain.org.brightify.hyperdrive.krpc.extension.SessionNodeExtension.kt Maven / Gradle / Ivy
package org.brightify.hyperdrive.krpc.extension
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withContext
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.modules.SerializersModule
import org.brightify.hyperdrive.Logger
import org.brightify.hyperdrive.krpc.extension.session.DefaultSession
import org.brightify.hyperdrive.krpc.RPCTransport
import org.brightify.hyperdrive.krpc.SerializedPayload
import org.brightify.hyperdrive.krpc.application.RPCNode
import org.brightify.hyperdrive.krpc.application.RPCNodeExtension
import org.brightify.hyperdrive.krpc.description.*
import org.brightify.hyperdrive.krpc.error.RPCErrorSerializer
import org.brightify.hyperdrive.krpc.protocol.ascension.PayloadSerializer
import org.brightify.hyperdrive.krpc.session.Session
import org.brightify.hyperdrive.krpc.session.SessionContextKeyRegistry
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
import kotlin.reflect.KClass
class SessionNodeExtension internal constructor(
session: DefaultSession,
private val plugins: List,
): ContextUpdateService, RPCNodeExtension {
companion object {
val logger = Logger()
const val maximumRejections = 10
}
object Identifier: RPCNodeExtension.Identifier {
override val uniqueIdentifier: String = "builtin:Session"
override val extensionClass = SessionNodeExtension::class
}
interface Plugin {
suspend fun onBindComplete(session: Session) { }
suspend fun onContextChanged(session: Session, modifiedKeys: Set>) { }
}
class Factory(
private val sessionContextKeyRegistry: SessionContextKeyRegistry,
private val payloadSerializerFactory: PayloadSerializer.Factory,
private val plugins: List = emptyList(),
): RPCNodeExtension.Factory {
override val identifier = Identifier
override val isRequiredOnOtherSide = true
override fun create(): SessionNodeExtension {
return SessionNodeExtension(
DefaultSession(payloadSerializerFactory, sessionContextKeyRegistry),
plugins
)
}
}
private val _session: DefaultSession = session
val session: Session = _session
override val providedServices: List = listOf(
ContextUpdateService.Descriptor.describe(this)
)
private val modifiedKeysFlow = MutableSharedFlow>>()
override suspend fun bind(transport: RPCTransport, contract: RPCNode.Contract) {
_session.bind(transport, contract)
for (plugin in plugins) {
plugin.onBindComplete(session)
}
}
override suspend fun whileConnected() {
session.observeModifications()
.collect {
notifyPluginsContextChanged(it)
}
}
override suspend fun enhanceParallelWorkContext(context: CoroutineContext): CoroutineContext {
return context + session
}
override suspend fun update(request: ContextUpdateRequest): ContextUpdateResult = _session.update(request)
override suspend fun clear() = _session.clear()
private suspend fun notifyPluginsContextChanged(modifiedKeys: Set>) {
modifiedKeysFlow.emit(modifiedKeys)
for (plugin in plugins) {
plugin.onContextChanged(session, modifiedKeys)
}
}
override suspend fun interceptIncomingSingleCall(
payload: PAYLOAD,
call: RunnableCallDescription.Single,
next: suspend (PAYLOAD) -> RESPONSE,
): RESPONSE = withSessionIfNeeded(call) {
super.interceptIncomingSingleCall(payload, call, next)
}
override suspend fun interceptIncomingUpstreamCall(
payload: PAYLOAD,
stream: Flow,
call: RunnableCallDescription.ColdUpstream,
next: suspend (PAYLOAD, Flow) -> RESPONSE,
): RESPONSE = withSessionIfNeeded(call) {
super.interceptIncomingUpstreamCall(payload, stream, call, next)
}
override suspend fun interceptIncomingDownstreamCall(
payload: PAYLOAD,
call: RunnableCallDescription.ColdDownstream,
next: suspend (PAYLOAD) -> Flow,
): Flow = withSessionIfNeeded(call) {
super.interceptIncomingDownstreamCall(payload, call, next)
}
override suspend fun interceptIncomingBistreamCall(
payload: PAYLOAD,
stream: Flow,
call: RunnableCallDescription.ColdBistream,
next: suspend (PAYLOAD, Flow) -> Flow,
): Flow = withSessionIfNeeded(call) {
super.interceptIncomingBistreamCall(payload, stream, call, next)
}
@OptIn(ExperimentalContracts::class)
private suspend fun withSessionIfNeeded(call: RunnableCallDescription<*>, block: suspend () -> RESULT): RESULT {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}
// We don't want to inject the service into the coroutine context if it's the CoroutineSyncService's call.
return if (call.identifier.serviceId == ContextUpdateService.Descriptor.identifier) {
block()
} else {
// TODO: Replace `session` with an immutable copy
withContext(coroutineContext + session) {
val result = block()
session.awaitCompletedContextSync()
result
}
}
}
override suspend fun interceptOutgoingSingleCall(
payload: PAYLOAD,
call: SingleCallDescription,
next: suspend (PAYLOAD) -> RESPONSE,
): RESPONSE = withCompletedContextSyncIfNeeded(call) {
super.interceptOutgoingSingleCall(payload, call, next)
}
override suspend fun interceptOutgoingUpstreamCall(
payload: PAYLOAD,
stream: Flow,
call: ColdUpstreamCallDescription,
next: suspend (PAYLOAD, Flow) -> RESPONSE,
): RESPONSE = withCompletedContextSyncIfNeeded(call) {
super.interceptOutgoingUpstreamCall(payload, stream, call, next)
}
override suspend fun interceptOutgoingDownstreamCall(
payload: PAYLOAD,
call: ColdDownstreamCallDescription,
next: suspend (PAYLOAD) -> Flow,
): Flow = withCompletedContextSyncIfNeeded(call) {
super.interceptOutgoingDownstreamCall(payload, call, next)
}
override suspend fun interceptOutgoingBistreamCall(
payload: PAYLOAD,
stream: Flow,
call: ColdBistreamCallDescription,
next: suspend (PAYLOAD, Flow) -> Flow,
): Flow = withCompletedContextSyncIfNeeded(call) {
super.interceptOutgoingBistreamCall(payload, stream, call, next)
}
private suspend fun withCompletedContextSyncIfNeeded(call: CallDescription<*>, block: suspend () -> RESULT): RESULT {
// We can't wait for the context to sync if it's the ContextSyncService's call.
return if (call.identifier.serviceId == ContextUpdateService.Descriptor.identifier) {
block()
} else {
session.awaitCompletedContextSync()
block()
}
}
}
class RPCContribution(
val contribution: T,
contributionClass: KClass,
): CoroutineContext.Element {
data class Key(val contributionClass: KClass): CoroutineContext.Key>
override val key: CoroutineContext.Key<*> = Key(contributionClass)
}
suspend fun withContributed(module: SerializersModule, block: suspend () -> RESULT): RESULT {
return withContribution(module, SerializersModule::class, block)
}
internal fun CoroutineContext.contribution(contributionClass: KClass): T? {
return get(RPCContribution.Key(contributionClass))?.contribution
}
internal suspend inline fun contextContribution(): T? {
return coroutineContext.contribution(T::class)
}
internal suspend fun withContribution(contribution: T, contributionClass: KClass, block: suspend () -> RESULT): RESULT {
return withContext(coroutineContext + RPCContribution(contribution, contributionClass)) {
block()
}
}
@Serializable
class ContextItemDto(
val revision: Int,
val value: SerializedPayload,
)
typealias KeyDto = String
@Serializable
class ContextUpdateRequest(
val modifications: Map = emptyMap(),
) {
@Serializable
sealed class Modification {
abstract val oldRevisionOrNull: Int?
@Serializable
class Required(val oldRevision: Int? = null): Modification() {
override val oldRevisionOrNull: Int?
get() = oldRevision
}
@Serializable
class Set(val oldRevision: Int? = null, val newItem: ContextItemDto): Modification() {
override val oldRevisionOrNull: Int?
get() = oldRevision
}
@Serializable
class Remove(val oldRevision: Int): Modification() {
override val oldRevisionOrNull: Int?
get() = oldRevision
}
}
}
@Serializable
sealed class ContextUpdateResult {
@Serializable
object Accepted: ContextUpdateResult()
@Serializable
class Rejected(
val rejectedModifications: Map = emptyMap(),
): ContextUpdateResult() {
@Serializable
sealed class Reason {
@Serializable
object Removed: Reason()
@Serializable
class Updated(val newItem: ContextItemDto): Reason()
}
}
}
interface ContextUpdateService {
suspend fun update(request: ContextUpdateRequest): ContextUpdateResult
suspend fun clear()
class Client(
private val transport: RPCTransport,
): ContextUpdateService {
override suspend fun update(request: ContextUpdateRequest): ContextUpdateResult {
return transport.singleCall(Descriptor.Call.update, request)
}
override suspend fun clear() {
return transport.singleCall(Descriptor.Call.clear, Unit)
}
}
object Descriptor: ServiceDescriptor {
const val identifier = "builtin:hyperdrive.ContextSyncService"
override fun describe(service: ContextUpdateService): ServiceDescription {
return ServiceDescription(
identifier,
listOf(
Call.update.calling { request ->
service.update(request)
},
Call.clear.calling { reequest ->
service.clear()
}
)
)
}
object Call {
val update = SingleCallDescription(
ServiceCallIdentifier(identifier, "update"),
ContextUpdateRequest.serializer(),
ContextUpdateResult.serializer(),
RPCErrorSerializer(),
)
val clear = SingleCallDescription(
ServiceCallIdentifier(identifier, "clear"),
Unit.serializer(),
Unit.serializer(),
RPCErrorSerializer(),
)
}
}
}
data class UnsupportedKey(override val qualifiedName: String): Session.Context.Key {
override val serializer: KSerializer = SerializedPayload.serializer()
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy