All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.org.brightify.hyperdrive.krpc.protocol.DefaultRPCNode.kt Maven / Gradle / Ivy

package org.brightify.hyperdrive.krpc.protocol

import kotlinx.coroutines.*
import org.brightify.hyperdrive.Logger
import org.brightify.hyperdrive.krpc.RPCConnection
import org.brightify.hyperdrive.krpc.RPCTransport
import org.brightify.hyperdrive.krpc.ServiceRegistry
import org.brightify.hyperdrive.krpc.application.RPCNode
import org.brightify.hyperdrive.krpc.application.RPCNodeExtension
import org.brightify.hyperdrive.krpc.description.RunnableCallDescription
import org.brightify.hyperdrive.krpc.description.ServiceCallIdentifier
import org.brightify.hyperdrive.krpc.error.RPCNotFoundError
import org.brightify.hyperdrive.krpc.impl.DefaultServiceRegistry
import org.brightify.hyperdrive.krpc.impl.MutableConcatServiceRegistry
import org.brightify.hyperdrive.krpc.impl.ProtocolBasedRPCTransport
import org.brightify.hyperdrive.krpc.protocol.ascension.ColdBistreamRunner
import org.brightify.hyperdrive.krpc.protocol.ascension.ColdDownstreamRunner
import org.brightify.hyperdrive.krpc.protocol.ascension.ColdUpstreamRunner
import org.brightify.hyperdrive.krpc.protocol.ascension.PayloadSerializer
import org.brightify.hyperdrive.krpc.protocol.ascension.RPCHandshakePerformer
import org.brightify.hyperdrive.krpc.protocol.ascension.SingleCallRunner
import kotlin.reflect.KClass

class HandshakeFailedException(val rpcMesage: String): Exception("Handshake has failed: $rpcMesage")

class RPCExtensionServiceRegistry(extensions: List): ServiceRegistry {
    private val registry = DefaultServiceRegistry()

    init {
        for (extension in extensions) {
            for (service in extension.providedServices) {
                registry.register(service)
            }
        }
    }

    override fun > getCallById(id: ServiceCallIdentifier, type: KClass): T? {
        return registry.getCallById(id, type)
    }
}

class DefaultRPCNode(
    override val contract: Contract,
    val transport: RPCTransport,
): RPCNode {
    private companion object {
        val logger = Logger()
    }

    override fun  getExtension(identifier: RPCNodeExtension.Identifier): E? {
        @Suppress("UNCHECKED_CAST")
        return contract.extensions[identifier] as? E
    }

    class Contract(
        override val payloadSerializer: PayloadSerializer,
        internal val protocol: RPCProtocol,
        internal val extensions: Map, RPCNodeExtension>,
    ): RPCNode.Contract

    class Factory(
        private val handshakePerformer: RPCHandshakePerformer,
        private val payloadSerializerFactory: PayloadSerializer.Factory,
        private val extensionFactories: List>,
        private val providedServiceRegistry: ServiceRegistry,
    ) {
        suspend fun create(connection: RPCConnection): DefaultRPCNode {
            when (val handshakeResult = handshakePerformer.performHandshake(connection)) {
                is RPCHandshakePerformer.HandshakeResult.Success -> {
                    val payloadSerializer = payloadSerializerFactory.create(handshakeResult.selectedFrameSerializer.format)

                    // TODO: Check which extensions are supported by the other party first.
                    val extensions = extensionFactories.associate { it.identifier to it.create() }
                    val extensionList = extensions.values.toList()
                    val interceptorRegistry = DefaultRPCInterceptorRegistry(extensionList, extensionList)
                    val extendedServiceRegistry = MutableConcatServiceRegistry(
                        RPCExtensionServiceRegistry(extensionList),
                        InterceptorEnabledServiceRegistry(providedServiceRegistry, interceptorRegistry.combinedIncomingInterceptor())
                    )
                    val implementationRegistry = DefaultRPCImplementationRegistry(payloadSerializer, extendedServiceRegistry)

                    val protocol = handshakeResult.selectedProtocolFactory.create(
                        connection,
                        handshakeResult.selectedFrameSerializer,
                        implementationRegistry,
                    )

                    val transport = ProtocolBasedRPCTransport(protocol, payloadSerializer)

                    val extendedTransport = InterceptorEnabledRPCTransport(transport, interceptorRegistry.combinedOutgoingInterceptor())
                    val contract = Contract(payloadSerializer, protocol, extensions)

                    return DefaultRPCNode(contract, extendedTransport)
                }
                is RPCHandshakePerformer.HandshakeResult.Failed -> {
                    val exception = HandshakeFailedException(handshakeResult.message)
                    connection.cancel("Handshake failed: ${handshakeResult.message}.")
                    throw exception
                }
            }
        }
    }

    suspend fun run(onInitializationCompleted: suspend () -> Unit): Unit = coroutineScope {
        // We need the protocol to be running before we bind the extensions.
        val runningProtocol = async { contract.protocol.run() }

        val extensions = contract.extensions.values
        for (extension in extensions) {
            extension.bind(transport, contract)
        }

        onInitializationCompleted()

        val parallelWorkContext = extensions.fold(coroutineContext) { accumulator, extension ->
            extension.enhanceParallelWorkContext(accumulator)
        }

        val runningParallelWork = launch(parallelWorkContext) {
            extensions.map { extension ->
                async { extension.whileConnected() }
            }.awaitAll()
        }

        // We want to the background work to end when the protocol does.
        runningProtocol.invokeOnCompletion {
            runningParallelWork.cancel("Protocol has completed.", it)
        }
    }
}

class DefaultRPCImplementationRegistry(
    private val payloadSerializer: PayloadSerializer,
    private val serviceRegistry: ServiceRegistry,
): RPCImplementationRegistry {
    override fun  callImplementation(id: ServiceCallIdentifier, type: KClass): T {
        val runnableCall = serviceRegistry.getCallById(id, RunnableCallDescription::class)
        @Suppress("UNCHECKED_CAST")
        return when (runnableCall) {
            is RunnableCallDescription.Single<*, *> -> SingleCallRunner.Callee(payloadSerializer, runnableCall) as T
            is RunnableCallDescription.ColdUpstream<*, *, *> -> ColdUpstreamRunner.Callee(payloadSerializer, runnableCall) as T
            is RunnableCallDescription.ColdDownstream<*, *> -> ColdDownstreamRunner.Callee(payloadSerializer, runnableCall) as T
            is RunnableCallDescription.ColdBistream<*, *, *> -> ColdBistreamRunner.Callee(payloadSerializer, runnableCall) as T
            null -> throw RPCNotFoundError(id)
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy