All Downloads are FREE. Search and download functionalities are using the official Maven repository.

it.auties.protobuf.stream.ProtobufInputStream Maven / Gradle / Ivy

The newest version!
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc.  All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd

// I'm not sure if the LICENSE copyright header is necessary as only two methods in this class are taken from Google's source code
// But just to be sure I included it

package it.auties.protobuf.stream;

import it.auties.protobuf.exception.ProtobufDeserializationException;
import it.auties.protobuf.model.ProtobufString;
import it.auties.protobuf.model.ProtobufWireType;

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.InvalidMarkException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class ProtobufInputStream {
    private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);

    private int wireType;
    private int index;
    private ProtobufInputStream() {
        this.wireType = -1;
        this.index = -1;
    }

    public static ProtobufInputStream fromBytes(byte[] bytes) {
        return new Bytes(bytes, 0, bytes.length);
    }

    public static ProtobufInputStream fromBytes(byte[] bytes, int offset, int length) {
        return new Bytes(bytes, offset, length);
    }

    public static ProtobufInputStream fromBuffer(ByteBuffer buffer) {
        return new Buffer(buffer, buffer.remaining());
    }

    public static ProtobufInputStream fromStream(InputStream buffer) {
        return new Stream(buffer);
    }

    public boolean readTag() {
        if(isFinished()) {
            return false;
        }

        var rawTag = readInt32Unchecked();
        this.wireType = rawTag & 7;
        this.index = rawTag >>> 3;
        if(index == 0) {
            throw ProtobufDeserializationException.invalidFieldIndex(index);
        }
        return wireType != ProtobufWireType.WIRE_TYPE_END_OBJECT;
    }

    public List readFloatPacked() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readFloatUnchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_FIXED32 -> List.of(readFloatUnchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readDoublePacked() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readDoubleUnchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_FIXED64 -> List.of(readDoubleUnchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readInt32Packed() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readInt32Unchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readInt32Unchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readInt64Packed() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readInt64Unchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readInt64Unchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readFixed32Packed() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readFixed32Unchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_FIXED32 -> List.of(readFixed32Unchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readFixed64Packed() {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readFixed64Unchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_FIXED64 -> List.of(readFixed64Unchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public List readBoolPacked(){
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
                var results = new ArrayList();
                var input = readLengthDelimited();
                while (!input.isFinished()){
                    results.add(input.readBoolUnchecked());
                }

                yield results;
            }

            case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readBoolUnchecked());
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    public float readFloat() {
        return Float.intBitsToFloat(readFixed32());
    }

    public float readFloatUnchecked() {
        return Float.intBitsToFloat(readFixed32());
    }

    public double readDouble() {
        return Double.longBitsToDouble(readFixed64());
    }

    public double readDoubleUnchecked() {
        return Double.longBitsToDouble(readFixed64Unchecked());
    }

    public boolean readBool() {
        if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        return readBoolUnchecked();
    }

    private boolean readBoolUnchecked() {
        return readInt64() == 1;
    }

    public ProtobufString readString() {
        if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        var size = this.readInt32Unchecked();
        if(size < 0) {
            throw ProtobufDeserializationException.negativeLength(size);
        }else {
            return readString(size);
        }
    }

    public int readInt32() {
        if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        return readInt32Unchecked();
    }

    // Source: https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/CodedInputStream.java
    // Fastest implementation I could find
    // Adapted to work with Channels
    private int readInt32Unchecked() {
        mark();
        fspath:
        {
            int x;
            if ((x = readByte()) >= 0) {
                return x;
            } else if ((x ^= (readByte() << 7)) < 0) {
                x ^= (~0 << 7);
            } else if ((x ^= (readByte() << 14)) >= 0) {
                x ^= (~0 << 7) ^ (~0 << 14);
            } else if ((x ^= (readByte() << 21)) < 0) {
                x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21);
            } else {
                int y = readByte();
                x ^= y << 28;
                x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28);
                if (y < 0
                        && readByte() < 0
                        && readByte() < 0
                        && readByte() < 0
                        && readByte() < 0
                        && readByte() < 0) {
                    break fspath;
                }
            }
            return x;
        }

        rewind();
        return (int) readVarInt64Slow();
    }

    // Source: https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/CodedInputStream.java
    // Fastest implementation I could find
    // Adapted to work with Channels
    public long readInt64() {
        if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        return readInt64Unchecked();
    }

    private long readInt64Unchecked() {
        mark();
        fspath:
        {
            long x;
            int y;
            if ((y = readByte()) >= 0) {
                return y;
            } else if ((y ^= (readByte() << 7)) < 0) {
                x = y ^ (~0 << 7);
            } else if ((y ^= (readByte() << 14)) >= 0) {
                x = y ^ ((~0 << 7) ^ (~0 << 14));
            } else if ((y ^= (readByte() << 21)) < 0) {
                x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21));
            } else if ((x = y ^ ((long) readByte() << 28)) >= 0L) {
                x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28);
            } else if ((x ^= ((long) readByte() << 35)) < 0L) {
                x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35);
            } else if ((x ^= ((long) readByte() << 42)) >= 0L) {
                x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42);
            } else if ((x ^= ((long) readByte() << 49)) < 0L) {
                x ^=
                        (~0L << 7)
                                ^ (~0L << 14)
                                ^ (~0L << 21)
                                ^ (~0L << 28)
                                ^ (~0L << 35)
                                ^ (~0L << 42)
                                ^ (~0L << 49);
            } else {
                x ^= ((long) readByte() << 56);
                x ^=
                        (~0L << 7)
                                ^ (~0L << 14)
                                ^ (~0L << 21)
                                ^ (~0L << 28)
                                ^ (~0L << 35)
                                ^ (~0L << 42)
                                ^ (~0L << 49)
                                ^ (~0L << 56);
                if (x < 0L) {
                    if (readByte() < 0L) {
                        break fspath;
                    }
                }
            }
            return x;
        }

        rewind();
        return readVarInt64Slow();
    }

    private long readVarInt64Slow() {
        var result = 0L;
        for (int shift = 0; shift < 64; shift += 7) {
            byte b = readByte();
            result |= (long) (b & 0x7F) << shift;
            if ((b & 0x80) == 0) {
                return result;
            }
        }

        throw ProtobufDeserializationException.malformedVarInt();
    }

    public int readFixed32() {
        if(wireType != ProtobufWireType.WIRE_TYPE_FIXED32) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        return readFixed32Unchecked();
    }

    private int readFixed32Unchecked() {
        return readByte() & 255
                | (readByte() & 255) << 8
                | (readByte() & 255) << 16
                | (readByte() & 255) << 24;
    }

    public long readFixed64() {
        if(wireType != ProtobufWireType.WIRE_TYPE_FIXED64) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        return readFixed64Unchecked();
    }

    private long readFixed64Unchecked() {
        return (long) readByte() & 255L
                | ((long) readByte() & 255L) << 8
                | ((long) readByte() & 255L) << 16
                | ((long) readByte() & 255L) << 24
                | ((long) readByte() & 255L) << 32
                | ((long) readByte() & 255L) << 40
                | ((long) readByte() & 255L) << 48
                | ((long) readByte() & 255L) << 56;
    }

    public ByteBuffer readBytes() {
        if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        var size = this.readInt32Unchecked();
        if(size < 0) {
            throw ProtobufDeserializationException.negativeLength(size);
        }else if(size == 0) {
            return EMPTY_BUFFER;
        }else {
            return readBytes(size);
        }
    }

    public Object readUnknown(boolean allocate) {
        return switch (wireType) {
            case ProtobufWireType.WIRE_TYPE_VAR_INT -> readInt64();
            case ProtobufWireType.WIRE_TYPE_FIXED32 -> readFixed32();
            case ProtobufWireType.WIRE_TYPE_FIXED64 -> readFixed64();
            case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> readBytes();
            case ProtobufWireType.WIRE_TYPE_START_OBJECT -> readGroup(allocate);
            default -> throw ProtobufDeserializationException.invalidWireType(wireType);
        };
    }

    private Map readGroup(boolean allocate) {
        var group = allocate ? new HashMap() : null;
        var groupIndex = index;
        while (readTag()) {
            var value = readUnknown(allocate);
            if(group != null) {
                group.put(index, value);
            }
        }
        assertGroupClosed(groupIndex);
        return group;
    }

    public void assertGroupOpened(int groupIndex) {
        if((wireType == -1 && !readTag()) || wireType != ProtobufWireType.WIRE_TYPE_START_OBJECT || index != groupIndex) {
            throw ProtobufDeserializationException.invalidStartObject(groupIndex);
        }
    }

    public void assertGroupClosed(int groupIndex) {
        if(wireType != ProtobufWireType.WIRE_TYPE_END_OBJECT) {
            throw ProtobufDeserializationException.malformedGroup();
        }
        if(index != groupIndex) {
            throw ProtobufDeserializationException.invalidEndObject(index, groupIndex);
        }
    }

    public int index() {
        return index;
    }

    public ProtobufInputStream readLengthDelimited() {
        if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
            throw ProtobufDeserializationException.invalidWireType(wireType);
        }

        var size = this.readInt32Unchecked();
        if(size < 0) {
            throw ProtobufDeserializationException.negativeLength(size);
        }else {
            return subStream(size);
        }
    }

    protected abstract byte readByte();
    protected abstract ByteBuffer readBytes(int size);
    protected abstract ProtobufString readString(int size);
    protected abstract void mark();
    protected abstract void rewind();
    protected abstract boolean isFinished();
    protected abstract ProtobufInputStream subStream(int size);

    private static final class Stream extends ProtobufInputStream {
        private static final int MAX_VAR_INT_SIZE = 10;

        private final InputStream inputStream;
        private final long length;
        private final byte[] buffer;
        private long position;
        private int bufferReadPosition;
        private int bufferWritePosition;
        private int bufferLength;
        private Stream(InputStream inputStream) {
            this.inputStream = inputStream;
            this.length = -1;
            this.buffer = new byte[MAX_VAR_INT_SIZE];
        }

        private Stream(InputStream inputStream, long length, byte[] buffer, int bufferReadPosition, int bufferWritePosition, int bufferLength) {
            this.inputStream = inputStream;
            this.length = length;
            this.buffer = buffer;
            this.bufferReadPosition = bufferReadPosition;
            this.bufferWritePosition = bufferWritePosition;
            this.bufferLength = bufferLength;
        }

        @Override
        public byte readByte() {
            try {
                if(length != -1) {
                    position++;
                }

                if(bufferLength > 0) {
                    bufferLength--;
                    return buffer[bufferReadPosition++];
                }

                var result = (byte) inputStream.read();
                buffer[bufferWritePosition++ % buffer.length] = result;
                return result;
            } catch (IOException exception) {
                throw new UncheckedIOException(exception);
            }
        }

        @Override
        public ByteBuffer readBytes(int size) {
            try {
                return ByteBuffer.wrap(readStreamBytes(size));
            } catch (IOException exception) {
                throw new UncheckedIOException(exception);
            }
        }

        @Override
        public ProtobufString readString(int size) {
            try {
                return ProtobufString.lazy(readStreamBytes(size), 0, size);
            } catch (IOException exception) {
                throw new UncheckedIOException(exception);
            }
        }

        private byte[] readStreamBytes(int size) throws IOException {
            if(length != -1) {
                position += size;
            }

            var result = new byte[size];
            for (int i = 0; i < size; i++) {
                if(bufferLength > 0) {
                    result[i] = buffer[bufferReadPosition++];
                    bufferLength--;
                }else {
                    var entry = (byte) inputStream.read();
                    result[i] = entry;
                    buffer[bufferWritePosition++ % buffer.length] = entry;
                }
            }
            return result;
        }

        @Override
        public void mark() {
            this.bufferReadPosition = 0;
            this.bufferWritePosition = 0;
        }

        @Override
        public void rewind() {
            this.bufferReadPosition = 0;
            this.bufferLength = bufferWritePosition - bufferReadPosition;
            if(length != -1) {
                this.position -= bufferLength;
            }
        }

        @Override
        public boolean isFinished() {
            if (length != -1) {
                return position >= length;
            }

            mark();
            var result = readByte() == -1;
            rewind();
            return result;
        }

        @Override
        public Stream subStream(int size) {
            var result = new Stream(inputStream, size, buffer, bufferReadPosition, bufferWritePosition, bufferLength);
            if(length != -1) {
                position += size;
            }
            return result;
        }
    }

    private static final class Bytes extends ProtobufInputStream {
        private final byte[] buffer;
        private final int offset;
        private final int length;
        private int position;
        private int marker;
        private Bytes(byte[] buffer, int offset, int length) {
            this.buffer = buffer;
            this.offset = offset;
            this.length = length;
        }

        @Override
        public byte readByte() {
            return buffer[offset + position++];
        }

        @Override
        public ByteBuffer readBytes(int size) {
            var result = ByteBuffer.wrap(buffer, offset + position, size);
            position += size;
            return result;
        }

        @Override
        public ProtobufString readString(int size) {
            var result = ProtobufString.lazy(buffer, offset + position, size);
            position += size;
            return result;
        }

        @Override
        public void mark() {
            this.marker = position;
        }

        @Override
        public void rewind() {
            if(marker == -1) {
                throw new InvalidMarkException();
            }

            this.position = marker;
        }

        @Override
        public boolean isFinished() {
            return position >= length;
        }

        @Override
        public Bytes subStream(int size) {
            var result = new Bytes(buffer, offset + position, size);
            position += size;
            return result;
        }
    }

    private static final class Buffer extends ProtobufInputStream {
        private final ByteBuffer buffer;
        private int length;
        private Buffer(ByteBuffer buffer, int length) {
            this.buffer = buffer;
            this.length = length;
        }

        @Override
        public byte readByte() {
            var result = buffer.get();
            length--;
            return result;
        }

        @Override
        public ByteBuffer readBytes(int size) {
            var position = buffer.position();
            var result = buffer.slice(position, position + size);
            buffer.position(position + size);
            length -= size;
            return result;
        }

        @Override
        public ProtobufString readString(int size) {
            var position = buffer.position();
            var result = buffer.slice(position, position + size);
            buffer.position(position + size);
            length -= size;
            return ProtobufString.lazy(result.asReadOnlyBuffer());
        }

        @Override
        public void mark() {
            buffer.mark();
        }

        @Override
        public void rewind() {
            buffer.reset();
        }

        @Override
        public boolean isFinished() {
            return length <= 0;
        }

        @Override
        public Buffer subStream(int size) {
            var result = new Buffer(buffer, size);
            buffer.position(buffer.position() + size);
            length -= size;
            return result;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy