All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.orc.StripeReader Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.orc;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import io.airlift.slice.Slice;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.orc.checkpoint.InvalidCheckpointException;
import io.trino.orc.checkpoint.StreamCheckpoint;
import io.trino.orc.metadata.ColumnEncoding;
import io.trino.orc.metadata.ColumnEncoding.ColumnEncodingKind;
import io.trino.orc.metadata.ColumnMetadata;
import io.trino.orc.metadata.MetadataReader;
import io.trino.orc.metadata.OrcColumnId;
import io.trino.orc.metadata.OrcType;
import io.trino.orc.metadata.OrcType.OrcTypeKind;
import io.trino.orc.metadata.PostScript.HiveWriterVersion;
import io.trino.orc.metadata.RowGroupIndex;
import io.trino.orc.metadata.Stream;
import io.trino.orc.metadata.StripeFooter;
import io.trino.orc.metadata.StripeInformation;
import io.trino.orc.metadata.statistics.BloomFilter;
import io.trino.orc.metadata.statistics.ColumnStatistics;
import io.trino.orc.stream.InputStreamSource;
import io.trino.orc.stream.InputStreamSources;
import io.trino.orc.stream.OrcChunkLoader;
import io.trino.orc.stream.OrcDataReader;
import io.trino.orc.stream.OrcInputStream;
import io.trino.orc.stream.ValueInputStream;
import io.trino.orc.stream.ValueInputStreamSource;
import io.trino.orc.stream.ValueStreams;

import java.io.IOException;
import java.io.InputStream;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.orc.checkpoint.Checkpoints.getDictionaryStreamCheckpoint;
import static io.trino.orc.checkpoint.Checkpoints.getStreamCheckpoints;
import static io.trino.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY;
import static io.trino.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY_V2;
import static io.trino.orc.metadata.Stream.StreamKind.BLOOM_FILTER;
import static io.trino.orc.metadata.Stream.StreamKind.BLOOM_FILTER_UTF8;
import static io.trino.orc.metadata.Stream.StreamKind.DICTIONARY_COUNT;
import static io.trino.orc.metadata.Stream.StreamKind.DICTIONARY_DATA;
import static io.trino.orc.metadata.Stream.StreamKind.LENGTH;
import static io.trino.orc.metadata.Stream.StreamKind.ROW_INDEX;
import static io.trino.orc.stream.CheckpointInputStreamSource.createCheckpointStreamSource;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class StripeReader
{
    private final OrcDataSource orcDataSource;
    private final ZoneId legacyFileTimeZone;
    private final Optional decompressor;
    private final ColumnMetadata types;
    private final HiveWriterVersion hiveWriterVersion;
    private final Set includedOrcColumnIds;
    private final OptionalInt rowsInRowGroup;
    private final OrcPredicate predicate;
    private final MetadataReader metadataReader;
    private final Optional writeValidation;

    public StripeReader(
            OrcDataSource orcDataSource,
            ZoneId legacyFileTimeZone,
            Optional decompressor,
            ColumnMetadata types,
            Set readColumns,
            OptionalInt rowsInRowGroup,
            OrcPredicate predicate,
            HiveWriterVersion hiveWriterVersion,
            MetadataReader metadataReader,
            Optional writeValidation)
    {
        this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
        this.legacyFileTimeZone = requireNonNull(legacyFileTimeZone, "legacyFileTimeZone is null");
        this.decompressor = requireNonNull(decompressor, "decompressor is null");
        this.types = requireNonNull(types, "types is null");
        this.includedOrcColumnIds = getIncludeColumns(requireNonNull(readColumns, "readColumns is null"));
        this.rowsInRowGroup = rowsInRowGroup;
        this.predicate = requireNonNull(predicate, "predicate is null");
        this.hiveWriterVersion = requireNonNull(hiveWriterVersion, "hiveWriterVersion is null");
        this.metadataReader = requireNonNull(metadataReader, "metadataReader is null");
        this.writeValidation = requireNonNull(writeValidation, "writeValidation is null");
    }

    public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext memoryUsage)
            throws IOException
    {
        // read the stripe footer
        StripeFooter stripeFooter = readStripeFooter(stripe, memoryUsage);
        ColumnMetadata columnEncodings = stripeFooter.getColumnEncodings();
        if (writeValidation.isPresent()) {
            writeValidation.get().validateTimeZone(orcDataSource.getId(), stripeFooter.getTimeZone());
        }
        ZoneId fileTimeZone = stripeFooter.getTimeZone();

        // get streams for selected columns
        Map streams = new HashMap<>();
        for (Stream stream : stripeFooter.getStreams()) {
            if (includedOrcColumnIds.contains(stream.getColumnId()) && isSupportedStreamType(stream, types.get(stream.getColumnId()).getOrcTypeKind())) {
                streams.put(new StreamId(stream), stream);
            }
        }

        // handle stripes with more than one row group
        boolean invalidCheckPoint = false;
        if (rowsInRowGroup.isPresent() && stripe.getNumberOfRows() > rowsInRowGroup.getAsInt()) {
            // determine ranges of the stripe to read
            Map diskRanges = getDiskRanges(stripeFooter.getStreams());
            diskRanges = Maps.filterKeys(diskRanges, Predicates.in(streams.keySet()));

            // read the file regions
            Map streamsData = readDiskRanges(stripe.getOffset(), diskRanges, memoryUsage);

            // read the bloom filter for each column
            Map> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData);

            // read the row index for each column
            Map> columnIndexes = readColumnIndexes(streams, streamsData, bloomFilterIndexes);
            if (writeValidation.isPresent()) {
                writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), stripe.getOffset(), columnIndexes);
            }

            // select the row groups matching the tuple domain
            Set selectedRowGroups = selectRowGroups(stripe, columnIndexes);

            // if all row groups are skipped, return null
            if (selectedRowGroups.isEmpty()) {
                // set accounted memory usage to zero
                memoryUsage.close();
                return null;
            }

            // value streams
            Map> valueStreams = createValueStreams(streams, streamsData, columnEncodings);

            // build the dictionary streams
            InputStreamSources dictionaryStreamSources = createDictionaryStreamSources(streams, valueStreams, columnEncodings);

            // build the row groups
            try {
                List rowGroups = createRowGroups(
                        stripe.getNumberOfRows(),
                        streams,
                        valueStreams,
                        columnIndexes,
                        selectedRowGroups,
                        columnEncodings);

                return new Stripe(stripe.getNumberOfRows(), fileTimeZone, columnEncodings, rowGroups, dictionaryStreamSources);
            }
            catch (InvalidCheckpointException e) {
                // The ORC file contains a corrupt checkpoint stream treat the stripe as a single row group.
                invalidCheckPoint = true;
            }
        }

        // stripe only has one row group
        ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
        for (Entry entry : getDiskRanges(stripeFooter.getStreams()).entrySet()) {
            StreamId streamId = entry.getKey();
            if (streams.containsKey(streamId)) {
                diskRangesBuilder.put(entry);
            }
        }
        ImmutableMap diskRanges = diskRangesBuilder.buildOrThrow();

        // read the file regions
        Map streamsData = readDiskRanges(stripe.getOffset(), diskRanges, memoryUsage);

        long minAverageRowBytes = 0;
        for (Entry entry : streams.entrySet()) {
            if (entry.getKey().getStreamKind() == ROW_INDEX) {
                List rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, new OrcInputStream(streamsData.get(entry.getKey())));
                checkState(rowGroupIndexes.size() == 1 || invalidCheckPoint, "expect a single row group or an invalid check point");
                long totalBytes = 0;
                long totalRows = 0;
                for (RowGroupIndex rowGroupIndex : rowGroupIndexes) {
                    ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics();
                    if (columnStatistics.hasMinAverageValueSizeInBytes()) {
                        totalBytes += columnStatistics.getMinAverageValueSizeInBytes() * columnStatistics.getNumberOfValues();
                        totalRows += columnStatistics.getNumberOfValues();
                    }
                }
                if (totalRows > 0) {
                    minAverageRowBytes += totalBytes / totalRows;
                }
            }
        }

        // value streams
        Map> valueStreams = createValueStreams(streams, streamsData, columnEncodings);

        // build the dictionary streams
        InputStreamSources dictionaryStreamSources = createDictionaryStreamSources(streams, valueStreams, columnEncodings);

        // build the row group
        ImmutableMap.Builder> builder = ImmutableMap.builder();
        for (Entry> entry : valueStreams.entrySet()) {
            builder.put(entry.getKey(), new ValueInputStreamSource<>(entry.getValue()));
        }
        RowGroup rowGroup = new RowGroup(0, 0, stripe.getNumberOfRows(), minAverageRowBytes, new InputStreamSources(builder.buildOrThrow()));

        return new Stripe(stripe.getNumberOfRows(), fileTimeZone, columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
    }

    private static boolean isSupportedStreamType(Stream stream, OrcTypeKind orcTypeKind)
    {
        if (stream.getStreamKind() == BLOOM_FILTER) {
            return switch (orcTypeKind) {
                // non-utf8 bloom filters are not allowed for character types
                case STRING, VARCHAR, CHAR -> false;
                // non-utf8 bloom filters are not supported for timestamp
                case TIMESTAMP, TIMESTAMP_INSTANT -> false;
                default -> true;
            };
        }
        if (stream.getStreamKind() == BLOOM_FILTER_UTF8) {
            // char types require padding for bloom filters, which is not supported
            return orcTypeKind != OrcTypeKind.CHAR;
        }
        return true;
    }

    private Map readDiskRanges(long stripeOffset, Map diskRanges, AggregatedMemoryContext memoryUsage)
            throws IOException
    {
        //
        // Note: this code does not use the stream APIs to avoid any extra object allocation
        //

        // transform ranges to have an absolute offset in file
        ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
        for (Entry entry : diskRanges.entrySet()) {
            DiskRange diskRange = entry.getValue();
            diskRangesBuilder.put(entry.getKey(), new DiskRange(stripeOffset + diskRange.getOffset(), diskRange.getLength()));
        }
        diskRanges = diskRangesBuilder.buildOrThrow();

        // read ranges
        Map streamsData = orcDataSource.readFully(diskRanges);

        // transform streams to OrcInputStream
        ImmutableMap.Builder dataBuilder = ImmutableMap.builder();
        for (Entry entry : streamsData.entrySet()) {
            dataBuilder.put(entry.getKey(), OrcChunkLoader.create(entry.getValue(), decompressor, memoryUsage));
        }
        return dataBuilder.buildOrThrow();
    }

    private Map> createValueStreams(Map streams, Map streamsData, ColumnMetadata columnEncodings)
    {
        ImmutableMap.Builder> valueStreams = ImmutableMap.builder();
        for (Entry entry : streams.entrySet()) {
            StreamId streamId = entry.getKey();
            Stream stream = entry.getValue();
            ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumnId()).getColumnEncodingKind();

            // skip index and empty streams
            if (isIndexStream(stream) || stream.getLength() == 0) {
                continue;
            }

            OrcChunkLoader chunkLoader = streamsData.get(streamId);
            OrcTypeKind columnType = types.get(stream.getColumnId()).getOrcTypeKind();

            valueStreams.put(streamId, ValueStreams.createValueStreams(streamId, chunkLoader, columnType, columnEncoding));
        }
        return valueStreams.buildOrThrow();
    }

    private InputStreamSources createDictionaryStreamSources(Map streams, Map> valueStreams, ColumnMetadata columnEncodings)
    {
        ImmutableMap.Builder> dictionaryStreamBuilder = ImmutableMap.builder();
        for (Entry entry : streams.entrySet()) {
            StreamId streamId = entry.getKey();
            Stream stream = entry.getValue();
            OrcColumnId column = stream.getColumnId();

            // only process dictionary streams
            ColumnEncodingKind columnEncoding = columnEncodings.get(column).getColumnEncodingKind();
            if (!isDictionary(stream, columnEncoding)) {
                continue;
            }

            // skip streams without data
            ValueInputStream valueStream = valueStreams.get(streamId);
            if (valueStream == null) {
                continue;
            }

            OrcTypeKind columnType = types.get(stream.getColumnId()).getOrcTypeKind();
            StreamCheckpoint streamCheckpoint = getDictionaryStreamCheckpoint(streamId, columnType, columnEncoding);

            InputStreamSource streamSource = createCheckpointStreamSource(valueStream, streamCheckpoint);
            dictionaryStreamBuilder.put(streamId, streamSource);
        }
        return new InputStreamSources(dictionaryStreamBuilder.buildOrThrow());
    }

    private List createRowGroups(
            int rowsInStripe,
            Map streams,
            Map> valueStreams,
            Map> columnIndexes,
            Set selectedRowGroups,
            ColumnMetadata encodings)
            throws InvalidCheckpointException
    {
        int rowsInRowGroup = this.rowsInRowGroup.orElseThrow(() -> new IllegalStateException("Cannot create row groups if row group info is missing"));
        ImmutableList.Builder rowGroupBuilder = ImmutableList.builder();

        for (int rowGroupId : selectedRowGroups) {
            Map checkpoints = getStreamCheckpoints(includedOrcColumnIds, types, decompressor.isPresent(), rowGroupId, encodings, streams, columnIndexes);
            int rowOffset = rowGroupId * rowsInRowGroup;
            int rowsInGroup = Math.min(rowsInStripe - rowOffset, rowsInRowGroup);
            long minAverageRowBytes = columnIndexes
                    .entrySet()
                    .stream()
                    .mapToLong(e -> e.getValue()
                            .get(rowGroupId)
                            .getColumnStatistics()
                            .getMinAverageValueSizeInBytes())
                    .sum();
            rowGroupBuilder.add(createRowGroup(rowGroupId, rowOffset, rowsInGroup, minAverageRowBytes, valueStreams, checkpoints));
        }

        return rowGroupBuilder.build();
    }

    private static RowGroup createRowGroup(int groupId, int rowOffset, int rowCount, long minAverageRowBytes, Map> valueStreams, Map checkpoints)
    {
        ImmutableMap.Builder> builder = ImmutableMap.builder();
        for (Entry entry : checkpoints.entrySet()) {
            StreamId streamId = entry.getKey();
            StreamCheckpoint checkpoint = entry.getValue();

            // skip streams without data
            ValueInputStream valueStream = valueStreams.get(streamId);
            if (valueStream == null) {
                continue;
            }

            builder.put(streamId, createCheckpointStreamSource(valueStream, checkpoint));
        }
        InputStreamSources rowGroupStreams = new InputStreamSources(builder.buildOrThrow());
        return new RowGroup(groupId, rowOffset, rowCount, minAverageRowBytes, rowGroupStreams);
    }

    private StripeFooter readStripeFooter(StripeInformation stripe, AggregatedMemoryContext memoryUsage)
            throws IOException
    {
        long offset = stripe.getOffset() + stripe.getIndexLength() + stripe.getDataLength();
        int tailLength = toIntExact(stripe.getFooterLength());

        // read the footer
        Slice tailBuffer = orcDataSource.readFully(offset, tailLength);
        try (InputStream inputStream = new OrcInputStream(OrcChunkLoader.create(orcDataSource.getId(), tailBuffer, decompressor, memoryUsage))) {
            return metadataReader.readStripeFooter(types, inputStream, legacyFileTimeZone);
        }
    }

    static boolean isIndexStream(Stream stream)
    {
        return stream.getStreamKind() == ROW_INDEX || stream.getStreamKind() == DICTIONARY_COUNT || stream.getStreamKind() == BLOOM_FILTER || stream.getStreamKind() == BLOOM_FILTER_UTF8;
    }

    private Map> readBloomFilterIndexes(Map streams, Map streamsData)
            throws IOException
    {
        HashMap> bloomFilters = new HashMap<>();
        for (Entry entry : streams.entrySet()) {
            Stream stream = entry.getValue();
            if (stream.getStreamKind() == BLOOM_FILTER_UTF8 || stream.getStreamKind() == BLOOM_FILTER && !bloomFilters.containsKey(stream.getColumnId())) {
                OrcInputStream inputStream = new OrcInputStream(streamsData.get(entry.getKey()));
                bloomFilters.put(stream.getColumnId(), metadataReader.readBloomFilterIndexes(inputStream));
            }
        }
        return ImmutableMap.copyOf(bloomFilters);
    }

    private Map> readColumnIndexes(Map streams, Map streamsData, Map> bloomFilterIndexes)
            throws IOException
    {
        ImmutableMap.Builder> columnIndexes = ImmutableMap.builder();
        for (Entry entry : streams.entrySet()) {
            Stream stream = entry.getValue();
            if (stream.getStreamKind() == ROW_INDEX) {
                OrcInputStream inputStream = new OrcInputStream(streamsData.get(entry.getKey()));
                List bloomFilters = bloomFilterIndexes.get(entry.getKey().getColumnId());
                List rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, inputStream);
                if (bloomFilters != null && !bloomFilters.isEmpty()) {
                    ImmutableList.Builder newRowGroupIndexes = ImmutableList.builder();
                    for (int i = 0; i < rowGroupIndexes.size(); i++) {
                        RowGroupIndex rowGroupIndex = rowGroupIndexes.get(i);
                        ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics()
                                .withBloomFilter(bloomFilters.get(i));
                        newRowGroupIndexes.add(new RowGroupIndex(rowGroupIndex.getPositions(), columnStatistics));
                    }
                    rowGroupIndexes = newRowGroupIndexes.build();
                }
                columnIndexes.put(entry.getKey(), rowGroupIndexes);
            }
        }
        return columnIndexes.buildOrThrow();
    }

    private Set selectRowGroups(StripeInformation stripe, Map> columnIndexes)
    {
        int rowsInRowGroup = this.rowsInRowGroup.orElseThrow(() -> new IllegalStateException("Cannot create row groups if row group info is missing"));

        int rowsInStripe = stripe.getNumberOfRows();
        int groupsInStripe = ceil(rowsInStripe, rowsInRowGroup);

        ImmutableSet.Builder selectedRowGroups = ImmutableSet.builder();
        int remainingRows = rowsInStripe;
        for (int rowGroup = 0; rowGroup < groupsInStripe; ++rowGroup) {
            int rows = Math.min(remainingRows, rowsInRowGroup);
            ColumnMetadata statistics = getRowGroupStatistics(types, columnIndexes, rowGroup);
            if (predicate.matches(rows, statistics)) {
                selectedRowGroups.add(rowGroup);
            }
            remainingRows -= rows;
        }
        return selectedRowGroups.build();
    }

    private static ColumnMetadata getRowGroupStatistics(ColumnMetadata types, Map> columnIndexes, int rowGroup)
    {
        requireNonNull(columnIndexes, "columnIndexes is null");
        checkArgument(rowGroup >= 0, "rowGroup is negative");

        Map> rowGroupIndexesByColumn = columnIndexes.entrySet().stream()
                .collect(toImmutableMap(entry -> entry.getKey().getColumnId().getId(), Entry::getValue));

        List statistics = new ArrayList<>(types.size());
        for (int columnIndex = 0; columnIndex < types.size(); columnIndex++) {
            List rowGroupIndexes = rowGroupIndexesByColumn.get(columnIndex);
            if (rowGroupIndexes != null) {
                statistics.add(rowGroupIndexes.get(rowGroup).getColumnStatistics());
            }
            else {
                statistics.add(null);
            }
        }
        return new ColumnMetadata<>(statistics);
    }

    private static boolean isDictionary(Stream stream, ColumnEncodingKind columnEncoding)
    {
        return stream.getStreamKind() == DICTIONARY_DATA || (stream.getStreamKind() == LENGTH && (columnEncoding == DICTIONARY || columnEncoding == DICTIONARY_V2));
    }

    private static Map getDiskRanges(List streams)
    {
        ImmutableMap.Builder streamDiskRanges = ImmutableMap.builder();
        long stripeOffset = 0;
        for (Stream stream : streams) {
            int streamLength = stream.getLength();
            // ignore zero byte streams
            if (streamLength > 0) {
                streamDiskRanges.put(new StreamId(stream), new DiskRange(stripeOffset, streamLength));
            }
            stripeOffset += streamLength;
        }
        return streamDiskRanges.buildOrThrow();
    }

    private static Set getIncludeColumns(Set includedColumns)
    {
        Set result = new LinkedHashSet<>();
        includeColumnsRecursive(result, includedColumns);
        return result;
    }

    private static void includeColumnsRecursive(Set result, Collection readColumns)
    {
        for (OrcColumn column : readColumns) {
            result.add(column.getColumnId());
            includeColumnsRecursive(result, column.getNestedColumns());
        }
    }

    /**
     * Ceiling of integer division
     */
    private static int ceil(int dividend, int divisor)
    {
        return ((dividend + divisor) - 1) / divisor;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy