All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.orc.OrcWriteValidation Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.orc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.trino.orc.metadata.ColumnMetadata;
import io.trino.orc.metadata.CompressionKind;
import io.trino.orc.metadata.OrcColumnId;
import io.trino.orc.metadata.OrcType;
import io.trino.orc.metadata.PostScript.HiveWriterVersion;
import io.trino.orc.metadata.RowGroupIndex;
import io.trino.orc.metadata.StripeInformation;
import io.trino.orc.metadata.statistics.BinaryStatisticsBuilder;
import io.trino.orc.metadata.statistics.BooleanStatisticsBuilder;
import io.trino.orc.metadata.statistics.ColumnStatistics;
import io.trino.orc.metadata.statistics.DateStatisticsBuilder;
import io.trino.orc.metadata.statistics.DoubleStatisticsBuilder;
import io.trino.orc.metadata.statistics.IntegerStatistics;
import io.trino.orc.metadata.statistics.IntegerStatisticsBuilder;
import io.trino.orc.metadata.statistics.LongDecimalStatisticsBuilder;
import io.trino.orc.metadata.statistics.NoOpBloomFilterBuilder;
import io.trino.orc.metadata.statistics.ShortDecimalStatisticsBuilder;
import io.trino.orc.metadata.statistics.StatisticsBuilder;
import io.trino.orc.metadata.statistics.StatisticsHasher;
import io.trino.orc.metadata.statistics.StringStatistics;
import io.trino.orc.metadata.statistics.StringStatisticsBuilder;
import io.trino.orc.metadata.statistics.StripeStatistics;
import io.trino.orc.metadata.statistics.TimeMicrosStatisticsBuilder;
import io.trino.orc.metadata.statistics.TimestampStatisticsBuilder;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.ColumnarMap;
import io.trino.spi.block.RowBlock;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.trino.orc.OrcWriteValidation.OrcWriteValidationMode.BOTH;
import static io.trino.orc.OrcWriteValidation.OrcWriteValidationMode.DETAILED;
import static io.trino.orc.OrcWriteValidation.OrcWriteValidationMode.HASHED;
import static io.trino.orc.metadata.OrcColumnId.ROOT_COLUMN;
import static io.trino.orc.metadata.OrcMetadataReader.maxStringTruncateToValidRange;
import static io.trino.orc.metadata.OrcMetadataReader.minStringTruncateToValidRange;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.block.ColumnarArray.toColumnarArray;
import static io.trino.spi.block.ColumnarMap.toColumnarMap;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimeType.TIME_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_NANOS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.UuidType.UUID;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static java.lang.Math.floorDiv;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class OrcWriteValidation
{
    public enum OrcWriteValidationMode
    {
        HASHED, DETAILED, BOTH
    }

    private final List version;
    private final CompressionKind compression;
    private final ZoneId timeZone;
    private final int rowGroupMaxRowCount;
    private final List columnNames;
    private final Map metadata;
    private final WriteChecksum checksum;
    private final Map> rowGroupStatistics;
    private final Map stripeStatistics;
    private final Optional> fileStatistics;
    private final int stringStatisticsLimitInBytes;

    private OrcWriteValidation(
            List version,
            CompressionKind compression,
            ZoneId timeZone,
            int rowGroupMaxRowCount,
            List columnNames,
            Map metadata,
            WriteChecksum checksum,
            Map> rowGroupStatistics,
            Map stripeStatistics,
            Optional> fileStatistics,
            int stringStatisticsLimitInBytes)
    {
        this.version = version;
        this.compression = compression;
        this.timeZone = timeZone;
        this.rowGroupMaxRowCount = rowGroupMaxRowCount;
        this.columnNames = columnNames;
        this.metadata = metadata;
        this.checksum = checksum;
        this.rowGroupStatistics = rowGroupStatistics;
        this.stripeStatistics = stripeStatistics;
        this.fileStatistics = fileStatistics;
        this.stringStatisticsLimitInBytes = stringStatisticsLimitInBytes;
    }

    public List getVersion()
    {
        return version;
    }

    public CompressionKind getCompression()
    {
        return compression;
    }

    public ZoneId getTimeZone()
    {
        return timeZone;
    }

    public void validateTimeZone(OrcDataSourceId orcDataSourceId, ZoneId actualTimeZone)
            throws OrcCorruptionException
    {
        if (!timeZone.equals(actualTimeZone)) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected time zone");
        }
    }

    public int getRowGroupMaxRowCount()
    {
        return rowGroupMaxRowCount;
    }

    public List getColumnNames()
    {
        return columnNames;
    }

    public Map getMetadata()
    {
        return metadata;
    }

    public void validateMetadata(OrcDataSourceId orcDataSourceId, Map actualMetadata)
            throws OrcCorruptionException
    {
        if (!metadata.equals(actualMetadata)) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected metadata");
        }
    }

    public WriteChecksum getChecksum()
    {
        return checksum;
    }

    public void validateFileStatistics(OrcDataSourceId orcDataSourceId, Optional> actualFileStatistics)
            throws OrcCorruptionException
    {
        // file stats will be absent when no rows are written
        if (fileStatistics.isEmpty()) {
            if (actualFileStatistics.isPresent()) {
                throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected file statistics");
            }
            return;
        }
        if (actualFileStatistics.isEmpty()) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: expected file statistics");
        }

        validateColumnStatisticsEquivalent(orcDataSourceId, "file", actualFileStatistics.get(), fileStatistics.get());
    }

    public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, List actualStripes, List> actualStripeStatistics)
            throws OrcCorruptionException
    {
        requireNonNull(actualStripes, "actualStripes is null");
        requireNonNull(actualStripeStatistics, "actualStripeStatistics is null");

        if (actualStripeStatistics.size() != stripeStatistics.size()) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected number of columns in stripe statistics");
        }

        for (int stripeIndex = 0; stripeIndex < actualStripes.size(); stripeIndex++) {
            long stripeOffset = actualStripes.get(stripeIndex).getOffset();
            StripeStatistics actual = actualStripeStatistics.get(stripeIndex).get();
            validateStripeStatistics(orcDataSourceId, stripeOffset, actual.getColumnStatistics());
        }
    }

    public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, long stripeOffset, ColumnMetadata actual)
            throws OrcCorruptionException
    {
        StripeStatistics expected = stripeStatistics.get(stripeOffset);
        if (expected == null) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected stripe at offset %s", stripeOffset);
        }
        validateColumnStatisticsEquivalent(orcDataSourceId, "Stripe at " + stripeOffset, actual, expected.getColumnStatistics());
    }

    public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long stripeOffset, Map> actualRowGroupStatistics)
            throws OrcCorruptionException
    {
        requireNonNull(actualRowGroupStatistics, "actualRowGroupStatistics is null");
        List expectedRowGroupStatistics = rowGroupStatistics.get(stripeOffset);
        if (expectedRowGroupStatistics == null) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected stripe at offset %s", stripeOffset);
        }

        int rowGroupCount = expectedRowGroupStatistics.size();
        for (Entry> entry : actualRowGroupStatistics.entrySet()) {
            if (entry.getValue().size() != rowGroupCount) {
                throw new OrcCorruptionException(orcDataSourceId, "Unexpected row group count stripe in at offset %s", stripeOffset);
            }
        }

        for (int rowGroupIndex = 0; rowGroupIndex < expectedRowGroupStatistics.size(); rowGroupIndex++) {
            RowGroupStatistics expectedRowGroup = expectedRowGroupStatistics.get(rowGroupIndex);
            if (expectedRowGroup.getValidationMode() != HASHED) {
                Map expectedStatistics = expectedRowGroup.getColumnStatistics();
                Set actualColumns = actualRowGroupStatistics.keySet().stream()
                        .map(StreamId::getColumnId)
                        .collect(Collectors.toSet());
                if (!expectedStatistics.keySet().equals(actualColumns)) {
                    throw new OrcCorruptionException(orcDataSourceId, "Unexpected column in row group %s in stripe at offset %s", rowGroupIndex, stripeOffset);
                }
                for (Entry> entry : actualRowGroupStatistics.entrySet()) {
                    ColumnStatistics actual = entry.getValue().get(rowGroupIndex).getColumnStatistics();
                    ColumnStatistics expected = expectedStatistics.get(entry.getKey().getColumnId());
                    validateColumnStatisticsEquivalent(orcDataSourceId, "Row group " + rowGroupIndex + " in stripe at offset " + stripeOffset, actual, expected);
                }
            }

            if (expectedRowGroup.getValidationMode() != DETAILED) {
                RowGroupStatistics actualRowGroup = buildActualRowGroupStatistics(rowGroupIndex, actualRowGroupStatistics);
                if (expectedRowGroup.getHash() != actualRowGroup.getHash()) {
                    throw new OrcCorruptionException(orcDataSourceId, "Checksum mismatch for row group %s in stripe at offset %s", rowGroupIndex, stripeOffset);
                }
            }
        }
    }

    private static RowGroupStatistics buildActualRowGroupStatistics(int rowGroupIndex, Map> actualRowGroupStatistics)
    {
        return new RowGroupStatistics(
                BOTH,
                actualRowGroupStatistics.entrySet()
                        .stream()
                        .collect(Collectors.toMap(entry -> entry.getKey().getColumnId(), entry -> entry.getValue().get(rowGroupIndex).getColumnStatistics())));
    }

    public void validateRowGroupStatistics(
            OrcDataSourceId orcDataSourceId,
            long stripeOffset,
            int rowGroupIndex,
            ColumnMetadata actual)
            throws OrcCorruptionException
    {
        List rowGroups = rowGroupStatistics.get(stripeOffset);
        if (rowGroups == null) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected stripe at offset %s", stripeOffset);
        }
        if (rowGroups.size() <= rowGroupIndex) {
            throw new OrcCorruptionException(orcDataSourceId, "Unexpected row group %s in stripe at offset %s", rowGroupIndex, stripeOffset);
        }

        RowGroupStatistics expectedRowGroup = rowGroups.get(rowGroupIndex);
        RowGroupStatistics actualRowGroup = new RowGroupStatistics(BOTH, IntStream.range(1, actual.size()).mapToObj(OrcColumnId::new).collect(toImmutableMap(identity(), actual::get)));

        if (expectedRowGroup.getValidationMode() != HASHED) {
            Map expectedByColumnIndex = expectedRowGroup.getColumnStatistics();

            // new writer does not write row group stats for column zero (table row column)
            ColumnMetadata expected = new ColumnMetadata<>(IntStream.range(1, actual.size())
                    .mapToObj(OrcColumnId::new)
                    .map(expectedByColumnIndex::get)
                    .collect(toImmutableList()));
            actual = new ColumnMetadata<>(actual.stream()
                    .skip(1)
                    .collect(toImmutableList()));

            validateColumnStatisticsEquivalent(orcDataSourceId, "Row group " + rowGroupIndex + " in stripe at offset " + stripeOffset, actual, expected);
        }

        if (expectedRowGroup.getValidationMode() != DETAILED) {
            if (expectedRowGroup.getHash() != actualRowGroup.getHash()) {
                throw new OrcCorruptionException(orcDataSourceId, "Checksum mismatch for row group %s in stripe at offset %s", rowGroupIndex, stripeOffset);
            }
        }
    }

    public StatisticsValidation createWriteStatisticsBuilder(ColumnMetadata orcTypes, List readTypes)
    {
        checkArgument(readTypes.size() == orcTypes.get(ROOT_COLUMN).getFieldCount(), "statistics validation requires all columns to be read");
        return new StatisticsValidation(readTypes);
    }

    private static void validateColumnStatisticsEquivalent(
            OrcDataSourceId orcDataSourceId,
            String name,
            ColumnMetadata actualColumnStatistics,
            ColumnMetadata expectedColumnStatistics)
            throws OrcCorruptionException
    {
        requireNonNull(name, "name is null");
        requireNonNull(actualColumnStatistics, "actualColumnStatistics is null");
        requireNonNull(expectedColumnStatistics, "expectedColumnStatistics is null");
        if (actualColumnStatistics.size() != expectedColumnStatistics.size()) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected number of columns in %s statistics", name);
        }
        for (int i = 0; i < actualColumnStatistics.size(); i++) {
            OrcColumnId columnId = new OrcColumnId(i);
            ColumnStatistics actual = actualColumnStatistics.get(columnId);
            ColumnStatistics expected = expectedColumnStatistics.get(columnId);
            validateColumnStatisticsEquivalent(orcDataSourceId, name + " column " + i, actual, expected);
        }
    }

    private static void validateColumnStatisticsEquivalent(
            OrcDataSourceId orcDataSourceId,
            String name,
            ColumnStatistics actualColumnStatistics,
            ColumnStatistics expectedColumnStatistics)
            throws OrcCorruptionException
    {
        requireNonNull(name, "name is null");
        requireNonNull(actualColumnStatistics, "actualColumnStatistics is null");
        requireNonNull(expectedColumnStatistics, "expectedColumnStatistics is null");

        if (actualColumnStatistics.getNumberOfValues() != expectedColumnStatistics.getNumberOfValues()) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected number of values in %s statistics", name);
        }
        if (!Objects.equals(actualColumnStatistics.getBooleanStatistics(), expectedColumnStatistics.getBooleanStatistics())) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected boolean counts in %s statistics", name);
        }
        if (!Objects.equals(actualColumnStatistics.getIntegerStatistics(), expectedColumnStatistics.getIntegerStatistics())) {
            IntegerStatistics actualIntegerStatistics = actualColumnStatistics.getIntegerStatistics();
            IntegerStatistics expectedIntegerStatistics = expectedColumnStatistics.getIntegerStatistics();
            // The sum of the integer stats depends on the order of how we merge them.
            // It is possible the sum can overflow with one order but not in another.
            // Ignore the validation of sum if one of the two sums is null.
            if (actualIntegerStatistics == null ||
                    expectedIntegerStatistics == null ||
                    !Objects.equals(actualIntegerStatistics.getMin(), expectedIntegerStatistics.getMin()) ||
                    !Objects.equals(actualIntegerStatistics.getMax(), expectedIntegerStatistics.getMax()) ||
                    (actualIntegerStatistics.getSum() != null &&
                            expectedIntegerStatistics.getSum() != null &&
                            !Objects.equals(actualIntegerStatistics.getSum(), expectedIntegerStatistics.getSum()))) {
                throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected integer range in %s statistics", name);
            }
        }
        if (!Objects.equals(actualColumnStatistics.getDoubleStatistics(), expectedColumnStatistics.getDoubleStatistics())) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected double range in %s statistics", name);
        }
        StringStatistics expectedStringStatistics = expectedColumnStatistics.getStringStatistics();
        if (expectedStringStatistics != null) {
            expectedStringStatistics = new StringStatistics(
                    minStringTruncateToValidRange(expectedStringStatistics.getMin(), HiveWriterVersion.ORC_HIVE_8732),
                    maxStringTruncateToValidRange(expectedStringStatistics.getMax(), HiveWriterVersion.ORC_HIVE_8732),
                    expectedStringStatistics.getSum());
        }
        StringStatistics actualStringStatistics = actualColumnStatistics.getStringStatistics();
        if (!Objects.equals(actualColumnStatistics.getStringStatistics(), expectedStringStatistics) && expectedStringStatistics != null) {
            // expectedStringStatistics (or the min/max of it) could be null while the actual one might not because
            // expectedStringStatistics is calculated by merging all row group stats in the stripe but the actual one is by scanning each row in the stripe on disk.
            // Merging row group stats can produce nulls given we have string stats limit.
            if (actualStringStatistics == null ||
                    actualStringStatistics.getSum() != expectedStringStatistics.getSum() ||
                    (expectedStringStatistics.getMax() != null && !Objects.equals(actualStringStatistics.getMax(), expectedStringStatistics.getMax())) ||
                    (expectedStringStatistics.getMin() != null && !Objects.equals(actualStringStatistics.getMin(), expectedStringStatistics.getMin()))) {
                throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected string range in %s statistics", name);
            }
        }
        if (!Objects.equals(actualColumnStatistics.getDateStatistics(), expectedColumnStatistics.getDateStatistics())) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected date range in %s statistics", name);
        }
        if (!Objects.equals(actualColumnStatistics.getDecimalStatistics(), expectedColumnStatistics.getDecimalStatistics())) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected decimal range in %s statistics", name);
        }
        if (!Objects.equals(actualColumnStatistics.getBloomFilter(), expectedColumnStatistics.getBloomFilter())) {
            throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected bloom filter in %s statistics", name);
        }
    }

    public static class WriteChecksum
    {
        private final long totalRowCount;
        private final long stripeHash;
        private final List columnHashes;

        public WriteChecksum(long totalRowCount, long stripeHash, List columnHashes)
        {
            this.totalRowCount = totalRowCount;
            this.stripeHash = stripeHash;
            this.columnHashes = columnHashes;
        }

        public long getTotalRowCount()
        {
            return totalRowCount;
        }

        public long getStripeHash()
        {
            return stripeHash;
        }

        public List getColumnHashes()
        {
            return columnHashes;
        }
    }

    public static class WriteChecksumBuilder
    {
        private final List validationHashes;
        private long totalRowCount;
        private final List columnHashes;
        private final XxHash64 stripeHash = new XxHash64();

        private final byte[] longBuffer = new byte[Long.BYTES];
        private final Slice longSlice = Slices.wrappedBuffer(longBuffer);

        private WriteChecksumBuilder(List types)
        {
            this.validationHashes = types.stream()
                    .map(ValidationHash::createValidationHash)
                    .collect(toImmutableList());

            ImmutableList.Builder columnHashes = ImmutableList.builder();
            for (Type ignored : types) {
                columnHashes.add(new XxHash64());
            }
            this.columnHashes = columnHashes.build();
        }

        public static WriteChecksumBuilder createWriteChecksumBuilder(ColumnMetadata orcTypes, List readTypes)
        {
            checkArgument(readTypes.size() == orcTypes.get(ROOT_COLUMN).getFieldCount(), "checksum requires all columns to be read");
            return new WriteChecksumBuilder(readTypes);
        }

        public void addStripe(int rowCount)
        {
            longSlice.setInt(0, rowCount);
            stripeHash.update(longBuffer, 0, Integer.BYTES);
        }

        public void addPage(Page page)
        {
            requireNonNull(page, "page is null");
            checkArgument(page.getChannelCount() == columnHashes.size(), "invalid page");

            for (int channel = 0; channel < columnHashes.size(); channel++) {
                ValidationHash validationHash = validationHashes.get(channel);
                Block block = page.getBlock(channel);
                XxHash64 xxHash64 = columnHashes.get(channel);
                for (int position = 0; position < block.getPositionCount(); position++) {
                    long hash = validationHash.hash(block, position);
                    longSlice.setLong(0, hash);
                    xxHash64.update(longBuffer);
                }
            }
            totalRowCount += page.getPositionCount();
        }

        public WriteChecksum build()
        {
            return new WriteChecksum(
                    totalRowCount,
                    stripeHash.hash(),
                    columnHashes.stream()
                            .map(XxHash64::hash)
                            .collect(toImmutableList()));
        }
    }

    public class StatisticsValidation
    {
        private final List types;
        private List columnStatisticsValidations;
        private long rowCount;

        private StatisticsValidation(List types)
        {
            this.types = requireNonNull(types, "types is null");
            columnStatisticsValidations = types.stream()
                    .map(ColumnStatisticsValidation::new)
                    .collect(toImmutableList());
        }

        public void reset()
        {
            rowCount = 0;
            columnStatisticsValidations = types.stream()
                    .map(ColumnStatisticsValidation::new)
                    .collect(toImmutableList());
        }

        public void addPage(Page page)
        {
            rowCount += page.getPositionCount();
            for (int channel = 0; channel < columnStatisticsValidations.size(); channel++) {
                columnStatisticsValidations.get(channel).addBlock(page.getBlock(channel));
            }
        }

        public Optional> build()
        {
            if (rowCount == 0) {
                return Optional.empty();
            }
            ImmutableList.Builder statisticsBuilders = ImmutableList.builder();
            statisticsBuilders.add(new ColumnStatistics(rowCount, 0, null, null, null, null, null, null, null, null, null, null));
            columnStatisticsValidations.forEach(validation -> validation.build(statisticsBuilders));
            return Optional.of(new ColumnMetadata<>(statisticsBuilders.build()));
        }
    }

    private class ColumnStatisticsValidation
    {
        private final Type type;
        private final StatisticsBuilder statisticsBuilder;
        private final Function> fieldExtractor;
        private final List fieldBuilders;

        private ColumnStatisticsValidation(Type type)
        {
            this.type = requireNonNull(type, "type is null");

            if (BOOLEAN.equals(type)) {
                statisticsBuilder = new BooleanStatisticsBuilder();
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TINYINT.equals(type)) {
                statisticsBuilder = new CountStatisticsBuilder();
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (SMALLINT.equals(type)) {
                statisticsBuilder = new IntegerStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (INTEGER.equals(type)) {
                statisticsBuilder = new IntegerStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (BIGINT.equals(type)) {
                statisticsBuilder = new IntegerStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (DOUBLE.equals(type)) {
                statisticsBuilder = new DoubleStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (REAL.equals(type)) {
                statisticsBuilder = new DoubleStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (type instanceof VarcharType) {
                statisticsBuilder = new StringStatisticsBuilder(stringStatisticsLimitInBytes, new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (type instanceof CharType) {
                statisticsBuilder = new StringStatisticsBuilder(stringStatisticsLimitInBytes, new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (VARBINARY.equals(type) || UUID.equals(type)) {
                statisticsBuilder = new BinaryStatisticsBuilder();
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TIME_MICROS.equals(type)) {
                statisticsBuilder = new TimeMicrosStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (DATE.equals(type)) {
                statisticsBuilder = new DateStatisticsBuilder(new NoOpBloomFilterBuilder());
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TIMESTAMP_MILLIS.equals(type) || TIMESTAMP_MICROS.equals(type)) {
                statisticsBuilder = new TimestampStatisticsBuilder(this::timestampMicrosToMillis);
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TIMESTAMP_NANOS.equals(type)) {
                statisticsBuilder = new TimestampStatisticsBuilder(this::timestampNanosToMillis);
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TIMESTAMP_TZ_MILLIS.equals(type)) {
                statisticsBuilder = new TimestampStatisticsBuilder(this::timestampTzShortToMillis);
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (TIMESTAMP_TZ_MICROS.equals(type) || TIMESTAMP_TZ_NANOS.equals(type)) {
                statisticsBuilder = new TimestampStatisticsBuilder(this::timestampTzLongToMillis);
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (type instanceof DecimalType decimalType) {
                if (decimalType.isShort()) {
                    statisticsBuilder = new ShortDecimalStatisticsBuilder(decimalType.getScale());
                }
                else {
                    statisticsBuilder = new LongDecimalStatisticsBuilder();
                }
                fieldExtractor = _ -> ImmutableList.of();
                fieldBuilders = ImmutableList.of();
            }
            else if (type instanceof ArrayType) {
                statisticsBuilder = new CountStatisticsBuilder();
                fieldExtractor = block -> ImmutableList.of(toColumnarArray(block).getElementsBlock());
                fieldBuilders = ImmutableList.of(new ColumnStatisticsValidation(Iterables.getOnlyElement(type.getTypeParameters())));
            }
            else if (type instanceof MapType) {
                statisticsBuilder = new CountStatisticsBuilder();
                fieldExtractor = block -> {
                    ColumnarMap columnarMap = toColumnarMap(block);
                    return ImmutableList.of(columnarMap.getKeysBlock(), columnarMap.getValuesBlock());
                };
                fieldBuilders = type.getTypeParameters().stream()
                        .map(ColumnStatisticsValidation::new)
                        .collect(toImmutableList());
            }
            else if (type instanceof RowType) {
                statisticsBuilder = new CountStatisticsBuilder();
                fieldExtractor = block -> RowBlock.getRowFieldsFromBlock(block.getLoadedBlock());
                fieldBuilders = type.getTypeParameters().stream()
                        .map(ColumnStatisticsValidation::new)
                        .collect(toImmutableList());
            }
            else {
                throw new TrinoException(NOT_SUPPORTED, format("Unsupported Hive type: %s", type));
            }
        }

        private long timestampMicrosToMillis(Type blockType, Block block, int position)
        {
            return floorDiv(blockType.getLong(block, position), MICROSECONDS_PER_MILLISECOND);
        }

        private long timestampNanosToMillis(Type blockType, Block block, int position)
        {
            return floorDiv(((LongTimestamp) blockType.getObject(block, position)).getEpochMicros(), MICROSECONDS_PER_MILLISECOND);
        }

        private long timestampTzShortToMillis(Type blockType, Block block, int position)
        {
            return unpackMillisUtc(blockType.getLong(block, position));
        }

        private long timestampTzLongToMillis(Type blockType, Block block, int position)
        {
            return ((LongTimestampWithTimeZone) blockType.getObject(block, position)).getEpochMillis();
        }

        private void addBlock(Block block)
        {
            statisticsBuilder.addBlock(type, block);

            List fields = fieldExtractor.apply(block);
            for (int i = 0; i < fieldBuilders.size(); i++) {
                fieldBuilders.get(i).addBlock(fields.get(i));
            }
        }

        private void build(ImmutableList.Builder output)
        {
            output.add(statisticsBuilder.buildColumnStatistics());
            fieldBuilders.forEach(fieldBuilders -> fieldBuilders.build(output));
        }
    }

    private static class CountStatisticsBuilder
            implements StatisticsBuilder
    {
        private long rowCount;

        @Override
        public void addBlock(Type type, Block block)
        {
            for (int position = 0; position < block.getPositionCount(); position++) {
                if (!block.isNull(position)) {
                    rowCount++;
                }
            }
        }

        @Override
        public ColumnStatistics buildColumnStatistics()
        {
            return new ColumnStatistics(rowCount, 0, null, null, null, null, null, null, null, null, null, null);
        }
    }

    private static class RowGroupStatistics
    {
        private static final int INSTANCE_SIZE = instanceSize(RowGroupStatistics.class);

        private final OrcWriteValidationMode validationMode;
        private final SortedMap columnStatistics;
        private final long hash;

        public RowGroupStatistics(OrcWriteValidationMode validationMode, Map columnStatistics)
        {
            this.validationMode = validationMode;

            requireNonNull(columnStatistics, "columnStatistics is null");
            switch (validationMode) {
                case HASHED:
                    this.columnStatistics = ImmutableSortedMap.of();
                    hash = hashColumnStatistics(ImmutableSortedMap.copyOf(columnStatistics));
                    break;
                case DETAILED:
                    this.columnStatistics = ImmutableSortedMap.copyOf(columnStatistics);
                    hash = 0;
                    break;
                case BOTH:
                    this.columnStatistics = ImmutableSortedMap.copyOf(columnStatistics);
                    hash = hashColumnStatistics(this.columnStatistics);
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported validation mode");
            }
        }

        private static long hashColumnStatistics(SortedMap columnStatistics)
        {
            StatisticsHasher statisticsHasher = new StatisticsHasher();
            statisticsHasher.putInt(columnStatistics.size());
            for (Entry entry : columnStatistics.entrySet()) {
                statisticsHasher.putInt(entry.getKey().getId())
                        .putOptionalHashable(entry.getValue());
            }
            return statisticsHasher.hash();
        }

        public OrcWriteValidationMode getValidationMode()
        {
            return validationMode;
        }

        public Map getColumnStatistics()
        {
            verify(validationMode != HASHED, "columnStatistics are not available in HASHED mode");
            return columnStatistics;
        }

        public long getHash()
        {
            return hash;
        }
    }

    public static class OrcWriteValidationBuilder
    {
        private static final int INSTANCE_SIZE = instanceSize(OrcWriteValidationBuilder.class);

        private final OrcWriteValidationMode validationMode;

        private List version;
        private CompressionKind compression;
        private ZoneId timeZone;
        private int rowGroupMaxRowCount;
        private int stringStatisticsLimitInBytes;
        private List columnNames;
        private final Map metadata = new HashMap<>();
        private final WriteChecksumBuilder checksum;
        private List currentRowGroupStatistics = new ArrayList<>();
        private final Map> rowGroupStatisticsByStripe = new HashMap<>();
        private final Map stripeStatistics = new HashMap<>();
        private Optional> fileStatistics = Optional.empty();
        private long retainedSize = INSTANCE_SIZE;

        public OrcWriteValidationBuilder(OrcWriteValidationMode validationMode, List types)
        {
            this.validationMode = validationMode;
            this.checksum = new WriteChecksumBuilder(types);
        }

        public long getRetainedSize()
        {
            return retainedSize;
        }

        public OrcWriteValidationBuilder setVersion(List version)
        {
            this.version = ImmutableList.copyOf(version);
            return this;
        }

        public void setCompression(CompressionKind compression)
        {
            this.compression = compression;
        }

        public void setTimeZone(ZoneId timeZone)
        {
            this.timeZone = timeZone;
        }

        public void setRowGroupMaxRowCount(int rowGroupMaxRowCount)
        {
            this.rowGroupMaxRowCount = rowGroupMaxRowCount;
        }

        public OrcWriteValidationBuilder setStringStatisticsLimitInBytes(int stringStatisticsLimitInBytes)
        {
            this.stringStatisticsLimitInBytes = stringStatisticsLimitInBytes;
            return this;
        }

        public OrcWriteValidationBuilder setColumnNames(List columnNames)
        {
            this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"));
            return this;
        }

        public OrcWriteValidationBuilder addMetadataProperty(String key, Slice value)
        {
            metadata.put(key, value);
            return this;
        }

        public OrcWriteValidationBuilder addStripe(int rowCount)
        {
            checksum.addStripe(rowCount);
            return this;
        }

        public OrcWriteValidationBuilder addPage(Page page)
        {
            checksum.addPage(page);
            return this;
        }

        public void addRowGroupStatistics(Map columnStatistics)
        {
            RowGroupStatistics rowGroupStatistics = new RowGroupStatistics(validationMode, columnStatistics);
            currentRowGroupStatistics.add(rowGroupStatistics);

            retainedSize += RowGroupStatistics.INSTANCE_SIZE;
            if (validationMode != HASHED) {
                for (ColumnStatistics statistics : rowGroupStatistics.getColumnStatistics().values()) {
                    retainedSize += Integer.BYTES + statistics.getRetainedSizeInBytes();
                }
            }
        }

        public void addStripeStatistics(long stripStartOffset, StripeStatistics columnStatistics)
        {
            stripeStatistics.put(stripStartOffset, columnStatistics);
            rowGroupStatisticsByStripe.put(stripStartOffset, currentRowGroupStatistics);
            currentRowGroupStatistics = new ArrayList<>();
        }

        public void setFileStatistics(Optional> fileStatistics)
        {
            this.fileStatistics = fileStatistics;
        }

        public OrcWriteValidation build()
        {
            return new OrcWriteValidation(
                    version,
                    compression,
                    timeZone,
                    rowGroupMaxRowCount,
                    columnNames,
                    metadata,
                    checksum.build(),
                    rowGroupStatisticsByStripe,
                    stripeStatistics,
                    fileStatistics,
                    stringStatisticsLimitInBytes);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy