commonMain.app.cash.zipline.internal.bridge.calls.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of zipline-jvm Show documentation
Show all versions of zipline-jvm Show documentation
Runs Kotlin/JS libraries in Kotlin/JVM and Kotlin/Native programs
The newest version!
/*
* Copyright (C) 2022 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app.cash.zipline.internal.bridge
import app.cash.zipline.CallResult
import app.cash.zipline.ZiplineApiMismatchException
import app.cash.zipline.ZiplineFunction
import app.cash.zipline.ZiplineScoped
import app.cash.zipline.ZiplineService
import app.cash.zipline.ziplineServiceSerializer
import kotlinx.serialization.KSerializer
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.encoding.decodeStructure
import kotlinx.serialization.encoding.encodeStructure
import kotlinx.serialization.json.JsonDecoder
internal class InternalCall(
/** This is not-null, but may refer to a service that is not known by this endpoint. */
val serviceName: String,
/** This is non-null for outbound calls. */
val argsListSerializer: ArgsListSerializer? = null,
/** This is non-null for suspending outbound calls. */
val suspendCallbackSerializer: KSerializer<*>? = null,
/** This is absent for outbound calls. */
val inboundService: InboundService<*>? = null,
/**
* The function being called. If the function is unknown to the receiver, it will synthesize a
* [ZiplineFunction] instance that always throws [ZiplineApiMismatchException].
*/
val function: ZiplineFunction<*>,
/**
* If this function is suspending, this callback is not null. The function returns an encoded
* [CancelCallback] and the response is delivered to the [SuspendCallback].
*/
val suspendCallback: SuspendCallback? = null,
val args: List<*>,
) {
override fun toString() =
"Call(receiver=$serviceName, function=${function.signature}, args=$args)"
}
/** This uses [Int] as a placeholder; in practice the element type depends on the argument type. */
private val argsListDescriptor = ListSerializer(Int.serializer()).descriptor
/** This uses [Int] as a placeholder; it doesn't matter 'cause we're only encoding failures. */
internal val failureSuspendCallbackSerializer = ziplineServiceSerializer>()
/** Serialize any cancel callback using pass-by-reference. */
internal val cancelCallbackSerializer = ziplineServiceSerializer()
/**
* Encode and decode calls using `ZiplineFunction.argsListSerializer`.
*
* When serializing outbound calls the function instance is a member of the call. To deserialize
* inbound calls the function must be looked up from [endpoint] using the service and function name.
*
* This serializer is weird! Its args serializer is dependent on other properties. Therefore, it
* (reasonably) assumes that JSON is decoded in the same order it's encoded.
*/
internal class RealCallSerializer(
private val endpoint: Endpoint,
) : KSerializer {
override val descriptor = buildClassSerialDescriptor("RealCall") {
element("service", String.serializer().descriptor)
element("function", String.serializer().descriptor)
element("callback", String.serializer().descriptor)
element("args", argsListDescriptor)
}
override fun serialize(encoder: Encoder, value: InternalCall) {
encoder.encodeStructure(descriptor) {
encodeStringElement(descriptor, 0, value.serviceName)
encodeStringElement(descriptor, 1, value.function.id)
if (value.suspendCallback != null) {
@Suppress("UNCHECKED_CAST") // We don't declare a type T for the result of this call.
encodeSerializableElement(
descriptor,
2,
value.suspendCallbackSerializer as KSerializer,
value.suspendCallback,
)
}
encodeSerializableElement(descriptor, 3, value.argsListSerializer!!, value.args)
}
}
override fun deserialize(decoder: Decoder): InternalCall {
val pushedTakeScope = endpoint.takeScope
try {
return decoder.decodeStructure(descriptor) {
var serviceName = ""
var inboundService: InboundService<*>? = null
var functionId = ""
var function: ZiplineFunction<*>? = null
var suspendCallback: SuspendCallback? = null
var args = listOf()
while (true) {
when (val index = decodeElementIndex(descriptor)) {
0 -> {
serviceName = decodeStringElement(descriptor, index)
inboundService = endpoint.inboundServices[serviceName]
endpoint.takeScope = (inboundService?.service as? ZiplineScoped)?.scope
}
1 -> {
functionId = decodeStringElement(descriptor, index)
function = inboundService?.type?.functionsById?.get(functionId)
}
2 -> {
@Suppress("UNCHECKED_CAST") // We don't declare a type T for the result of this call.
val serializer = when (function) {
is SuspendingZiplineFunction<*> -> function.suspendCallbackSerializer
// We can use any suspend callback if we're only returning failures.
else -> failureSuspendCallbackSerializer
} as KSerializer>
suspendCallback = decodeSerializableElement(
descriptor,
index,
serializer,
)
}
3 -> {
val argsListSerializer = when (function) {
is SuspendingZiplineFunction<*> -> function.argsListSerializer
is ReturningZiplineFunction<*> -> function.argsListSerializer
else -> null
}
if (argsListSerializer != null) {
args = decodeSerializableElement(descriptor, index, argsListSerializer)
} else {
// Discard args for unknown function.
(decoder as JsonDecoder).decodeJsonElement()
}
}
DECODE_DONE -> break
else -> error("Unexpected index: $index")
}
}
return@decodeStructure InternalCall(
serviceName = serviceName,
inboundService = inboundService ?: unknownService(),
function = function ?: unknownFunction(
functionId,
suspendCallback,
when (inboundService) {
null -> ZiplineApiMismatchException.UNKNOWN_SERVICE
else -> ZiplineApiMismatchException.UNKNOWN_FUNCTION
},
),
suspendCallback = suspendCallback,
args = args,
)
}
} finally {
endpoint.takeScope = pushedTakeScope
}
}
/** Returns a fake service that implements no functions. */
private fun unknownService(): InboundService<*> {
return InboundService(
type = RealZiplineServiceType(
name = "Unknown",
functions = listOf(),
),
service = object : ZiplineService {},
endpoint = endpoint,
)
}
/** Returns a function that always throws [ZiplineApiMismatchException] when called. */
private fun unknownFunction(
functionId: String,
suspendCallback: SuspendCallback?,
message: String,
): ZiplineFunction {
if (suspendCallback != null) {
return object : SuspendingZiplineFunction(
id = functionId,
signature = "suspend fun unknownFunction(): kotlin.Unit",
argSerializers = listOf(),
// Placeholder; we're only encoding failures.
resultSerializer = Int.serializer(),
suspendCallbackSerializer = failureSuspendCallbackSerializer,
) {
override suspend fun callSuspending(service: T, args: List<*>) =
throw ZiplineApiMismatchException(message)
}
} else {
return object : ReturningZiplineFunction(
id = functionId,
signature = "fun unknownFunction(): kotlin.Unit",
argSerializers = listOf(),
// Placeholder; we're only encoding failures.
resultSerializer = Int.serializer(),
) {
override fun call(service: T, args: List<*>) =
throw ZiplineApiMismatchException(message)
}
}
}
}
internal class ArgsListSerializer(
internal val serializers: List>,
) : KSerializer> {
override val descriptor = argsListDescriptor
override fun serialize(encoder: Encoder, value: List<*>) {
check(value.size == serializers.size)
encoder.encodeStructure(descriptor) {
for (i in serializers.indices) {
@Suppress("UNCHECKED_CAST") // We don't have a type argument T for each parameter.
encodeSerializableElement(descriptor, i, serializers[i] as KSerializer, value[i])
}
}
}
override fun deserialize(decoder: Decoder): List<*> {
return decoder.decodeStructure(descriptor) {
val result = mutableListOf()
for (i in serializers.indices) {
check(decodeElementIndex(descriptor) == i)
result += decodeSerializableElement(descriptor, i, serializers[i])
}
check(decodeElementIndex(descriptor) == DECODE_DONE)
return@decodeStructure result
}
}
}
/**
* Immediate result from invoking a returning or suspending function, that may include either a
* value (if the call never suspended) or a cancel callback (if it did suspend).
*/
internal class ResultOrCallback(
/** The function returned or failed without suspending. */
val result: Result? = null,
/** The function suspended. Only non-null for suspend calls. */
val callback: CancelCallback? = null,
) {
init {
require((callback != null) != (result != null))
}
}
/** Combination of [ResultOrCallback] and [app.cash.zipline.CallResult]. */
internal class EncodedResultOrCallback(
val value: ResultOrCallback<*>,
val json: String,
serviceNames: List,
) {
val serviceNames: List = serviceNames.toList() // Defensive copy.
/** The call result. Null if this is a callback. */
val callResult: CallResult?
get() {
val result = value.result ?: return null
return CallResult(
result,
json,
serviceNames,
)
}
}
internal class ResultOrCallbackSerializer(
internal val successSerializer: KSerializer,
) : KSerializer> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("Result") {
element("cancelCallback", cancelCallbackSerializer.descriptor)
element("failure", ThrowableSerializer.descriptor)
element("success", successSerializer.descriptor)
}
override fun serialize(encoder: Encoder, value: ResultOrCallback) {
encoder.encodeStructure(descriptor) {
if (value.callback != null) {
encodeSerializableElement(descriptor, 0, cancelCallbackSerializer, value.callback)
return@encodeStructure
}
val result = value.result!!
val throwable = result.exceptionOrNull()
if (throwable != null) {
encodeSerializableElement(descriptor, 1, ThrowableSerializer, throwable)
return@encodeStructure
}
@Suppress("UNCHECKED_CAST") // We know the value of a success result is a 'T'.
encodeSerializableElement(descriptor, 2, successSerializer, result.getOrNull() as T)
}
}
override fun deserialize(decoder: Decoder): ResultOrCallback {
return decoder.decodeStructure(descriptor) {
var result: Result? = null
var callback: CancelCallback? = null
while (true) {
when (val index = decodeElementIndex(descriptor)) {
0 -> callback = decodeSerializableElement(descriptor, 0, cancelCallbackSerializer)
1 -> result = Result.failure(
decodeSerializableElement(descriptor, 1, ThrowableSerializer),
)
2 -> result = Result.success(
decodeSerializableElement(descriptor, 2, successSerializer),
)
DECODE_DONE -> break
else -> error("Unexpected index: $index")
}
}
return@decodeStructure ResultOrCallback(result, callback)
}
}
}