io.trino.orc.stream.DecimalInputStream Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of trino-orc Show documentation
Show all versions of trino-orc Show documentation
Trino - ORC file format support
The newest version!
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.orc.stream;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.orc.OrcCorruptionException;
import io.trino.orc.checkpoint.DecimalStreamCheckpoint;
import java.io.IOException;
import static com.google.common.base.Verify.verify;
import static io.trino.orc.checkpoint.InputStreamCheckpoint.decodeCompressedBlockOffset;
import static io.trino.orc.checkpoint.InputStreamCheckpoint.decodeDecompressedOffset;
public class DecimalInputStream
implements ValueInputStream
{
private static final long LONG_MASK = 0x80_80_80_80_80_80_80_80L;
private static final int INT_MASK = 0x80_80_80_80;
private final OrcChunkLoader chunkLoader;
private Slice block = Slices.EMPTY_SLICE; // reference to current (decoded) data block. Can be either a reference to the buffer or to current chunk
private int blockOffset; // position within current data block
private long lastCheckpoint;
public DecimalInputStream(OrcChunkLoader chunkLoader)
{
this.chunkLoader = chunkLoader;
}
@Override
public void seekToCheckpoint(DecimalStreamCheckpoint checkpoint)
throws IOException
{
long newCheckpoint = checkpoint.getInputStreamCheckpoint();
// if checkpoint starts at the same compressed position...
// (we're checking for an empty block because empty blocks signify that we possibly read all the data in the existing
// buffer, so last checkpoint is no longer valid)
if (block.length() > 0 && decodeCompressedBlockOffset(newCheckpoint) == decodeCompressedBlockOffset(lastCheckpoint)) {
// and decompressed position is within our block, reposition in the block directly
int blockOffset = decodeDecompressedOffset(newCheckpoint) - decodeDecompressedOffset(lastCheckpoint);
if (blockOffset >= 0 && blockOffset < block.length()) {
this.blockOffset = blockOffset;
// do not change last checkpoint because we have not moved positions
return;
}
}
chunkLoader.seekToCheckpoint(newCheckpoint);
lastCheckpoint = newCheckpoint;
block = Slices.EMPTY_SLICE;
blockOffset = 0;
}
// result must have at least batchSize * 2 capacity
@SuppressWarnings("PointlessBitwiseExpression")
public void nextLongDecimal(long[] result, int batchSize)
throws IOException
{
verify(result.length >= batchSize * 2);
int count = 0;
while (count < batchSize) {
if (blockOffset == block.length()) {
advance();
}
while (blockOffset <= block.length() - 20) { // we'll read 2 longs + 1 int
long low;
long middle = 0;
int high = 0;
// low bits
long current = block.getLong(blockOffset);
int zeros = Long.numberOfTrailingZeros(~current & LONG_MASK);
int end = (zeros + 1) / 8;
blockOffset += end;
boolean negative = (current & 1) == 1;
low = (current & 0x7F_00_00_00_00_00_00_00L) >>> 7;
low |= (current & 0x7F_00_00_00_00_00_00L) >>> 6;
low |= (current & 0x7F_00_00_00_00_00L) >>> 5;
low |= (current & 0x7F_00_00_00_00L) >>> 4;
low |= (current & 0x7F_00_00_00) >>> 3;
low |= (current & 0x7F_00_00) >>> 2;
low |= (current & 0x7F_00) >>> 1;
low |= (current & 0x7F) >>> 0;
low = low & ((1L << (end * 7)) - 1);
// middle bits
if (zeros == 64) {
current = block.getLong(blockOffset);
zeros = Long.numberOfTrailingZeros(~current & LONG_MASK);
end = (zeros + 1) / 8;
blockOffset += end;
middle = (current & 0x7F_00_00_00_00_00_00_00L) >>> 7;
middle |= (current & 0x7F_00_00_00_00_00_00L) >>> 6;
middle |= (current & 0x7F_00_00_00_00_00L) >>> 5;
middle |= (current & 0x7F_00_00_00_00L) >>> 4;
middle |= (current & 0x7F_00_00_00) >>> 3;
middle |= (current & 0x7F_00_00) >>> 2;
middle |= (current & 0x7F_00) >>> 1;
middle |= (current & 0x7F) >>> 0;
middle = middle & ((1L << (end * 7)) - 1);
// high bits
if (zeros == 64) {
int last = block.getInt(blockOffset);
zeros = Integer.numberOfTrailingZeros(~last & INT_MASK);
end = (zeros + 1) / 8;
blockOffset += end;
high = (last & 0x7F_00_00) >>> 2;
high |= (last & 0x7F_00) >>> 1;
high |= (last & 0x7F) >>> 0;
high = high & ((1 << (end * 7)) - 1);
if (end == 4 || high > 0xFF_FF) { // only 127 - (55 + 56) = 16 bits allowed in high
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal exceeds 128 bits");
}
}
}
emitLongDecimal(result, count, low, middle, high, negative);
count++;
if (count == batchSize) {
return;
}
}
// handle the tail of the current block
count = decodeLongDecimalTail(result, count, batchSize);
}
}
private int decodeLongDecimalTail(long[] result, int count, int batchSize)
throws IOException
{
boolean negative = false;
long low = 0;
long middle = 0;
int high = 0;
long value;
boolean last = false;
if (blockOffset == block.length()) {
advance();
}
int offset = 0;
while (true) {
value = block.getByte(blockOffset);
blockOffset++;
if (offset == 0) {
negative = (value & 1) == 1;
low |= (value & 0x7F);
}
else if (offset < 8) {
low |= (value & 0x7F) << (offset * 7);
}
else if (offset < 16) {
middle |= (value & 0x7F) << ((offset - 8) * 7);
}
else if (offset < 19) {
high = (int) (high | (value & 0x7F) << ((offset - 16) * 7));
}
else {
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal exceeds 128 bits");
}
offset++;
if ((value & 0x80) == 0) {
if (high > 0xFF_FF) { // only 127 - (55 + 56) = 16 bits allowed in high
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal exceeds 128 bits");
}
emitLongDecimal(result, count, low, middle, high, negative);
count++;
low = 0;
middle = 0;
high = 0;
offset = 0;
if (blockOffset == block.length()) {
// the last value aligns with the end of the block, so just
// reset the block and loop around to optimized decoding
break;
}
if (last || count == batchSize) {
break;
}
}
else if (blockOffset == block.length()) {
last = true;
advance();
}
}
return count;
}
private static void emitLongDecimal(long[] result, int offset, long low, long middle, long high, boolean negative)
{
long lower = (low >>> 1) | (middle << 55); // drop the sign bit from low
long upper = (middle >>> 9) | (high << 47);
if (negative) {
// ORC encodes decimals using a zig-zag vint strategy
// For negative values, the encoded value is given by:
// encoded = -value * 2 - 1
//
// Therefore,
// value = -(encoded + 1) / 2
// = -encoded / 2 - 1/2
//
// Given the identity -v = ~v + 1 for negating a value using
// two's complement representation,
//
// value = (~encoded + 1) / 2 - 1/2
// = ~encoded / 2 + 1/2 - 1/2
// = ~encoded / 2
//
// The shift is performed above as the bits are assembled. The negation
// is performed here.
lower = ~lower;
upper = ~upper;
}
result[2 * offset] = upper;
result[2 * offset + 1] = lower;
}
@SuppressWarnings("PointlessBitwiseExpression")
public void nextShortDecimal(long[] result, int batchSize)
throws IOException
{
verify(result.length >= batchSize);
int count = 0;
while (count < batchSize) {
if (blockOffset == block.length()) {
advance();
}
while (blockOffset <= block.length() - 12) { // we'll read 1 longs + 1 int
long low;
int high = 0;
// low bits
long current = block.getLong(blockOffset);
int zeros = Long.numberOfTrailingZeros(~current & LONG_MASK);
int end = (zeros + 1) / 8;
blockOffset += end;
low = (current & 0x7F_00_00_00_00_00_00_00L) >>> 7;
low |= (current & 0x7F_00_00_00_00_00_00L) >>> 6;
low |= (current & 0x7F_00_00_00_00_00L) >>> 5;
low |= (current & 0x7F_00_00_00_00L) >>> 4;
low |= (current & 0x7F_00_00_00) >>> 3;
low |= (current & 0x7F_00_00) >>> 2;
low |= (current & 0x7F_00) >>> 1;
low |= (current & 0x7F) >>> 0;
low = low & ((1L << (end * 7)) - 1);
// high bits
if (zeros == 64) {
int last = block.getInt(blockOffset);
zeros = Integer.numberOfTrailingZeros(~last & INT_MASK);
end = (zeros + 1) / 8;
blockOffset += end;
high = (last & 0x7F_00) >>> 1;
high |= (last & 0x7F) >>> 0;
high = high & ((1 << (end * 7)) - 1);
if (end >= 3 || high > 0xFF) { // only 63 - (55) = 8 bits allowed in high
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal does not fit long (invalid table schema?)");
}
}
emitShortDecimal(result, count, low, high);
count++;
if (count == batchSize) {
return;
}
}
// handle the tail of the current block
count = decodeShortDecimalTail(result, count, batchSize);
}
}
private int decodeShortDecimalTail(long[] result, int count, int batchSize)
throws IOException
{
long low = 0;
long high = 0;
long value;
boolean last = false;
int offset = 0;
if (blockOffset == block.length()) {
advance();
}
while (true) {
value = block.getByte(blockOffset);
blockOffset++;
if (offset == 0) {
low |= (value & 0x7F);
}
else if (offset < 8) {
low |= (value & 0x7F) << (offset * 7);
}
else if (offset < 11) {
high |= (value & 0x7F) << ((offset - 8) * 7);
}
else {
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal does not fit long (invalid table schema?)");
}
offset++;
if ((value & 0x80) == 0) {
if (high > 0xFF) { // only 63 - (55) = 8 bits allowed in high
throw new OrcCorruptionException(chunkLoader.getOrcDataSourceId(), "Decimal does not fit long (invalid table schema?)");
}
emitShortDecimal(result, count, low, high);
count++;
low = 0;
high = 0;
offset = 0;
if (blockOffset == block.length()) {
// the last value aligns with the end of the block, so just
// reset the block and loop around to optimized decoding
break;
}
if (last || count == batchSize) {
break;
}
}
else if (blockOffset == block.length()) {
last = true;
advance();
}
}
return count;
}
private static void emitShortDecimal(long[] result, int offset, long low, long high)
{
boolean negative = (low & 1) == 1;
long value = (low >>> 1) | (high << 55); // drop the sign bit from low
if (negative) {
value = ~value;
}
result[offset] = value;
}
@Override
public void skip(long items)
throws IOException
{
if (items == 0) {
return;
}
if (blockOffset == block.length()) {
advance();
}
int count = 0;
while (true) {
while (blockOffset <= block.length() - Long.BYTES) { // only safe if there's at least one long to read
long current = block.getLong(blockOffset);
int increment = Long.bitCount(~current & LONG_MASK);
if (count + increment >= items) {
// reached the tail, so bail out and process byte at a time
break;
}
count += increment;
blockOffset += Long.BYTES;
}
while (blockOffset < block.length()) { // tail -- byte at a time
byte current = block.getByte(blockOffset);
blockOffset++;
if ((current & 0x80) == 0) {
count++;
if (count == items) {
return;
}
}
}
advance();
}
}
private void advance()
throws IOException
{
block = chunkLoader.nextChunk();
lastCheckpoint = chunkLoader.getLastCheckpoint();
blockOffset = 0;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy