All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.utils.io.serialization.utils.kt Maven / Gradle / Ivy

/*
 * Copyright 2019-2022 Mamoe Technologies and contributors.
 *
 * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
 * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
 *
 * https://github.com/mamoe/mirai/blob/dev/LICENSE
 */

@file:JvmName("SerializationUtils")
@file:JvmMultifileClass
@file:Suppress("NOTHING_TO_INLINE")

package net.mamoe.mirai.internal.utils.io.serialization

import kotlinx.io.core.*
import kotlinx.io.streams.asInput
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.descriptors.SerialDescriptor
import net.mamoe.mirai.internal.message.contextualBugReportException
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestDataVersion2
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestDataVersion3
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestPacket
import net.mamoe.mirai.internal.network.protocol.data.proto.OidbSso
import net.mamoe.mirai.internal.utils.io.JceStruct
import net.mamoe.mirai.internal.utils.io.ProtoBuf
import net.mamoe.mirai.internal.utils.io.serialization.tars.Tars
import net.mamoe.mirai.internal.utils.io.serialization.tars.internal.DebugLogger
import net.mamoe.mirai.internal.utils.io.serialization.tars.internal.TarsDecoder
import net.mamoe.mirai.internal.utils.printStructure
import net.mamoe.mirai.utils.*
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract

internal typealias KtProtoBuf = kotlinx.serialization.protobuf.ProtoBuf

internal fun  ByteArray.loadWithUniPacket(
    deserializer: DeserializationStrategy,
    name: String? = null,
): T = this.read { readUniPacket(deserializer, name) }

internal fun  ByteArray.loadAs(
    deserializer: DeserializationStrategy,
    offset: Int = 0,
    length: Int = size - offset,
): T {
    if (this.size >= 4) {
        val possibleLength = this.toInt(offset = offset)
//        return doLoadAs(deserializer, offset = offset + 4, length = possibleLength)

        if (possibleLength == length || possibleLength == length - 4) {
            return doLoadAs(
                deserializer,
                offset = offset + 4,
                length = length - 4
            )
        }
    }

    return doLoadAs(deserializer, offset, length)
}

private fun  ByteArray.doLoadAs(
    deserializer: DeserializationStrategy,
    offset: Int,
    length: Int,
): T {
    try {
        return this.inputStream(offset = offset, length = length).asInput().use { input ->
            Tars.UTF_8.load(deserializer, input)
        }
    } catch (originalException: Exception) {
        val log = ByteArrayOutputStream()
        try {
            val value = PrintStream(log).use { stream ->
                stream.println("\nData: ")
                stream.println(this.toUHexString(offset = offset, length = length))
                stream.println("Trace:")

                this.inputStream(offset = offset, length = length).asInput().use { input ->
                    Tars.UTF_8.load(deserializer, input, debugLogger = DebugLogger(stream))
                }
            }
            return value.also {
                TarsDecoder.logger.warning(
                    contextualBugReportException(
                        "解析 " + deserializer.descriptor.serialName,
                        "启用 debug 模式后解析正常: $value \n\n${log.toByteArray().decodeToString()}",
                        originalException
                    )
                )
            }
        } catch (secondFailure: Exception) {
            throw contextualBugReportException(
                "解析 " + deserializer.descriptor.serialName,
                log.toByteArray().decodeToString(),
                ExceptionCollector.compressExceptions(originalException, secondFailure)
            )
        }
    }
}

internal fun  BytePacketBuilder.writeJceStruct(
    serializer: SerializationStrategy,
    struct: T,
) {
    Tars.UTF_8.dumpTo(serializer, struct, this)
}

internal fun  ByteReadPacket.readJceStruct(
    deserializer: DeserializationStrategy,
    length: Int = this.remaining.toInt(),
): T {
    return this.useBytes(n = length) { data, arrayLength ->
        data.loadAs(deserializer, offset = 0, length = arrayLength)
    }
}

internal fun  BytePacketBuilder.writeJceRequestPacket(
    version: Int = 3,
    servantName: String,
    funcName: String,
    name: String = funcName,
    serializer: SerializationStrategy,
    body: T,
) = writeJceStruct(
    RequestPacket.serializer(),
    RequestPacket(
        requestId = 0,
        version = version.toShort(),
        servantName = servantName,
        funcName = funcName,
        sBuffer = jceRequestSBuffer(name, serializer, body),
    ),
)

/**
 * 先解析为 [RequestPacket], 即 `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
 */
internal fun  ByteReadPacket.readUniPacket(
    deserializer: DeserializationStrategy,
    name: String? = null,
): T {
    return decodeUniRequestPacketAndDeserialize(name) {
        it.read {
            discardExact(1)
            this.readJceStruct(deserializer, length = (this.remaining - 1).toInt())
        }
    }
}

/**
 * 先解析为 [RequestPacket], 即 `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
 */
internal fun  ByteReadPacket.readUniPacket(
    deserializer: DeserializationStrategy,
    name: String? = null,
): T {
    return decodeUniRequestPacketAndDeserialize(name) {
        it.read {
            discardExact(1)
            this.readProtoBuf(deserializer, (this.remaining - 1).toInt())
        }
    }
}

private fun  Map.singleValue(): V = this.entries.single().value

internal fun  ByteReadPacket.decodeUniRequestPacketAndDeserialize(name: String? = null, block: (ByteArray) -> R): R {
    val request = this.readJceStruct(RequestPacket.serializer())

    return block(
        if (name == null) when (request.version?.toInt() ?: 3) {
            2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.singleValue().singleValue()
            3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.singleValue()
            else -> error("unsupported version ${request.version}")
        } else when (request.version?.toInt() ?: 3) {
            2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }
                .singleValue()
            3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }
            else -> error("unsupported version ${request.version}")
        },
    )
}

internal fun  T.toByteArray(
    serializer: SerializationStrategy,
): ByteArray = Tars.UTF_8.encodeToByteArray(serializer, this)

internal fun  BytePacketBuilder.writeProtoBuf(serializer: SerializationStrategy, v: T) {
    this.writeFully(v.toByteArray(serializer))
}

internal fun  BytePacketBuilder.writeOidb(
    command: Int = 0,
    serviceType: Int = 0,
    serializer: SerializationStrategy,
    v: T,
    clientVersion: String = "android 8.4.8",
) {
    return this.writeProtoBuf(
        OidbSso.OIDBSSOPkg.serializer(),
        OidbSso.OIDBSSOPkg(
            command = command,
            serviceType = serviceType,
            clientVersion = clientVersion,
            bodybuffer = v.toByteArray(serializer),
        ),
    )
}

/**
 * dump
 */
internal fun  T.toByteArray(serializer: SerializationStrategy): ByteArray {
    return KtProtoBuf.encodeToByteArray(serializer, this)
}

/**
 * load
 */
internal fun  ByteArray.loadAs(deserializer: DeserializationStrategy, offset: Int = 0): T {
    if (offset != 0) {
        require(offset in offset..this.lastIndex) { "invalid offset: $offset" }
        return this.copyOfRange(offset, this.size).loadAs(deserializer)
    }
    return KtProtoBuf.decodeFromByteArray(deserializer, this)
}

internal fun  ByteArray.loadOidb(deserializer: DeserializationStrategy, log: Boolean = false): T {
    val oidb = loadAs(OidbSso.OIDBSSOPkg.serializer())
    if (log) {
        oidb.printStructure("OIDB")
    }
    return oidb.bodybuffer.loadAs(deserializer)
}

/**
 * load
 */
internal fun  ByteReadPacket.readProtoBuf(
    serializer: DeserializationStrategy,
    length: Int = this.remaining.toInt(),
): T = KtProtoBuf.decodeFromByteArray(serializer, this.readBytes(length))

@Suppress("NON_PUBLIC_PRIMARY_CONSTRUCTOR_OF_INLINE_CLASS")
@JvmInline
internal value class OidbBodyOrFailure private constructor(
    private val v: Any,
) {
    internal class Failure(
        val oidb: OidbSso.OIDBSSOPkg,
    )

    inline fun  fold(
        onSuccess: T.(T) -> R,
        onFailure: OidbSso.OIDBSSOPkg.(OidbSso.OIDBSSOPkg) -> R,
    ): R {
        contract {
            callsInPlace(onSuccess, InvocationKind.AT_MOST_ONCE)
            callsInPlace(onFailure, InvocationKind.AT_MOST_ONCE)
        }
        @Suppress("UNCHECKED_CAST")
        return if (v is Failure) {
            onFailure(v.oidb, v.oidb)
        } else {
            val t = v as T
            onSuccess(t, t)
        }
    }

    companion object {
        fun  success(t: T): OidbBodyOrFailure = OidbBodyOrFailure(t)
        fun  failure(oidb: OidbSso.OIDBSSOPkg): OidbBodyOrFailure = OidbBodyOrFailure(Failure(oidb))
    }
}

/**
 * load
 */
internal inline fun  ByteReadPacket.readOidbSsoPkg(
    serializer: DeserializationStrategy,
    length: Int = this.remaining.toInt(),
): OidbBodyOrFailure {
    val oidb = readBytes(length).loadAs(OidbSso.OIDBSSOPkg.serializer())
    return if (oidb.result == 0) {
        OidbBodyOrFailure.success(oidb.bodybuffer.loadAs(serializer))
    } else {
        OidbBodyOrFailure.failure(oidb)
    }
}

/**
 * 构造 [RequestPacket] 的 [RequestPacket.sBuffer]
 */
internal fun  jceRequestSBuffer(
    name: String,
    serializer: SerializationStrategy,
    jceStruct: T,
): ByteArray {
    return RequestDataVersion3(
        mapOf(
            name to JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0,
        ),
    ).toByteArray(RequestDataVersion3.serializer())
}

internal inline fun jceRequestSBuffer(block: JceRequestSBufferBuilder.() -> Unit): ByteArray {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return JceRequestSBufferBuilder().apply(block).complete()
}

internal class JceRequestSBufferBuilder {
    val map: MutableMap = LinkedHashMap()
    operator fun  String.invoke(
        serializer: SerializationStrategy,
        jceStruct: T,
    ) {
        map[this] = JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0
    }

    fun complete(): ByteArray = RequestDataVersion3(map).toByteArray(RequestDataVersion3.serializer())
}

private val JCE_STRUCT_HEAD_OF_TAG_0 = byteArrayOf(0x0A)
private val JCE_STRUCT_TAIL_OF_TAG_0 = byteArrayOf(0x0B)

internal inline fun  SerialDescriptor.findAnnotation(elementIndex: Int): A? {
    val candidates = getElementAnnotations(elementIndex).filterIsInstance()
    return when (candidates.size) {
        0 -> null
        1 -> candidates[0]
        else -> throw IllegalStateException("There are duplicate annotations of type ${A::class} in the descriptor $this")
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy