All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.microsoft.thrifty.protocol.CompactProtocol.kt Maven / Gradle / Ivy

/*
 * Thrifty
 *
 * Copyright (c) Microsoft Corporation
 *
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the License);
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * THIS CODE IS PROVIDED ON AN  *AS IS* BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
 * WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
 * FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
 *
 * See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
 */
/*
 * This file is derived from the file TCompactProtocol.java, in the Apache
 * Thrift implementation.  The original license header is reproduced
 * below:
 */
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.microsoft.thrifty.protocol

import com.microsoft.thrifty.TType
import com.microsoft.thrifty.internal.ProtocolException
import com.microsoft.thrifty.transport.Transport
import okio.ByteString
import okio.ByteString.Companion.toByteString
import okio.EOFException
import okio.IOException

/**
 * An implementation of the Thrift compact binary protocol.
 *
 * Instances of this class are *not* threadsafe.
 */
class CompactProtocol(transport: Transport) : BaseProtocol(transport) {

    // Boolean fields get special treatment - their value is encoded
    // directly in the field header.  As such, when a boolean field
    // header is written, we cache it here until we get the value from
    // the subsequent `writeBool` call.
    private var booleanFieldId = -1

    // Similarly, we cache the value read from a field header until
    // the `readBool` call.
    private var booleanFieldType: Byte = -1
    private val buffer = ByteArray(16)

    // Keep track of the most-recently-written fields,
    // used for delta-encoding.
    private val writingFields = ShortStack()
    private var lastWritingField: Short = 0
    private val readingFields = ShortStack()
    private var lastReadingField: Short = 0

    @Throws(IOException::class)
    override fun writeMessageBegin(name: String, typeId: Byte, seqId: Int) {
        writeByte(PROTOCOL_ID)
        writeByte(((VERSION.toInt() and VERSION_MASK.toInt()) or ((typeId.toInt() shl TYPE_SHIFT_AMOUNT) and TYPE_MASK.toInt())).toByte())
        writeVarint32(seqId)
        writeString(name)
    }

    @Throws(IOException::class)
    override fun writeMessageEnd() {
        // no wire representation
    }

    @Throws(IOException::class)
    override fun writeStructBegin(structName: String) {
        writingFields.push(lastWritingField)
        lastWritingField = 0
    }

    @Throws(IOException::class)
    override fun writeStructEnd() {
        lastWritingField = writingFields.pop()
    }

    @Throws(IOException::class)
    override fun writeFieldBegin(fieldName: String, fieldId: Int, typeId: Byte) {
        if (typeId == TType.BOOL) {
            if (booleanFieldId != -1) {
                throw ProtocolException("Nested invocation of writeFieldBegin")
            }
            booleanFieldId = fieldId
        } else {
            writeFieldBegin(fieldId, CompactTypes.ttypeToCompact(typeId))
        }
    }

    @Throws(IOException::class)
    private fun writeFieldBegin(fieldId: Int, compactTypeId: Byte) {
        // Can we delta-encode the field ID?
        if (fieldId > lastWritingField && fieldId - lastWritingField <= 15) {
            writeByte((fieldId - lastWritingField shl 4 or compactTypeId.toInt()).toByte())
        } else {
            writeByte(compactTypeId)
            writeI16(fieldId.toShort())
        }
        lastWritingField = fieldId.toShort()
    }

    @Throws(IOException::class)
    override fun writeFieldEnd() {
        // no wire representation
    }

    @Throws(IOException::class)
    override fun writeFieldStop() {
        writeByte(TType.STOP)
    }

    @Throws(IOException::class)
    override fun writeMapBegin(keyTypeId: Byte, valueTypeId: Byte, mapSize: Int) {
        if (mapSize == 0) {
            writeByte(0.toByte())
        } else {
            val compactKeyType = CompactTypes.ttypeToCompact(keyTypeId)
            val compactValueType = CompactTypes.ttypeToCompact(valueTypeId)
            writeVarint32(mapSize)
            writeByte(((compactKeyType.toInt() shl 4) or compactValueType.toInt()).toByte())
        }
    }

    @Throws(IOException::class)
    override fun writeMapEnd() {
        // no wire representation
    }

    @Throws(IOException::class)
    override fun writeListBegin(elementTypeId: Byte, listSize: Int) {
        writeVectorBegin(elementTypeId, listSize)
    }

    @Throws(IOException::class)
    override fun writeListEnd() {
        // no wire representation
    }

    @Throws(IOException::class)
    override fun writeSetBegin(elementTypeId: Byte, setSize: Int) {
        writeVectorBegin(elementTypeId, setSize)
    }

    @Throws(IOException::class)
    override fun writeSetEnd() {
        // no wire representation
    }

    @Throws(IOException::class)
    override fun writeBool(b: Boolean) {
        val compactValue = if (b) CompactTypes.BOOLEAN_TRUE else CompactTypes.BOOLEAN_FALSE
        if (booleanFieldId != -1) {
            // We are writing a boolean field, and need to write the
            // deferred field header.  In this case we encode the value
            // directly in the header's type field.
            writeFieldBegin(booleanFieldId, compactValue)
            booleanFieldId = -1
        } else {
            // We are not writing a field - just write the value directly.
            writeByte(compactValue)
        }
    }

    @Throws(IOException::class)
    override fun writeByte(b: Byte) {
        buffer[0] = b
        transport.write(buffer, 0, 1)
    }

    @Throws(IOException::class)
    override fun writeI16(i16: Short) {
        writeVarint32(intToZigZag(i16.toInt()))
    }

    @Throws(IOException::class)
    override fun writeI32(i32: Int) {
        writeVarint32(intToZigZag(i32))
    }

    @Throws(IOException::class)
    override fun writeI64(i64: Long) {
        writeVarint64(longToZigZag(i64))
    }

    @Throws(IOException::class)
    override fun writeDouble(dub: Double) {
        val bits = dub.toRawBits()

        // Doubles get written out in little-endian order
        buffer[0] = (bits and 0xFFL).toByte()
        buffer[1] = ((bits ushr  8) and 0xFFL).toByte()
        buffer[2] = ((bits ushr 16) and 0xFFL).toByte()
        buffer[3] = ((bits ushr 24) and 0xFFL).toByte()
        buffer[4] = ((bits ushr 32) and 0xFFL).toByte()
        buffer[5] = ((bits ushr 40) and 0xFFL).toByte()
        buffer[6] = ((bits ushr 48) and 0xFFL).toByte()
        buffer[7] = ((bits ushr 56) and 0xFFL).toByte()
        transport.write(buffer, 0, 8)
    }

    @Throws(IOException::class)
    override fun writeString(str: String) {
        val bytes = str.encodeToByteArray()
        writeVarint32(bytes.size)
        transport.write(bytes)
    }

    @Throws(IOException::class)
    override fun writeBinary(buf: ByteString) {
        writeVarint32(buf.size)
        transport.write(buf.toByteArray())
    }

    @Throws(IOException::class)
    private fun writeVectorBegin(typeId: Byte, size: Int) {
        val compactId = CompactTypes.ttypeToCompact(typeId)
        if (size <= 14) {
            writeByte(((size shl 4) or compactId.toInt()).toByte())
        } else {
            writeByte((0xF0 or compactId.toInt()).toByte())
            writeVarint32(size)
        }
    }

    @Throws(IOException::class)
    private fun writeVarint32(num: Int) {
        var n = num
        for (i in buffer.indices) {
            if (n and 0x7F.inv() == 0x00) {
                buffer[i] = n.toByte()
                transport.write(buffer, 0, i + 1)
                return
            } else {
                buffer[i] = ((n and 0x7F) or 0x80).toByte()
                n = n ushr 7
            }
        }
        throw IllegalArgumentException("Cannot represent $n as a varint in 16 bytes or less")
    }

    @Throws(IOException::class)
    private fun writeVarint64(num: Long) {
        var n = num
        for (i in buffer.indices) {
            if (n and 0x7FL.inv() == 0x00L) {
                buffer[i] = n.toByte()
                transport.write(buffer, 0, i + 1)
                return
            } else {
                buffer[i] = ((n and 0x7F) or 0x80).toByte()
                n = n ushr 7
            }
        }
        throw IllegalArgumentException("Cannot represent $n as a varint in 16 bytes or less")
    }

    @Throws(IOException::class)
    override fun readMessageBegin(): MessageMetadata {
        val protocolId = readByte()
        if (protocolId != PROTOCOL_ID) {
            throw ProtocolException(
                    "Expected protocol ID " + PROTOCOL_ID.toInt()
                            + " but got " + protocolId.toInt().toString(radix = 16))
        }
        val versionAndType = readByte()
        val version = (VERSION_MASK.toInt() and versionAndType.toInt()).toByte()
        if (version != VERSION) {
            throw ProtocolException(
                    "Version mismatch; expected version " + VERSION
                            + " but got " + version)
        }
        val typeId = ((versionAndType.toInt() shr TYPE_SHIFT_AMOUNT) and TYPE_BITS.toInt()).toByte()
        val seqId = readVarint32()
        val name = readString()
        return MessageMetadata(name, typeId, seqId)
    }

    @Throws(IOException::class)
    override fun readMessageEnd() {
    }

    @Throws(IOException::class)
    override fun readStructBegin(): StructMetadata {
        readingFields.push(lastReadingField)
        lastReadingField = 0
        return NO_STRUCT
    }

    @Throws(IOException::class)
    override fun readStructEnd() {
        lastReadingField = readingFields.pop()
    }

    @Throws(IOException::class)
    override fun readFieldBegin(): FieldMetadata {
        val compactId = readByte()
        val typeId = CompactTypes.compactToTtype((compactId.toInt() and 0x0F).toByte())
        if (compactId == TType.STOP) {
            return END_FIELDS
        }
        val fieldId: Short
        val modifier = ((compactId.toInt() and 0xF0) shr 4).toShort()
        fieldId = if (modifier.toInt() == 0) {
            // This is not a field-ID delta - read the entire ID.
            readI16()
        } else {
            (lastReadingField + modifier).toShort()
        }
        if (typeId == TType.BOOL) {
            // the bool value is encoded in the lower nibble of the ID
            booleanFieldType = (compactId.toInt() and 0x0F).toByte()
        }
        lastReadingField = fieldId
        return FieldMetadata("", typeId, fieldId)
    }

    @Throws(IOException::class)
    override fun readFieldEnd() {
    }

    @Throws(IOException::class)
    override fun readMapBegin(): MapMetadata {
        val size = readVarint32()
        val keyAndValueTypes = if (size == 0) 0 else readByte()
        val keyType = CompactTypes.compactToTtype(((keyAndValueTypes.toInt() shr 4) and 0x0F).toByte())
        val valueType = CompactTypes.compactToTtype((keyAndValueTypes.toInt() and 0x0F).toByte())
        return MapMetadata(keyType, valueType, size)
    }

    @Throws(IOException::class)
    override fun readMapEnd() {
        // Nothing on the wire
    }

    @Throws(IOException::class)
    override fun readListBegin(): ListMetadata {
        return readCollectionBegin(::ListMetadata)
    }

    @Throws(IOException::class)
    override fun readListEnd() {
        // Nothing on the wire
    }

    @Throws(IOException::class)
    override fun readSetBegin(): SetMetadata {
        return readCollectionBegin(::SetMetadata)
    }

    private inline fun  readCollectionBegin(buildMetadata: (Byte, Int) -> T): T {
        val sizeAndType = readByte()
        var size: Int = (sizeAndType.toInt() shr 4) and 0x0F
        if (size == 0x0F) {
            size = readVarint32()
        }
        val compactType = (sizeAndType.toInt() and 0x0F).toByte()
        val ttype = CompactTypes.compactToTtype(compactType)
        return buildMetadata(ttype, size)
    }

    @Throws(IOException::class)
    override fun readSetEnd() {
        // Nothing on the wire
    }

    @Throws(IOException::class)
    override fun readBool(): Boolean {
        val compactId: Byte
        if (booleanFieldType.toInt() != -1) {
            compactId = booleanFieldType
            booleanFieldType = -1
        } else {
            compactId = readByte()
        }
        return compactId == CompactTypes.BOOLEAN_TRUE
    }

    @Throws(IOException::class)
    override fun readByte(): Byte {
        readFully(buffer, 1)
        return buffer[0]
    }

    @Throws(IOException::class)
    override fun readI16(): Short {
        return zigZagToInt(readVarint32()).toShort()
    }

    @Throws(IOException::class)
    override fun readI32(): Int {
        return zigZagToInt(readVarint32())
    }

    @Throws(IOException::class)
    override fun readI64(): Long {
        return zigZagToLong(readVarint64())
    }

    @Throws(IOException::class)
    override fun readDouble(): Double {
        readFully(buffer, 8)
        val bits: Long = ((buffer[0].toLong() and 0xFFL)
                or ((buffer[1].toLong() and 0xFFL) shl 8)
                or ((buffer[2].toLong() and 0xFFL) shl 16)
                or ((buffer[3].toLong() and 0xFFL) shl 24)
                or ((buffer[4].toLong() and 0xFFL) shl 32)
                or ((buffer[5].toLong() and 0xFFL) shl 40)
                or ((buffer[6].toLong() and 0xFFL) shl 48)
                or ((buffer[7].toLong() and 0xFFL) shl 56))
        return Double.fromBits(bits)
    }

    @Throws(IOException::class)
    override fun readString(): String {
        val length = readVarint32()
        if (length == 0) {
            return ""
        }
        val bytes = ByteArray(length)
        readFully(bytes, length)
        return bytes.decodeToString()
    }

    @Throws(IOException::class)
    override fun readBinary(): ByteString {
        val length = readVarint32()
        if (length == 0) {
            return ByteString.EMPTY
        }
        val bytes = ByteArray(length)
        readFully(bytes, length)
        return bytes.toByteString()
    }

    @Throws(IOException::class)
    private fun readVarint32(): Int {
        var result = 0
        var shift = 0
        while (true) {
            val b = readByte()
            result = result or ((b.toInt() and 0x7F) shl shift)
            if (b.toInt() and 0x80 != 0x80) {
                return result
            }
            shift += 7
        }
    }

    @Throws(IOException::class)
    private fun readVarint64(): Long {
        var result: Long = 0
        var shift = 0
        while (true) {
            val b = readByte()
            result = result or ((b.toInt() and 0x7F).toLong() shl shift)
            if (b.toInt() and 0x80 != 0x80) {
                return result
            }
            shift += 7
        }
    }

    @Throws(IOException::class)
    private fun readFully(buffer: ByteArray, count: Int) {
        var toRead = count
        var offset = 0
        while (toRead > 0) {
            val read = transport.read(buffer, offset, toRead)
            if (read == -1) {
                throw EOFException()
            }
            toRead -= read
            offset += read
        }
    }

    private class CompactTypes private constructor() {
        companion object {
            const val BOOLEAN_TRUE: Byte = 0x01
            const val BOOLEAN_FALSE: Byte = 0x02
            const val BYTE: Byte = 0x03
            const val I16: Byte = 0x04
            const val I32: Byte = 0x05
            const val I64: Byte = 0x06
            const val DOUBLE: Byte = 0x07
            const val BINARY: Byte = 0x08
            const val LIST: Byte = 0x09
            const val SET: Byte = 0x0A
            const val MAP: Byte = 0x0B
            const val STRUCT: Byte = 0x0C
            fun ttypeToCompact(typeId: Byte): Byte {
                return when (typeId) {
                    TType.STOP -> TType.STOP
                    TType.VOID -> throw IllegalArgumentException("Unexpected VOID type")
                    TType.BOOL -> BOOLEAN_TRUE
                    TType.BYTE -> BYTE
                    TType.DOUBLE -> DOUBLE
                    TType.I16 -> I16
                    TType.I32 -> I32
                    TType.I64 -> I64
                    TType.STRING -> BINARY
                    TType.STRUCT -> STRUCT
                    TType.MAP -> MAP
                    TType.SET -> SET
                    TType.LIST -> LIST
                    else -> throw IllegalArgumentException(
                            "Unknown TType ID: $typeId")
                }
            }

            fun compactToTtype(compactId: Byte): Byte {
                return when (compactId) {
                    TType.STOP -> TType.STOP
                    BOOLEAN_TRUE -> TType.BOOL
                    BOOLEAN_FALSE -> TType.BOOL
                    BYTE -> TType.BYTE
                    I16 -> TType.I16
                    I32 -> TType.I32
                    I64 -> TType.I64
                    DOUBLE -> TType.DOUBLE
                    BINARY -> TType.STRING
                    LIST -> TType.LIST
                    SET -> TType.SET
                    MAP -> TType.MAP
                    STRUCT -> TType.STRUCT
                    else -> throw IllegalArgumentException(
                            "Unknown compact type ID: $compactId")
                }
            }
        }

        init {
            throw AssertionError("no instances")
        }
    }

    private class ShortStack {
        private var stack: ShortArray
        private var top: Int
        fun push(value: Short) {
            if (top + 1 == stack.size) {
                stack = stack.copyOf(stack.size shl 1)
            }
            stack[++top] = value
        }

        fun pop(): Short {
            return stack[top--]
        }

        init {
            stack = ShortArray(16)
            top = -1
        }
    }

    companion object {
        // Constants, as defined in TCompactProtocol.java
        private const val PROTOCOL_ID = 0x82.toByte()
        private const val VERSION: Byte = 1
        private const val VERSION_MASK: Byte = 0x1F
        private const val TYPE_MASK = 0xE0.toByte()
        private const val TYPE_BITS: Byte = 0x07
        private const val TYPE_SHIFT_AMOUNT = 5
        private val NO_STRUCT = StructMetadata("")
        private val END_FIELDS = FieldMetadata("", TType.STOP, 0.toShort())

        /**
         * Convert a twos-complement int to zigzag encoding,
         * allowing negative values to be written as varints.
         */
        private fun intToZigZag(n: Int): Int {
            return n shl 1 xor (n shr 31)
        }

        /**
         * Convert a twos-complement long to zigzag encoding,
         * allowing negative values to be written as varints.
         */
        private fun longToZigZag(n: Long): Long {
            return n shl 1 xor (n shr 63)
        }

        private fun zigZagToInt(n: Int): Int {
            return n ushr 1 xor -(n and 1)
        }

        private fun zigZagToLong(n: Long): Long {
            return n ushr 1 xor -(n and 1)
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy