
com.facebook.presto.orc.StripeReader Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc;
import com.facebook.presto.memory.context.AggregatedMemoryContext;
import com.facebook.presto.orc.checkpoint.InvalidCheckpointException;
import com.facebook.presto.orc.checkpoint.StreamCheckpoint;
import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind;
import com.facebook.presto.orc.metadata.DwrfSequenceEncoding;
import com.facebook.presto.orc.metadata.MetadataReader;
import com.facebook.presto.orc.metadata.OrcType;
import com.facebook.presto.orc.metadata.OrcType.OrcTypeKind;
import com.facebook.presto.orc.metadata.PostScript.HiveWriterVersion;
import com.facebook.presto.orc.metadata.RowGroupIndex;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.metadata.Stream.StreamKind;
import com.facebook.presto.orc.metadata.StripeFooter;
import com.facebook.presto.orc.metadata.StripeInformation;
import com.facebook.presto.orc.metadata.statistics.ColumnStatistics;
import com.facebook.presto.orc.metadata.statistics.HiveBloomFilter;
import com.facebook.presto.orc.stream.InputStreamSource;
import com.facebook.presto.orc.stream.InputStreamSources;
import com.facebook.presto.orc.stream.OrcInputStream;
import com.facebook.presto.orc.stream.ValueInputStream;
import com.facebook.presto.orc.stream.ValueInputStreamSource;
import com.facebook.presto.orc.stream.ValueStreams;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import io.airlift.slice.Slices;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.orc.checkpoint.Checkpoints.getDictionaryStreamCheckpoint;
import static com.facebook.presto.orc.checkpoint.Checkpoints.getStreamCheckpoints;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY;
import static com.facebook.presto.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY_V2;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.BLOOM_FILTER;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.BLOOM_FILTER_UTF8;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.DICTIONARY_COUNT;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.DICTIONARY_DATA;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.LENGTH;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.ROW_INDEX;
import static com.facebook.presto.orc.metadata.statistics.ColumnStatistics.mergeColumnStatistics;
import static com.facebook.presto.orc.stream.CheckpointInputStreamSource.createCheckpointStreamSource;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
public class StripeReader
{
private final OrcDataSource orcDataSource;
private final Optional decompressor;
private final List types;
private final HiveWriterVersion hiveWriterVersion;
private final Set includedOrcColumns;
private final int rowsInRowGroup;
private final OrcPredicate predicate;
private final MetadataReader metadataReader;
private final Optional writeValidation;
public StripeReader(OrcDataSource orcDataSource,
Optional decompressor,
List types,
Set includedColumns,
int rowsInRowGroup,
OrcPredicate predicate,
HiveWriterVersion hiveWriterVersion,
MetadataReader metadataReader,
Optional writeValidation)
{
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
this.decompressor = requireNonNull(decompressor, "decompressor is null");
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.includedOrcColumns = getIncludedOrcColumns(types, requireNonNull(includedColumns, "includedColumns is null"));
this.rowsInRowGroup = rowsInRowGroup;
this.predicate = requireNonNull(predicate, "predicate is null");
this.hiveWriterVersion = requireNonNull(hiveWriterVersion, "hiveWriterVersion is null");
this.metadataReader = requireNonNull(metadataReader, "metadataReader is null");
this.writeValidation = requireNonNull(writeValidation, "writeValidation is null");
}
public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext systemMemoryUsage)
throws IOException
{
// read the stripe footer
StripeFooter stripeFooter = readStripeFooter(stripe, systemMemoryUsage);
List columnEncodings = stripeFooter.getColumnEncodings();
// get streams for selected columns
Map streams = new HashMap<>();
boolean hasRowGroupDictionary = false;
for (Stream stream : stripeFooter.getStreams()) {
if (includedOrcColumns.contains(stream.getColumn())) {
streams.put(new StreamId(stream), stream);
if (stream.getStreamKind() == StreamKind.IN_DICTIONARY) {
ColumnEncoding columnEncoding = columnEncodings.get(stream.getColumn());
if (columnEncoding.getColumnEncodingKind() == DICTIONARY) {
hasRowGroupDictionary = true;
}
Optional> additionalSequenceEncodings = columnEncoding.getAdditionalSequenceEncodings();
if (additionalSequenceEncodings.isPresent()
&& additionalSequenceEncodings.get().stream()
.map(DwrfSequenceEncoding::getValueEncoding)
.anyMatch(encoding -> encoding.getColumnEncodingKind() == DICTIONARY)) {
hasRowGroupDictionary = true;
}
}
}
}
// handle stripes with more than one row group or a dictionary
boolean invalidCheckPoint = false;
if ((stripe.getNumberOfRows() > rowsInRowGroup) || hasRowGroupDictionary) {
// determine ranges of the stripe to read
Map diskRanges = getDiskRanges(stripeFooter.getStreams());
diskRanges = Maps.filterKeys(diskRanges, Predicates.in(streams.keySet()));
// read the file regions
Map streamsData = readDiskRanges(stripe.getOffset(), diskRanges, systemMemoryUsage);
// read the bloom filter for each column
Map> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData);
// read the row index for each column
Map> columnIndexes = readColumnIndexes(streams, streamsData, bloomFilterIndexes);
if (writeValidation.isPresent()) {
writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), stripe.getOffset(), columnIndexes);
}
// select the row groups matching the tuple domain
Set selectedRowGroups = selectRowGroups(stripe, columnIndexes);
// if all row groups are skipped, return null
if (selectedRowGroups.isEmpty()) {
// set accounted memory usage to zero
systemMemoryUsage.close();
return null;
}
// value streams
Map> valueStreams = createValueStreams(streams, streamsData, columnEncodings);
// build the dictionary streams
InputStreamSources dictionaryStreamSources = createDictionaryStreamSources(streams, valueStreams, columnEncodings);
// build the row groups
try {
List rowGroups = createRowGroups(
stripe.getNumberOfRows(),
streams,
valueStreams,
columnIndexes,
selectedRowGroups,
columnEncodings);
return new Stripe(stripe.getNumberOfRows(), columnEncodings, rowGroups, dictionaryStreamSources);
}
catch (InvalidCheckpointException e) {
// The ORC file contains a corrupt checkpoint stream
// If the file does not have a row group dictionary, treat the stripe as a single row group. Otherwise,
// we must fail because the length of the row group dictionary is contained in the checkpoint stream.
if (hasRowGroupDictionary) {
throw new OrcCorruptionException(e, orcDataSource.getId(), "Checkpoints are corrupt");
}
invalidCheckPoint = true;
}
}
// stripe only has one row group and no dictionary
ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
for (Entry entry : getDiskRanges(stripeFooter.getStreams()).entrySet()) {
StreamId streamId = entry.getKey();
if (streams.keySet().contains(streamId)) {
diskRangesBuilder.put(entry);
}
}
ImmutableMap diskRanges = diskRangesBuilder.build();
// read the file regions
Map streamsData = readDiskRanges(stripe.getOffset(), diskRanges, systemMemoryUsage);
long minAverageRowBytes = 0;
for (Entry entry : streams.entrySet()) {
if (entry.getKey().getStreamKind() == ROW_INDEX) {
List rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, streamsData.get(entry.getKey()));
checkState(rowGroupIndexes.size() == 1 || invalidCheckPoint, "expect a single row group or an invalid check point");
long totalBytes = 0;
long totalRows = 0;
for (RowGroupIndex rowGroupIndex : rowGroupIndexes) {
ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics();
if (columnStatistics.hasMinAverageValueSizeInBytes()) {
totalBytes += columnStatistics.getMinAverageValueSizeInBytes() * columnStatistics.getNumberOfValues();
totalRows += columnStatistics.getNumberOfValues();
}
}
if (totalRows > 0) {
minAverageRowBytes += totalBytes / totalRows;
}
}
}
// value streams
Map> valueStreams = createValueStreams(streams, streamsData, columnEncodings);
// build the dictionary streams
InputStreamSources dictionaryStreamSources = createDictionaryStreamSources(streams, valueStreams, columnEncodings);
// build the row group
ImmutableMap.Builder> builder = ImmutableMap.builder();
for (Entry> entry : valueStreams.entrySet()) {
builder.put(entry.getKey(), new ValueInputStreamSource<>(entry.getValue()));
}
RowGroup rowGroup = new RowGroup(0, 0, stripe.getNumberOfRows(), minAverageRowBytes, new InputStreamSources(builder.build()));
return new Stripe(stripe.getNumberOfRows(), columnEncodings, ImmutableList.of(rowGroup), dictionaryStreamSources);
}
public Map readDiskRanges(long stripeOffset, Map diskRanges, AggregatedMemoryContext systemMemoryUsage)
throws IOException
{
//
// Note: this code does not use the Java 8 stream APIs to avoid any extra object allocation
//
// transform ranges to have an absolute offset in file
ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder();
for (Entry entry : diskRanges.entrySet()) {
DiskRange diskRange = entry.getValue();
diskRangesBuilder.put(entry.getKey(), new DiskRange(stripeOffset + diskRange.getOffset(), diskRange.getLength()));
}
diskRanges = diskRangesBuilder.build();
// read ranges
Map streamsData = orcDataSource.readFully(diskRanges);
// transform streams to OrcInputStream
ImmutableMap.Builder streamsBuilder = ImmutableMap.builder();
for (Entry entry : streamsData.entrySet()) {
OrcDataSourceInput sourceInput = entry.getValue();
streamsBuilder.put(entry.getKey(), new OrcInputStream(orcDataSource.getId(), sourceInput.getInput(), decompressor, systemMemoryUsage, sourceInput.getRetainedSizeInBytes()));
}
return streamsBuilder.build();
}
private Map> createValueStreams(Map streams, Map streamsData, List columnEncodings)
{
ImmutableMap.Builder> valueStreams = ImmutableMap.builder();
for (Entry entry : streams.entrySet()) {
StreamId streamId = entry.getKey();
Stream stream = entry.getValue();
ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumn())
.getColumnEncoding(stream.getSequence())
.getColumnEncodingKind();
// skip index and empty streams
if (isIndexStream(stream) || stream.getLength() == 0) {
continue;
}
OrcInputStream inputStream = streamsData.get(streamId);
OrcTypeKind columnType = types.get(stream.getColumn()).getOrcTypeKind();
valueStreams.put(streamId, ValueStreams.createValueStreams(streamId, inputStream, columnType, columnEncoding, stream.isUseVInts()));
}
return valueStreams.build();
}
public InputStreamSources createDictionaryStreamSources(Map streams, Map> valueStreams, List columnEncodings)
{
ImmutableMap.Builder> dictionaryStreamBuilder = ImmutableMap.builder();
for (Entry entry : streams.entrySet()) {
StreamId streamId = entry.getKey();
Stream stream = entry.getValue();
int column = stream.getColumn();
// only process dictionary streams
ColumnEncodingKind columnEncoding = columnEncodings.get(column)
.getColumnEncoding(stream.getSequence())
.getColumnEncodingKind();
if (!isDictionary(stream, columnEncoding)) {
continue;
}
// skip streams without data
ValueInputStream> valueStream = valueStreams.get(streamId);
if (valueStream == null) {
continue;
}
OrcTypeKind columnType = types.get(stream.getColumn()).getOrcTypeKind();
StreamCheckpoint streamCheckpoint = getDictionaryStreamCheckpoint(streamId, columnType, columnEncoding);
InputStreamSource> streamSource = createCheckpointStreamSource(valueStream, streamCheckpoint);
dictionaryStreamBuilder.put(streamId, streamSource);
}
return new InputStreamSources(dictionaryStreamBuilder.build());
}
private List createRowGroups(
int rowsInStripe,
Map streams,
Map> valueStreams,
Map> columnIndexes,
Set selectedRowGroups,
List encodings)
throws InvalidCheckpointException
{
ImmutableList.Builder rowGroupBuilder = ImmutableList.builder();
for (int rowGroupId : selectedRowGroups) {
Map checkpoints = getStreamCheckpoints(includedOrcColumns, types, decompressor.isPresent(), rowGroupId, encodings, streams, columnIndexes);
int rowOffset = rowGroupId * rowsInRowGroup;
int rowsInGroup = Math.min(rowsInStripe - rowOffset, rowsInRowGroup);
long minAverageRowBytes = columnIndexes
.entrySet()
.stream()
.mapToLong(e -> e.getValue()
.get(rowGroupId)
.getColumnStatistics()
.getMinAverageValueSizeInBytes())
.sum();
rowGroupBuilder.add(createRowGroup(rowGroupId, rowOffset, rowsInGroup, minAverageRowBytes, valueStreams, checkpoints));
}
return rowGroupBuilder.build();
}
public static RowGroup createRowGroup(int groupId, int rowOffset, int rowCount, long minAverageRowBytes, Map> valueStreams, Map checkpoints)
{
ImmutableMap.Builder> builder = ImmutableMap.builder();
for (Entry entry : checkpoints.entrySet()) {
StreamId streamId = entry.getKey();
StreamCheckpoint checkpoint = entry.getValue();
// skip streams without data
ValueInputStream> valueStream = valueStreams.get(streamId);
if (valueStream == null) {
continue;
}
builder.put(streamId, createCheckpointStreamSource(valueStream, checkpoint));
}
InputStreamSources rowGroupStreams = new InputStreamSources(builder.build());
return new RowGroup(groupId, rowOffset, rowCount, minAverageRowBytes, rowGroupStreams);
}
public StripeFooter readStripeFooter(StripeInformation stripe, AggregatedMemoryContext systemMemoryUsage)
throws IOException
{
long offset = stripe.getOffset() + stripe.getIndexLength() + stripe.getDataLength();
int tailLength = toIntExact(stripe.getFooterLength());
// read the footer
byte[] tailBuffer = new byte[tailLength];
orcDataSource.readFully(offset, tailBuffer);
try (InputStream inputStream = new OrcInputStream(orcDataSource.getId(), Slices.wrappedBuffer(tailBuffer).getInput(), decompressor, systemMemoryUsage, tailLength)) {
return metadataReader.readStripeFooter(types, inputStream);
}
}
static boolean isIndexStream(Stream stream)
{
return stream.getStreamKind() == ROW_INDEX || stream.getStreamKind() == DICTIONARY_COUNT || stream.getStreamKind() == BLOOM_FILTER || stream.getStreamKind() == BLOOM_FILTER_UTF8;
}
private Map> readBloomFilterIndexes(Map streams, Map streamsData)
throws IOException
{
ImmutableMap.Builder> bloomFilters = ImmutableMap.builder();
for (Entry entry : streams.entrySet()) {
Stream stream = entry.getValue();
if (stream.getStreamKind() == BLOOM_FILTER) {
OrcInputStream inputStream = streamsData.get(entry.getKey());
bloomFilters.put(entry.getKey(), metadataReader.readBloomFilterIndexes(inputStream));
}
// TODO: add support for BLOOM_FILTER_UTF8
}
return bloomFilters.build();
}
private Map> readColumnIndexes(Map streams, Map streamsData, Map> bloomFilterIndexes)
throws IOException
{
ImmutableMap.Builder> columnIndexes = ImmutableMap.builder();
for (Entry entry : streams.entrySet()) {
Stream stream = entry.getValue();
if (stream.getStreamKind() == ROW_INDEX) {
OrcInputStream inputStream = streamsData.get(entry.getKey());
List bloomFilters = bloomFilterIndexes.get(entry.getKey());
List rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, inputStream);
if (bloomFilters != null && !bloomFilters.isEmpty()) {
ImmutableList.Builder newRowGroupIndexes = ImmutableList.builder();
for (int i = 0; i < rowGroupIndexes.size(); i++) {
RowGroupIndex rowGroupIndex = rowGroupIndexes.get(i);
ColumnStatistics columnStatistics = rowGroupIndex.getColumnStatistics()
.withBloomFilter(bloomFilters.get(i));
newRowGroupIndexes.add(new RowGroupIndex(rowGroupIndex.getPositions(), columnStatistics));
}
rowGroupIndexes = newRowGroupIndexes.build();
}
columnIndexes.put(entry.getKey(), rowGroupIndexes);
}
}
return columnIndexes.build();
}
private Set selectRowGroups(StripeInformation stripe, Map> columnIndexes)
{
int rowsInStripe = toIntExact(stripe.getNumberOfRows());
int groupsInStripe = ceil(rowsInStripe, rowsInRowGroup);
ImmutableSet.Builder selectedRowGroups = ImmutableSet.builder();
int remainingRows = rowsInStripe;
for (int rowGroup = 0; rowGroup < groupsInStripe; ++rowGroup) {
int rows = Math.min(remainingRows, rowsInRowGroup);
Map statistics = getRowGroupStatistics(types.get(0), columnIndexes, rowGroup);
if (predicate.matches(rows, statistics)) {
selectedRowGroups.add(rowGroup);
}
remainingRows -= rows;
}
return selectedRowGroups.build();
}
private static Map getRowGroupStatistics(OrcType rootStructType, Map> columnIndexes, int rowGroup)
{
requireNonNull(rootStructType, "rootStructType is null");
checkArgument(rootStructType.getOrcTypeKind() == OrcTypeKind.STRUCT);
requireNonNull(columnIndexes, "columnIndexes is null");
checkArgument(rowGroup >= 0, "rowGroup is negative");
Map> groupedColumnStatistics = new HashMap<>();
for (Entry> entry : columnIndexes.entrySet()) {
groupedColumnStatistics.computeIfAbsent(entry.getKey().getColumn(), key -> new ArrayList<>())
.add(entry.getValue().get(rowGroup).getColumnStatistics());
}
ImmutableMap.Builder statistics = ImmutableMap.builder();
for (int ordinal = 0; ordinal < rootStructType.getFieldCount(); ordinal++) {
List columnStatistics = groupedColumnStatistics.get(rootStructType.getFieldTypeIndex(ordinal));
if (columnStatistics != null) {
if (columnStatistics.size() == 1) {
statistics.put(ordinal, getOnlyElement(columnStatistics));
}
else {
// Merge statistics from different streams
// This can happen if map is represented as struct (DWRF only)
statistics.put(ordinal, mergeColumnStatistics(columnStatistics));
}
}
}
return statistics.build();
}
private static boolean isDictionary(Stream stream, ColumnEncodingKind columnEncoding)
{
return stream.getStreamKind() == DICTIONARY_DATA || (stream.getStreamKind() == LENGTH && (columnEncoding == DICTIONARY || columnEncoding == DICTIONARY_V2));
}
private static Map getDiskRanges(List streams)
{
ImmutableMap.Builder streamDiskRanges = ImmutableMap.builder();
long stripeOffset = 0;
for (Stream stream : streams) {
int streamLength = toIntExact(stream.getLength());
// ignore zero byte streams
if (streamLength > 0) {
streamDiskRanges.put(new StreamId(stream), new DiskRange(stripeOffset, streamLength));
}
stripeOffset += streamLength;
}
return streamDiskRanges.build();
}
private static Set getIncludedOrcColumns(List types, Set includedColumns)
{
Set includes = new LinkedHashSet<>();
OrcType root = types.get(0);
for (int includedColumn : includedColumns) {
includeOrcColumnsRecursive(types, includes, root.getFieldTypeIndex(includedColumn));
}
return includes;
}
private static void includeOrcColumnsRecursive(List types, Set result, int typeId)
{
result.add(typeId);
OrcType type = types.get(typeId);
int children = type.getFieldCount();
for (int i = 0; i < children; ++i) {
includeOrcColumnsRecursive(types, result, type.getFieldTypeIndex(i));
}
}
/**
* Ceiling of integer division
*/
private static int ceil(int dividend, int divisor)
{
return ((dividend + divisor) - 1) / divisor;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy