io.trino.orc.stream.LongInputStreamV2 Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.orc.stream;
import io.trino.orc.OrcCorruptionException;
import io.trino.orc.checkpoint.LongStreamCheckpoint;
import io.trino.orc.checkpoint.LongStreamV2Checkpoint;
import java.io.IOException;
import java.io.InputStream;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
/**
* See {@code org.apache.orc.impl.RunLengthIntegerWriterV2} for description of various lightweight compression techniques.
*/
// This comes from the Apache Hive ORC code
public class LongInputStreamV2
implements LongInputStream
{
private static final int MIN_REPEAT_SIZE = 3;
private static final int MAX_LITERAL_SIZE = 512;
private enum EncodingType
{
SHORT_REPEAT, DIRECT, PATCHED_BASE, DELTA
}
private final LongBitPacker packer = new LongBitPacker();
private final OrcInputStream input;
private final boolean signed;
private final long[] literals = new long[MAX_LITERAL_SIZE];
private int numLiterals;
private int used;
private final boolean skipCorrupt;
private long lastReadInputCheckpoint;
public LongInputStreamV2(OrcInputStream input, boolean signed, boolean skipCorrupt)
{
this.input = input;
this.signed = signed;
this.skipCorrupt = skipCorrupt;
lastReadInputCheckpoint = input.getCheckpoint();
}
// This comes from the Apache Hive ORC code
private void readValues()
throws IOException
{
lastReadInputCheckpoint = input.getCheckpoint();
// read the first 2 bits and determine the encoding type
int firstByte = input.read();
if (firstByte < 0) {
throw new OrcCorruptionException(input.getOrcDataSourceId(), "Read past end of RLE integer");
}
int enc = (firstByte >>> 6) & 0x03;
if (EncodingType.SHORT_REPEAT.ordinal() == enc) {
readShortRepeatValues(firstByte);
}
else if (EncodingType.DIRECT.ordinal() == enc) {
readDirectValues(firstByte);
}
else if (EncodingType.PATCHED_BASE.ordinal() == enc) {
readPatchedBaseValues(firstByte);
}
else {
readDeltaValues(firstByte);
}
}
// This comes from the Apache Hive ORC code
private void readDeltaValues(int firstByte)
throws IOException
{
// extract the number of fixed bits
int fixedBits = (firstByte >>> 1) & 0x1f;
if (fixedBits != 0) {
fixedBits = LongDecode.decodeBitWidth(fixedBits);
}
// extract the blob run length
int length = (firstByte & 0x01) << 8;
length |= input.read();
// read the first value stored as vint
long firstVal = LongDecode.readVInt(signed, input);
// store first value to result buffer
literals[numLiterals++] = firstVal;
// if fixed bits is 0 then all values have fixed delta
long prevVal;
if (fixedBits == 0) {
// read the fixed delta value stored as vint (deltas can be negative even
// if all number are positive)
long fixedDelta = LongDecode.readSignedVInt(input);
// add fixed deltas to adjacent values
for (int i = 0; i < length; i++) {
literals[numLiterals++] = literals[numLiterals - 2] + fixedDelta;
}
}
else {
long deltaBase = LongDecode.readSignedVInt(input);
// add delta base and first value
literals[numLiterals++] = firstVal + deltaBase;
prevVal = literals[numLiterals - 1];
length -= 1;
// write the unpacked values, add it to previous value and store final
// value to result buffer. if the delta base value is negative then it
// is a decreasing sequence else an increasing sequence
packer.unpack(literals, numLiterals, length, fixedBits, input);
while (length > 0) {
if (deltaBase < 0) {
literals[numLiterals] = prevVal - literals[numLiterals];
}
else {
literals[numLiterals] = prevVal + literals[numLiterals];
}
prevVal = literals[numLiterals];
length--;
numLiterals++;
}
}
}
// This comes from the Apache Hive ORC code
private void readPatchedBaseValues(int firstByte)
throws IOException
{
// extract the number of fixed bits
int fb = LongDecode.decodeBitWidth((firstByte >>> 1) & 0b1_1111);
// extract the run length of data blob
int length = (firstByte & 0b1) << 8;
length |= input.read();
// runs are always one off
length += 1;
// extract the number of bytes occupied by base
int thirdByte = input.read();
int baseWidth = (thirdByte >>> 5) & 0b0111;
// base width is one off
baseWidth += 1;
// extract patch width
int patchWidth = LongDecode.decodeBitWidth(thirdByte & 0b1_1111);
// read fourth byte and extract patch gap width
int fourthByte = input.read();
int patchGapWidth = (fourthByte >>> 5) & 0b0111;
// patch gap width is one off
patchGapWidth += 1;
// extract the length of the patch list
int patchListLength = fourthByte & 0b1_1111;
// read the next base width number of bytes to extract base value
long base = bytesToLongBE(input, baseWidth);
long mask = (1L << ((baseWidth * 8) - 1));
// if MSB of base value is 1 then base is negative value else positive
if ((base & mask) != 0) {
base = base & ~mask;
base = -base;
}
// unpack the data blob
long[] unpacked = new long[length];
packer.unpack(unpacked, 0, length, fb, input);
// unpack the patch blob
long[] unpackedPatch = new long[patchListLength];
if ((patchWidth + patchGapWidth) > 64 && !skipCorrupt) {
throw new OrcCorruptionException(input.getOrcDataSourceId(), "Invalid RLEv2 encoded stream");
}
int bitSize = LongDecode.getClosestFixedBits(patchWidth + patchGapWidth);
packer.unpack(unpackedPatch, 0, patchListLength, bitSize, input);
// apply the patch directly when decoding the packed data
int patchIndex = 0;
long currentGap;
long currentPatch;
long patchMask = ((1L << patchWidth) - 1);
currentGap = unpackedPatch[patchIndex] >>> patchWidth;
currentPatch = unpackedPatch[patchIndex] & patchMask;
long actualGap = 0;
// special case: gap is >255 then patch value will be 0.
// if gap is <=255 then patch value cannot be 0
while (currentGap == 255 && currentPatch == 0) {
actualGap += 255;
patchIndex++;
currentGap = unpackedPatch[patchIndex] >>> patchWidth;
currentPatch = unpackedPatch[patchIndex] & patchMask;
}
// add the left over gap
actualGap += currentGap;
// unpack data blob, patch it (if required), add base to get final result
for (int i = 0; i < unpacked.length; i++) {
if (i == actualGap) {
// extract the patch value
long patchedValue = unpacked[i] | (currentPatch << fb);
// add base to patched value
literals[numLiterals++] = base + patchedValue;
// increment the patch to point to next entry in patch list
patchIndex++;
if (patchIndex < patchListLength) {
// read the next gap and patch
currentGap = unpackedPatch[patchIndex] >>> patchWidth;
currentPatch = unpackedPatch[patchIndex] & patchMask;
actualGap = 0;
// special case: gap is >255 then patch will be 0. if gap is
// <=255 then patch cannot be 0
while (currentGap == 255 && currentPatch == 0) {
actualGap += 255;
patchIndex++;
currentGap = unpackedPatch[patchIndex] >>> patchWidth;
currentPatch = unpackedPatch[patchIndex] & patchMask;
}
// add the left over gap
actualGap += currentGap;
// next gap is relative to the current gap
actualGap += i;
}
}
else {
// no patching required. add base to unpacked value to get final value
literals[numLiterals++] = base + unpacked[i];
}
}
}
// This comes from the Apache Hive ORC code
private void readDirectValues(int firstByte)
throws IOException
{
// extract the number of fixed bits
int fixedBits = LongDecode.decodeBitWidth((firstByte >>> 1) & 0b1_1111);
// extract the run length
int length = (firstByte & 0b1) << 8;
length |= input.read();
// runs are one off
length += 1;
// write the unpacked values and zigzag decode to result buffer
packer.unpack(literals, numLiterals, length, fixedBits, input);
if (signed) {
for (int i = 0; i < length; i++) {
literals[numLiterals] = LongDecode.zigzagDecode(literals[numLiterals]);
numLiterals++;
}
}
else {
numLiterals += length;
}
}
// This comes from the Apache Hive ORC code
private void readShortRepeatValues(int firstByte)
throws IOException
{
// read the number of bytes occupied by the value
int size = (firstByte >>> 3) & 0b0111;
// #bytes are one off
size += 1;
// read the run length
int length = firstByte & 0x07;
// run lengths values are stored only after MIN_REPEAT value is met
length += MIN_REPEAT_SIZE;
// read the repeated value which is store using fixed bytes
long val = bytesToLongBE(input, size);
if (signed) {
val = LongDecode.zigzagDecode(val);
}
// repeat the value for length times
for (int i = 0; i < length; i++) {
literals[numLiterals++] = val;
}
}
/**
* Read n bytes in big endian order and convert to long.
*/
private static long bytesToLongBE(InputStream input, int n)
throws IOException
{
long out = 0;
long val;
while (n > 0) {
n--;
// store it in a long and then shift else integer overflow will occur
val = input.read();
out |= (val << (n * 8));
}
return out;
}
@Override
public long next()
throws IOException
{
if (used == numLiterals) {
numLiterals = 0;
used = 0;
readValues();
}
return literals[used++];
}
@Override
public void next(long[] values, int items)
throws IOException
{
int offset = 0;
while (items > 0) {
if (used == numLiterals) {
numLiterals = 0;
used = 0;
readValues();
}
int chunkSize = min(numLiterals - used, items);
System.arraycopy(literals, used, values, offset, chunkSize);
used += chunkSize;
offset += chunkSize;
items -= chunkSize;
}
}
@Override
public void next(int[] values, int items)
throws IOException
{
int offset = 0;
while (items > 0) {
if (used == numLiterals) {
numLiterals = 0;
used = 0;
readValues();
}
int chunkSize = min(numLiterals - used, items);
for (int i = 0; i < chunkSize; i++) {
long literal = literals[used + i];
int value = (int) literal;
if (literal != value) {
throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decoded value out of range for a 32bit number");
}
values[offset + i] = value;
}
used += chunkSize;
offset += chunkSize;
items -= chunkSize;
}
}
@Override
public void next(short[] values, int items)
throws IOException
{
int offset = 0;
while (items > 0) {
if (used == numLiterals) {
numLiterals = 0;
used = 0;
readValues();
}
int chunkSize = min(numLiterals - used, items);
for (int i = 0; i < chunkSize; i++) {
long literal = literals[used + i];
short value = (short) literal;
if (literal != value) {
throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decoded value out of range for a 16bit number");
}
values[offset + i] = value;
}
used += chunkSize;
offset += chunkSize;
items -= chunkSize;
}
}
@Override
public void seekToCheckpoint(LongStreamCheckpoint checkpoint)
throws IOException
{
LongStreamV2Checkpoint v2Checkpoint = (LongStreamV2Checkpoint) checkpoint;
// if the checkpoint is within the current buffer, just adjust the pointer
if (lastReadInputCheckpoint == v2Checkpoint.getInputStreamCheckpoint() && v2Checkpoint.getOffset() <= numLiterals) {
used = v2Checkpoint.getOffset();
}
else {
// otherwise, discard the buffer and start over
input.seekToCheckpoint(v2Checkpoint.getInputStreamCheckpoint());
numLiterals = 0;
used = 0;
skip(v2Checkpoint.getOffset());
}
}
@Override
public void skip(long items)
throws IOException
{
while (items > 0) {
if (used == numLiterals) {
numLiterals = 0;
used = 0;
readValues();
}
int consume = toIntExact(min(items, numLiterals - used));
used += consume;
items -= consume;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy