All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.orc.stream.LongInputStreamV2 Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.orc.stream;

import io.trino.orc.OrcCorruptionException;
import io.trino.orc.checkpoint.LongStreamCheckpoint;
import io.trino.orc.checkpoint.LongStreamV2Checkpoint;

import java.io.IOException;
import java.io.InputStream;

import static java.lang.Math.min;
import static java.lang.Math.toIntExact;

/**
 * See {@code org.apache.orc.impl.RunLengthIntegerWriterV2} for description of various lightweight compression techniques.
 */
// This comes from the Apache Hive ORC code
public class LongInputStreamV2
        implements LongInputStream
{
    private static final int MIN_REPEAT_SIZE = 3;
    private static final int MAX_LITERAL_SIZE = 512;

    private enum EncodingType
    {
        SHORT_REPEAT, DIRECT, PATCHED_BASE, DELTA
    }

    private final LongBitPacker packer = new LongBitPacker();
    private final OrcInputStream input;
    private final boolean signed;
    private final long[] literals = new long[MAX_LITERAL_SIZE];
    private int numLiterals;
    private int used;
    private final boolean skipCorrupt;
    private long lastReadInputCheckpoint;

    public LongInputStreamV2(OrcInputStream input, boolean signed, boolean skipCorrupt)
    {
        this.input = input;
        this.signed = signed;
        this.skipCorrupt = skipCorrupt;
        lastReadInputCheckpoint = input.getCheckpoint();
    }

    // This comes from the Apache Hive ORC code
    private void readValues()
            throws IOException
    {
        lastReadInputCheckpoint = input.getCheckpoint();

        // read the first 2 bits and determine the encoding type
        int firstByte = input.read();
        if (firstByte < 0) {
            throw new OrcCorruptionException(input.getOrcDataSourceId(), "Read past end of RLE integer");
        }

        int enc = (firstByte >>> 6) & 0x03;
        if (EncodingType.SHORT_REPEAT.ordinal() == enc) {
            readShortRepeatValues(firstByte);
        }
        else if (EncodingType.DIRECT.ordinal() == enc) {
            readDirectValues(firstByte);
        }
        else if (EncodingType.PATCHED_BASE.ordinal() == enc) {
            readPatchedBaseValues(firstByte);
        }
        else {
            readDeltaValues(firstByte);
        }
    }

    // This comes from the Apache Hive ORC code
    private void readDeltaValues(int firstByte)
            throws IOException
    {
        // extract the number of fixed bits
        int fixedBits = (firstByte >>> 1) & 0x1f;
        if (fixedBits != 0) {
            fixedBits = LongDecode.decodeBitWidth(fixedBits);
        }

        // extract the blob run length
        int length = (firstByte & 0x01) << 8;
        length |= input.read();

        // read the first value stored as vint
        long firstVal = LongDecode.readVInt(signed, input);

        // store first value to result buffer
        literals[numLiterals++] = firstVal;

        // if fixed bits is 0 then all values have fixed delta
        long prevVal;
        if (fixedBits == 0) {
            // read the fixed delta value stored as vint (deltas can be negative even
            // if all number are positive)
            long fixedDelta = LongDecode.readSignedVInt(input);

            // add fixed deltas to adjacent values
            for (int i = 0; i < length; i++) {
                literals[numLiterals++] = literals[numLiterals - 2] + fixedDelta;
            }
        }
        else {
            long deltaBase = LongDecode.readSignedVInt(input);
            // add delta base and first value
            literals[numLiterals++] = firstVal + deltaBase;
            prevVal = literals[numLiterals - 1];
            length -= 1;

            // write the unpacked values, add it to previous value and store final
            // value to result buffer. if the delta base value is negative then it
            // is a decreasing sequence else an increasing sequence
            packer.unpack(literals, numLiterals, length, fixedBits, input);
            while (length > 0) {
                if (deltaBase < 0) {
                    literals[numLiterals] = prevVal - literals[numLiterals];
                }
                else {
                    literals[numLiterals] = prevVal + literals[numLiterals];
                }
                prevVal = literals[numLiterals];
                length--;
                numLiterals++;
            }
        }
    }

    // This comes from the Apache Hive ORC code
    private void readPatchedBaseValues(int firstByte)
            throws IOException
    {
        // extract the number of fixed bits
        int fb = LongDecode.decodeBitWidth((firstByte >>> 1) & 0b1_1111);

        // extract the run length of data blob
        int length = (firstByte & 0b1) << 8;
        length |= input.read();
        // runs are always one off
        length += 1;

        // extract the number of bytes occupied by base
        int thirdByte = input.read();
        int baseWidth = (thirdByte >>> 5) & 0b0111;
        // base width is one off
        baseWidth += 1;

        // extract patch width
        int patchWidth = LongDecode.decodeBitWidth(thirdByte & 0b1_1111);

        // read fourth byte and extract patch gap width
        int fourthByte = input.read();
        int patchGapWidth = (fourthByte >>> 5) & 0b0111;
        // patch gap width is one off
        patchGapWidth += 1;

        // extract the length of the patch list
        int patchListLength = fourthByte & 0b1_1111;

        // read the next base width number of bytes to extract base value
        long base = bytesToLongBE(input, baseWidth);
        long mask = (1L << ((baseWidth * 8) - 1));
        // if MSB of base value is 1 then base is negative value else positive
        if ((base & mask) != 0) {
            base = base & ~mask;
            base = -base;
        }

        // unpack the data blob
        long[] unpacked = new long[length];
        packer.unpack(unpacked, 0, length, fb, input);

        // unpack the patch blob
        long[] unpackedPatch = new long[patchListLength];

        if ((patchWidth + patchGapWidth) > 64 && !skipCorrupt) {
            throw new OrcCorruptionException(input.getOrcDataSourceId(), "Invalid RLEv2 encoded stream");
        }

        int bitSize = LongDecode.getClosestFixedBits(patchWidth + patchGapWidth);
        packer.unpack(unpackedPatch, 0, patchListLength, bitSize, input);

        // apply the patch directly when decoding the packed data
        int patchIndex = 0;
        long currentGap;
        long currentPatch;
        long patchMask = ((1L << patchWidth) - 1);
        currentGap = unpackedPatch[patchIndex] >>> patchWidth;
        currentPatch = unpackedPatch[patchIndex] & patchMask;
        long actualGap = 0;

        // special case: gap is >255 then patch value will be 0.
        // if gap is <=255 then patch value cannot be 0
        while (currentGap == 255 && currentPatch == 0) {
            actualGap += 255;
            patchIndex++;
            currentGap = unpackedPatch[patchIndex] >>> patchWidth;
            currentPatch = unpackedPatch[patchIndex] & patchMask;
        }
        // add the left over gap
        actualGap += currentGap;

        // unpack data blob, patch it (if required), add base to get final result
        for (int i = 0; i < unpacked.length; i++) {
            if (i == actualGap) {
                // extract the patch value
                long patchedValue = unpacked[i] | (currentPatch << fb);

                // add base to patched value
                literals[numLiterals++] = base + patchedValue;

                // increment the patch to point to next entry in patch list
                patchIndex++;

                if (patchIndex < patchListLength) {
                    // read the next gap and patch
                    currentGap = unpackedPatch[patchIndex] >>> patchWidth;
                    currentPatch = unpackedPatch[patchIndex] & patchMask;
                    actualGap = 0;

                    // special case: gap is >255 then patch will be 0. if gap is
                    // <=255 then patch cannot be 0
                    while (currentGap == 255 && currentPatch == 0) {
                        actualGap += 255;
                        patchIndex++;
                        currentGap = unpackedPatch[patchIndex] >>> patchWidth;
                        currentPatch = unpackedPatch[patchIndex] & patchMask;
                    }
                    // add the left over gap
                    actualGap += currentGap;

                    // next gap is relative to the current gap
                    actualGap += i;
                }
            }
            else {
                // no patching required. add base to unpacked value to get final value
                literals[numLiterals++] = base + unpacked[i];
            }
        }
    }

    // This comes from the Apache Hive ORC code
    private void readDirectValues(int firstByte)
            throws IOException
    {
        // extract the number of fixed bits
        int fixedBits = LongDecode.decodeBitWidth((firstByte >>> 1) & 0b1_1111);

        // extract the run length
        int length = (firstByte & 0b1) << 8;
        length |= input.read();
        // runs are one off
        length += 1;

        // write the unpacked values and zigzag decode to result buffer
        packer.unpack(literals, numLiterals, length, fixedBits, input);
        if (signed) {
            for (int i = 0; i < length; i++) {
                literals[numLiterals] = LongDecode.zigzagDecode(literals[numLiterals]);
                numLiterals++;
            }
        }
        else {
            numLiterals += length;
        }
    }

    // This comes from the Apache Hive ORC code
    private void readShortRepeatValues(int firstByte)
            throws IOException
    {
        // read the number of bytes occupied by the value
        int size = (firstByte >>> 3) & 0b0111;
        // #bytes are one off
        size += 1;

        // read the run length
        int length = firstByte & 0x07;
        // run lengths values are stored only after MIN_REPEAT value is met
        length += MIN_REPEAT_SIZE;

        // read the repeated value which is store using fixed bytes
        long val = bytesToLongBE(input, size);

        if (signed) {
            val = LongDecode.zigzagDecode(val);
        }

        // repeat the value for length times
        for (int i = 0; i < length; i++) {
            literals[numLiterals++] = val;
        }
    }

    /**
     * Read n bytes in big endian order and convert to long.
     */
    private static long bytesToLongBE(InputStream input, int n)
            throws IOException
    {
        long out = 0;
        long val;
        while (n > 0) {
            n--;
            // store it in a long and then shift else integer overflow will occur
            val = input.read();
            out |= (val << (n * 8));
        }
        return out;
    }

    @Override
    public long next()
            throws IOException
    {
        if (used == numLiterals) {
            numLiterals = 0;
            used = 0;
            readValues();
        }
        return literals[used++];
    }

    @Override
    public void next(long[] values, int items)
            throws IOException
    {
        int offset = 0;
        while (items > 0) {
            if (used == numLiterals) {
                numLiterals = 0;
                used = 0;
                readValues();
            }

            int chunkSize = min(numLiterals - used, items);
            System.arraycopy(literals, used, values, offset, chunkSize);
            used += chunkSize;
            offset += chunkSize;
            items -= chunkSize;
        }
    }

    @Override
    public void next(int[] values, int items)
            throws IOException
    {
        int offset = 0;
        while (items > 0) {
            if (used == numLiterals) {
                numLiterals = 0;
                used = 0;
                readValues();
            }

            int chunkSize = min(numLiterals - used, items);
            for (int i = 0; i < chunkSize; i++) {
                long literal = literals[used + i];
                int value = (int) literal;
                if (literal != value) {
                    throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decoded value out of range for a 32bit number");
                }
                values[offset + i] = value;
            }
            used += chunkSize;
            offset += chunkSize;
            items -= chunkSize;
        }
    }

    @Override
    public void next(short[] values, int items)
            throws IOException
    {
        int offset = 0;
        while (items > 0) {
            if (used == numLiterals) {
                numLiterals = 0;
                used = 0;
                readValues();
            }

            int chunkSize = min(numLiterals - used, items);
            for (int i = 0; i < chunkSize; i++) {
                long literal = literals[used + i];
                short value = (short) literal;
                if (literal != value) {
                    throw new OrcCorruptionException(input.getOrcDataSourceId(), "Decoded value out of range for a 16bit number");
                }
                values[offset + i] = value;
            }
            used += chunkSize;
            offset += chunkSize;
            items -= chunkSize;
        }
    }

    @Override
    public void seekToCheckpoint(LongStreamCheckpoint checkpoint)
            throws IOException
    {
        LongStreamV2Checkpoint v2Checkpoint = (LongStreamV2Checkpoint) checkpoint;

        // if the checkpoint is within the current buffer, just adjust the pointer
        if (lastReadInputCheckpoint == v2Checkpoint.getInputStreamCheckpoint() && v2Checkpoint.getOffset() <= numLiterals) {
            used = v2Checkpoint.getOffset();
        }
        else {
            // otherwise, discard the buffer and start over
            input.seekToCheckpoint(v2Checkpoint.getInputStreamCheckpoint());
            numLiterals = 0;
            used = 0;
            skip(v2Checkpoint.getOffset());
        }
    }

    @Override
    public void skip(long items)
            throws IOException
    {
        while (items > 0) {
            if (used == numLiterals) {
                numLiterals = 0;
                used = 0;
                readValues();
            }
            int consume = toIntExact(min(items, numLiterals - used));
            used += consume;
            items -= consume;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy