it.auties.protobuf.stream.ProtobufInputStream Maven / Gradle / Ivy
The newest version!
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
// I'm not sure if the LICENSE copyright header is necessary as only two methods in this class are taken from Google's source code
// But just to be sure I included it
package it.auties.protobuf.stream;
import it.auties.protobuf.exception.ProtobufDeserializationException;
import it.auties.protobuf.model.ProtobufString;
import it.auties.protobuf.model.ProtobufWireType;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.InvalidMarkException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public abstract class ProtobufInputStream {
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
private int wireType;
private int index;
private ProtobufInputStream() {
this.wireType = -1;
this.index = -1;
}
public static ProtobufInputStream fromBytes(byte[] bytes) {
return new Bytes(bytes, 0, bytes.length);
}
public static ProtobufInputStream fromBytes(byte[] bytes, int offset, int length) {
return new Bytes(bytes, offset, length);
}
public static ProtobufInputStream fromBuffer(ByteBuffer buffer) {
return new Buffer(buffer, buffer.remaining());
}
public static ProtobufInputStream fromStream(InputStream buffer) {
return new Stream(buffer);
}
public boolean readTag() {
if(isFinished()) {
return false;
}
var rawTag = readInt32Unchecked();
this.wireType = rawTag & 7;
this.index = rawTag >>> 3;
if(index == 0) {
throw ProtobufDeserializationException.invalidFieldIndex(index);
}
return wireType != ProtobufWireType.WIRE_TYPE_END_OBJECT;
}
public List readFloatPacked() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readFloatUnchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_FIXED32 -> List.of(readFloatUnchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readDoublePacked() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readDoubleUnchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_FIXED64 -> List.of(readDoubleUnchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readInt32Packed() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readInt32Unchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readInt32Unchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readInt64Packed() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readInt64Unchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readInt64Unchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readFixed32Packed() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readFixed32Unchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_FIXED32 -> List.of(readFixed32Unchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readFixed64Packed() {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readFixed64Unchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_FIXED64 -> List.of(readFixed64Unchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public List readBoolPacked(){
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> {
var results = new ArrayList();
var input = readLengthDelimited();
while (!input.isFinished()){
results.add(input.readBoolUnchecked());
}
yield results;
}
case ProtobufWireType.WIRE_TYPE_VAR_INT -> List.of(readBoolUnchecked());
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
public float readFloat() {
return Float.intBitsToFloat(readFixed32());
}
public float readFloatUnchecked() {
return Float.intBitsToFloat(readFixed32());
}
public double readDouble() {
return Double.longBitsToDouble(readFixed64());
}
public double readDoubleUnchecked() {
return Double.longBitsToDouble(readFixed64Unchecked());
}
public boolean readBool() {
if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
return readBoolUnchecked();
}
private boolean readBoolUnchecked() {
return readInt64() == 1;
}
public ProtobufString readString() {
if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
var size = this.readInt32Unchecked();
if(size < 0) {
throw ProtobufDeserializationException.negativeLength(size);
}else {
return readString(size);
}
}
public int readInt32() {
if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
return readInt32Unchecked();
}
// Source: https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/CodedInputStream.java
// Fastest implementation I could find
// Adapted to work with Channels
private int readInt32Unchecked() {
mark();
fspath:
{
int x;
if ((x = readByte()) >= 0) {
return x;
} else if ((x ^= (readByte() << 7)) < 0) {
x ^= (~0 << 7);
} else if ((x ^= (readByte() << 14)) >= 0) {
x ^= (~0 << 7) ^ (~0 << 14);
} else if ((x ^= (readByte() << 21)) < 0) {
x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21);
} else {
int y = readByte();
x ^= y << 28;
x ^= (~0 << 7) ^ (~0 << 14) ^ (~0 << 21) ^ (~0 << 28);
if (y < 0
&& readByte() < 0
&& readByte() < 0
&& readByte() < 0
&& readByte() < 0
&& readByte() < 0) {
break fspath;
}
}
return x;
}
rewind();
return (int) readVarInt64Slow();
}
// Source: https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/CodedInputStream.java
// Fastest implementation I could find
// Adapted to work with Channels
public long readInt64() {
if(wireType != ProtobufWireType.WIRE_TYPE_VAR_INT) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
return readInt64Unchecked();
}
private long readInt64Unchecked() {
mark();
fspath:
{
long x;
int y;
if ((y = readByte()) >= 0) {
return y;
} else if ((y ^= (readByte() << 7)) < 0) {
x = y ^ (~0 << 7);
} else if ((y ^= (readByte() << 14)) >= 0) {
x = y ^ ((~0 << 7) ^ (~0 << 14));
} else if ((y ^= (readByte() << 21)) < 0) {
x = y ^ ((~0 << 7) ^ (~0 << 14) ^ (~0 << 21));
} else if ((x = y ^ ((long) readByte() << 28)) >= 0L) {
x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28);
} else if ((x ^= ((long) readByte() << 35)) < 0L) {
x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35);
} else if ((x ^= ((long) readByte() << 42)) >= 0L) {
x ^= (~0L << 7) ^ (~0L << 14) ^ (~0L << 21) ^ (~0L << 28) ^ (~0L << 35) ^ (~0L << 42);
} else if ((x ^= ((long) readByte() << 49)) < 0L) {
x ^=
(~0L << 7)
^ (~0L << 14)
^ (~0L << 21)
^ (~0L << 28)
^ (~0L << 35)
^ (~0L << 42)
^ (~0L << 49);
} else {
x ^= ((long) readByte() << 56);
x ^=
(~0L << 7)
^ (~0L << 14)
^ (~0L << 21)
^ (~0L << 28)
^ (~0L << 35)
^ (~0L << 42)
^ (~0L << 49)
^ (~0L << 56);
if (x < 0L) {
if (readByte() < 0L) {
break fspath;
}
}
}
return x;
}
rewind();
return readVarInt64Slow();
}
private long readVarInt64Slow() {
var result = 0L;
for (int shift = 0; shift < 64; shift += 7) {
byte b = readByte();
result |= (long) (b & 0x7F) << shift;
if ((b & 0x80) == 0) {
return result;
}
}
throw ProtobufDeserializationException.malformedVarInt();
}
public int readFixed32() {
if(wireType != ProtobufWireType.WIRE_TYPE_FIXED32) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
return readFixed32Unchecked();
}
private int readFixed32Unchecked() {
return readByte() & 255
| (readByte() & 255) << 8
| (readByte() & 255) << 16
| (readByte() & 255) << 24;
}
public long readFixed64() {
if(wireType != ProtobufWireType.WIRE_TYPE_FIXED64) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
return readFixed64Unchecked();
}
private long readFixed64Unchecked() {
return (long) readByte() & 255L
| ((long) readByte() & 255L) << 8
| ((long) readByte() & 255L) << 16
| ((long) readByte() & 255L) << 24
| ((long) readByte() & 255L) << 32
| ((long) readByte() & 255L) << 40
| ((long) readByte() & 255L) << 48
| ((long) readByte() & 255L) << 56;
}
public ByteBuffer readBytes() {
if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
var size = this.readInt32Unchecked();
if(size < 0) {
throw ProtobufDeserializationException.negativeLength(size);
}else if(size == 0) {
return EMPTY_BUFFER;
}else {
return readBytes(size);
}
}
public Object readUnknown(boolean allocate) {
return switch (wireType) {
case ProtobufWireType.WIRE_TYPE_VAR_INT -> readInt64();
case ProtobufWireType.WIRE_TYPE_FIXED32 -> readFixed32();
case ProtobufWireType.WIRE_TYPE_FIXED64 -> readFixed64();
case ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED -> readBytes();
case ProtobufWireType.WIRE_TYPE_START_OBJECT -> readGroup(allocate);
default -> throw ProtobufDeserializationException.invalidWireType(wireType);
};
}
private Map readGroup(boolean allocate) {
var group = allocate ? new HashMap() : null;
var groupIndex = index;
while (readTag()) {
var value = readUnknown(allocate);
if(group != null) {
group.put(index, value);
}
}
assertGroupClosed(groupIndex);
return group;
}
public void assertGroupOpened(int groupIndex) {
if((wireType == -1 && !readTag()) || wireType != ProtobufWireType.WIRE_TYPE_START_OBJECT || index != groupIndex) {
throw ProtobufDeserializationException.invalidStartObject(groupIndex);
}
}
public void assertGroupClosed(int groupIndex) {
if(wireType != ProtobufWireType.WIRE_TYPE_END_OBJECT) {
throw ProtobufDeserializationException.malformedGroup();
}
if(index != groupIndex) {
throw ProtobufDeserializationException.invalidEndObject(index, groupIndex);
}
}
public int index() {
return index;
}
public ProtobufInputStream readLengthDelimited() {
if(wireType != ProtobufWireType.WIRE_TYPE_LENGTH_DELIMITED) {
throw ProtobufDeserializationException.invalidWireType(wireType);
}
var size = this.readInt32Unchecked();
if(size < 0) {
throw ProtobufDeserializationException.negativeLength(size);
}else {
return subStream(size);
}
}
protected abstract byte readByte();
protected abstract ByteBuffer readBytes(int size);
protected abstract ProtobufString readString(int size);
protected abstract void mark();
protected abstract void rewind();
protected abstract boolean isFinished();
protected abstract ProtobufInputStream subStream(int size);
private static final class Stream extends ProtobufInputStream {
private static final int MAX_VAR_INT_SIZE = 10;
private final InputStream inputStream;
private final long length;
private final byte[] buffer;
private long position;
private int bufferReadPosition;
private int bufferWritePosition;
private int bufferLength;
private Stream(InputStream inputStream) {
this.inputStream = inputStream;
this.length = -1;
this.buffer = new byte[MAX_VAR_INT_SIZE];
}
private Stream(InputStream inputStream, long length, byte[] buffer, int bufferReadPosition, int bufferWritePosition, int bufferLength) {
this.inputStream = inputStream;
this.length = length;
this.buffer = buffer;
this.bufferReadPosition = bufferReadPosition;
this.bufferWritePosition = bufferWritePosition;
this.bufferLength = bufferLength;
}
@Override
public byte readByte() {
try {
if(length != -1) {
position++;
}
if(bufferLength > 0) {
bufferLength--;
return buffer[bufferReadPosition++];
}
var result = (byte) inputStream.read();
buffer[bufferWritePosition++ % buffer.length] = result;
return result;
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}
@Override
public ByteBuffer readBytes(int size) {
try {
return ByteBuffer.wrap(readStreamBytes(size));
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}
@Override
public ProtobufString readString(int size) {
try {
return ProtobufString.lazy(readStreamBytes(size), 0, size);
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}
private byte[] readStreamBytes(int size) throws IOException {
if(length != -1) {
position += size;
}
var result = new byte[size];
for (int i = 0; i < size; i++) {
if(bufferLength > 0) {
result[i] = buffer[bufferReadPosition++];
bufferLength--;
}else {
var entry = (byte) inputStream.read();
result[i] = entry;
buffer[bufferWritePosition++ % buffer.length] = entry;
}
}
return result;
}
@Override
public void mark() {
this.bufferReadPosition = 0;
this.bufferWritePosition = 0;
}
@Override
public void rewind() {
this.bufferReadPosition = 0;
this.bufferLength = bufferWritePosition - bufferReadPosition;
if(length != -1) {
this.position -= bufferLength;
}
}
@Override
public boolean isFinished() {
if (length != -1) {
return position >= length;
}
mark();
var result = readByte() == -1;
rewind();
return result;
}
@Override
public Stream subStream(int size) {
var result = new Stream(inputStream, size, buffer, bufferReadPosition, bufferWritePosition, bufferLength);
if(length != -1) {
position += size;
}
return result;
}
}
private static final class Bytes extends ProtobufInputStream {
private final byte[] buffer;
private final int offset;
private final int length;
private int position;
private int marker;
private Bytes(byte[] buffer, int offset, int length) {
this.buffer = buffer;
this.offset = offset;
this.length = length;
}
@Override
public byte readByte() {
return buffer[offset + position++];
}
@Override
public ByteBuffer readBytes(int size) {
var result = ByteBuffer.wrap(buffer, offset + position, size);
position += size;
return result;
}
@Override
public ProtobufString readString(int size) {
var result = ProtobufString.lazy(buffer, offset + position, size);
position += size;
return result;
}
@Override
public void mark() {
this.marker = position;
}
@Override
public void rewind() {
if(marker == -1) {
throw new InvalidMarkException();
}
this.position = marker;
}
@Override
public boolean isFinished() {
return position >= length;
}
@Override
public Bytes subStream(int size) {
var result = new Bytes(buffer, offset + position, size);
position += size;
return result;
}
}
private static final class Buffer extends ProtobufInputStream {
private final ByteBuffer buffer;
private int length;
private Buffer(ByteBuffer buffer, int length) {
this.buffer = buffer;
this.length = length;
}
@Override
public byte readByte() {
var result = buffer.get();
length--;
return result;
}
@Override
public ByteBuffer readBytes(int size) {
var position = buffer.position();
var result = buffer.slice(position, position + size);
buffer.position(position + size);
length -= size;
return result;
}
@Override
public ProtobufString readString(int size) {
var position = buffer.position();
var result = buffer.slice(position, position + size);
buffer.position(position + size);
length -= size;
return ProtobufString.lazy(result.asReadOnlyBuffer());
}
@Override
public void mark() {
buffer.mark();
}
@Override
public void rewind() {
buffer.reset();
}
@Override
public boolean isFinished() {
return length <= 0;
}
@Override
public Buffer subStream(int size) {
var result = new Buffer(buffer, size);
buffer.position(buffer.position() + size);
length -= size;
return result;
}
}
}