commonMain.utils.io.serialization.utils.kt Maven / Gradle / Ivy
/*
* Copyright 2019-2022 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/dev/LICENSE
*/
@file:JvmName("SerializationUtils")
@file:JvmMultifileClass
@file:Suppress("NOTHING_TO_INLINE")
package net.mamoe.mirai.internal.utils.io.serialization
import kotlinx.io.core.*
import kotlinx.io.streams.asInput
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.descriptors.SerialDescriptor
import net.mamoe.mirai.internal.message.contextualBugReportException
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestDataVersion2
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestDataVersion3
import net.mamoe.mirai.internal.network.protocol.data.jce.RequestPacket
import net.mamoe.mirai.internal.network.protocol.data.proto.OidbSso
import net.mamoe.mirai.internal.utils.io.JceStruct
import net.mamoe.mirai.internal.utils.io.ProtoBuf
import net.mamoe.mirai.internal.utils.io.serialization.tars.Tars
import net.mamoe.mirai.internal.utils.io.serialization.tars.internal.DebugLogger
import net.mamoe.mirai.internal.utils.io.serialization.tars.internal.TarsDecoder
import net.mamoe.mirai.internal.utils.printStructure
import net.mamoe.mirai.utils.*
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
internal typealias KtProtoBuf = kotlinx.serialization.protobuf.ProtoBuf
internal fun ByteArray.loadWithUniPacket(
deserializer: DeserializationStrategy,
name: String? = null,
): T = this.read { readUniPacket(deserializer, name) }
internal fun ByteArray.loadAs(
deserializer: DeserializationStrategy,
offset: Int = 0,
length: Int = size - offset,
): T {
if (this.size >= 4) {
val possibleLength = this.toInt(offset = offset)
// return doLoadAs(deserializer, offset = offset + 4, length = possibleLength)
if (possibleLength == length || possibleLength == length - 4) {
return doLoadAs(
deserializer,
offset = offset + 4,
length = length - 4
)
}
}
return doLoadAs(deserializer, offset, length)
}
private fun ByteArray.doLoadAs(
deserializer: DeserializationStrategy,
offset: Int,
length: Int,
): T {
try {
return this.inputStream(offset = offset, length = length).asInput().use { input ->
Tars.UTF_8.load(deserializer, input)
}
} catch (originalException: Exception) {
val log = ByteArrayOutputStream()
try {
val value = PrintStream(log).use { stream ->
stream.println("\nData: ")
stream.println(this.toUHexString(offset = offset, length = length))
stream.println("Trace:")
this.inputStream(offset = offset, length = length).asInput().use { input ->
Tars.UTF_8.load(deserializer, input, debugLogger = DebugLogger(stream))
}
}
return value.also {
TarsDecoder.logger.warning(
contextualBugReportException(
"解析 " + deserializer.descriptor.serialName,
"启用 debug 模式后解析正常: $value \n\n${log.toByteArray().decodeToString()}",
originalException
)
)
}
} catch (secondFailure: Exception) {
throw contextualBugReportException(
"解析 " + deserializer.descriptor.serialName,
log.toByteArray().decodeToString(),
ExceptionCollector.compressExceptions(originalException, secondFailure)
)
}
}
}
internal fun BytePacketBuilder.writeJceStruct(
serializer: SerializationStrategy,
struct: T,
) {
Tars.UTF_8.dumpTo(serializer, struct, this)
}
internal fun ByteReadPacket.readJceStruct(
deserializer: DeserializationStrategy,
length: Int = this.remaining.toInt(),
): T {
return this.useBytes(n = length) { data, arrayLength ->
data.loadAs(deserializer, offset = 0, length = arrayLength)
}
}
internal fun BytePacketBuilder.writeJceRequestPacket(
version: Int = 3,
servantName: String,
funcName: String,
name: String = funcName,
serializer: SerializationStrategy,
body: T,
) = writeJceStruct(
RequestPacket.serializer(),
RequestPacket(
requestId = 0,
version = version.toShort(),
servantName = servantName,
funcName = funcName,
sBuffer = jceRequestSBuffer(name, serializer, body),
),
)
/**
* 先解析为 [RequestPacket], 即 `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
*/
internal fun ByteReadPacket.readUniPacket(
deserializer: DeserializationStrategy,
name: String? = null,
): T {
return decodeUniRequestPacketAndDeserialize(name) {
it.read {
discardExact(1)
this.readJceStruct(deserializer, length = (this.remaining - 1).toInt())
}
}
}
/**
* 先解析为 [RequestPacket], 即 `UniRequest`, 再按版本解析 map, 再找出指定数据并反序列化
*/
internal fun ByteReadPacket.readUniPacket(
deserializer: DeserializationStrategy,
name: String? = null,
): T {
return decodeUniRequestPacketAndDeserialize(name) {
it.read {
discardExact(1)
this.readProtoBuf(deserializer, (this.remaining - 1).toInt())
}
}
}
private fun Map.singleValue(): V = this.entries.single().value
internal fun ByteReadPacket.decodeUniRequestPacketAndDeserialize(name: String? = null, block: (ByteArray) -> R): R {
val request = this.readJceStruct(RequestPacket.serializer())
return block(
if (name == null) when (request.version?.toInt() ?: 3) {
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.singleValue().singleValue()
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.singleValue()
else -> error("unsupported version ${request.version}")
} else when (request.version?.toInt() ?: 3) {
2 -> request.sBuffer.loadAs(RequestDataVersion2.serializer()).map.getOrElse(name) { error("cannot find $name") }
.singleValue()
3 -> request.sBuffer.loadAs(RequestDataVersion3.serializer()).map.getOrElse(name) { error("cannot find $name") }
else -> error("unsupported version ${request.version}")
},
)
}
internal fun T.toByteArray(
serializer: SerializationStrategy,
): ByteArray = Tars.UTF_8.encodeToByteArray(serializer, this)
internal fun BytePacketBuilder.writeProtoBuf(serializer: SerializationStrategy, v: T) {
this.writeFully(v.toByteArray(serializer))
}
internal fun BytePacketBuilder.writeOidb(
command: Int = 0,
serviceType: Int = 0,
serializer: SerializationStrategy,
v: T,
clientVersion: String = "android 8.4.8",
) {
return this.writeProtoBuf(
OidbSso.OIDBSSOPkg.serializer(),
OidbSso.OIDBSSOPkg(
command = command,
serviceType = serviceType,
clientVersion = clientVersion,
bodybuffer = v.toByteArray(serializer),
),
)
}
/**
* dump
*/
internal fun T.toByteArray(serializer: SerializationStrategy): ByteArray {
return KtProtoBuf.encodeToByteArray(serializer, this)
}
/**
* load
*/
internal fun ByteArray.loadAs(deserializer: DeserializationStrategy, offset: Int = 0): T {
if (offset != 0) {
require(offset in offset..this.lastIndex) { "invalid offset: $offset" }
return this.copyOfRange(offset, this.size).loadAs(deserializer)
}
return KtProtoBuf.decodeFromByteArray(deserializer, this)
}
internal fun ByteArray.loadOidb(deserializer: DeserializationStrategy, log: Boolean = false): T {
val oidb = loadAs(OidbSso.OIDBSSOPkg.serializer())
if (log) {
oidb.printStructure("OIDB")
}
return oidb.bodybuffer.loadAs(deserializer)
}
/**
* load
*/
internal fun ByteReadPacket.readProtoBuf(
serializer: DeserializationStrategy,
length: Int = this.remaining.toInt(),
): T = KtProtoBuf.decodeFromByteArray(serializer, this.readBytes(length))
@Suppress("NON_PUBLIC_PRIMARY_CONSTRUCTOR_OF_INLINE_CLASS")
@JvmInline
internal value class OidbBodyOrFailure private constructor(
private val v: Any,
) {
internal class Failure(
val oidb: OidbSso.OIDBSSOPkg,
)
inline fun fold(
onSuccess: T.(T) -> R,
onFailure: OidbSso.OIDBSSOPkg.(OidbSso.OIDBSSOPkg) -> R,
): R {
contract {
callsInPlace(onSuccess, InvocationKind.AT_MOST_ONCE)
callsInPlace(onFailure, InvocationKind.AT_MOST_ONCE)
}
@Suppress("UNCHECKED_CAST")
return if (v is Failure) {
onFailure(v.oidb, v.oidb)
} else {
val t = v as T
onSuccess(t, t)
}
}
companion object {
fun success(t: T): OidbBodyOrFailure = OidbBodyOrFailure(t)
fun failure(oidb: OidbSso.OIDBSSOPkg): OidbBodyOrFailure = OidbBodyOrFailure(Failure(oidb))
}
}
/**
* load
*/
internal inline fun ByteReadPacket.readOidbSsoPkg(
serializer: DeserializationStrategy,
length: Int = this.remaining.toInt(),
): OidbBodyOrFailure {
val oidb = readBytes(length).loadAs(OidbSso.OIDBSSOPkg.serializer())
return if (oidb.result == 0) {
OidbBodyOrFailure.success(oidb.bodybuffer.loadAs(serializer))
} else {
OidbBodyOrFailure.failure(oidb)
}
}
/**
* 构造 [RequestPacket] 的 [RequestPacket.sBuffer]
*/
internal fun jceRequestSBuffer(
name: String,
serializer: SerializationStrategy,
jceStruct: T,
): ByteArray {
return RequestDataVersion3(
mapOf(
name to JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0,
),
).toByteArray(RequestDataVersion3.serializer())
}
internal inline fun jceRequestSBuffer(block: JceRequestSBufferBuilder.() -> Unit): ByteArray {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return JceRequestSBufferBuilder().apply(block).complete()
}
internal class JceRequestSBufferBuilder {
val map: MutableMap = LinkedHashMap()
operator fun String.invoke(
serializer: SerializationStrategy,
jceStruct: T,
) {
map[this] = JCE_STRUCT_HEAD_OF_TAG_0 + jceStruct.toByteArray(serializer) + JCE_STRUCT_TAIL_OF_TAG_0
}
fun complete(): ByteArray = RequestDataVersion3(map).toByteArray(RequestDataVersion3.serializer())
}
private val JCE_STRUCT_HEAD_OF_TAG_0 = byteArrayOf(0x0A)
private val JCE_STRUCT_TAIL_OF_TAG_0 = byteArrayOf(0x0B)
internal inline fun SerialDescriptor.findAnnotation(elementIndex: Int): A? {
val candidates = getElementAnnotations(elementIndex).filterIsInstance()
return when (candidates.size) {
0 -> null
1 -> candidates[0]
else -> throw IllegalStateException("There are duplicate annotations of type ${A::class} in the descriptor $this")
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy