io.trino.parquet.ParquetWriteValidation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of trino-parquet Show documentation
Show all versions of trino-parquet Show documentation
Trino - Parquet file format support
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.parquet;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.trino.parquet.metadata.ColumnChunkMetadata;
import io.trino.parquet.metadata.IndexReference;
import io.trino.parquet.metadata.PrunedBlockMetadata;
import io.trino.parquet.reader.RowGroupInfo;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.statistics.Statistics;
import org.apache.parquet.format.ColumnChunk;
import org.apache.parquet.format.ColumnMetaData;
import org.apache.parquet.format.RowGroup;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.trino.parquet.ColumnStatisticsValidation.ColumnStatistics;
import static io.trino.parquet.ParquetMetadataConverter.getPrimitive;
import static io.trino.parquet.ParquetValidationUtils.validateParquet;
import static io.trino.parquet.ParquetWriteValidation.IndexReferenceValidation.fromIndexReference;
import static java.util.Objects.requireNonNull;
public class ParquetWriteValidation
{
private final String createdBy;
private final Optional timeZoneId;
private final List columns;
private final List rowGroups;
private final WriteChecksum checksum;
private final List types;
private final List columnNames;
private ParquetWriteValidation(
String createdBy,
Optional timeZoneId,
List columns,
List rowGroups,
WriteChecksum checksum,
List types,
List columnNames)
{
this.createdBy = requireNonNull(createdBy, "createdBy is null");
checkArgument(!createdBy.isEmpty(), "createdBy is empty");
this.timeZoneId = requireNonNull(timeZoneId, "timeZoneId is null");
this.columns = requireNonNull(columns, "columnPaths is null");
this.rowGroups = requireNonNull(rowGroups, "rowGroups is null");
this.checksum = requireNonNull(checksum, "checksum is null");
this.types = requireNonNull(types, "types is null");
this.columnNames = requireNonNull(columnNames, "columnNames is null");
}
public String getCreatedBy()
{
return createdBy;
}
public List getTypes()
{
return types;
}
public List getColumnNames()
{
return columnNames;
}
public void validateTimeZone(ParquetDataSourceId dataSourceId, Optional actualTimeZoneId)
throws ParquetCorruptionException
{
validateParquet(
timeZoneId.equals(actualTimeZoneId),
dataSourceId,
"Found unexpected time zone %s, expected %s",
actualTimeZoneId,
timeZoneId);
}
public void validateColumns(ParquetDataSourceId dataSourceId, MessageType schema)
throws ParquetCorruptionException
{
List actualColumns = schema.getColumns();
validateParquet(
actualColumns.size() == columns.size(),
dataSourceId,
"Found columns %s, expected %s",
actualColumns,
columns);
for (int columnIndex = 0; columnIndex < columns.size(); columnIndex++) {
validateColumnDescriptorsSame(actualColumns.get(columnIndex), columns.get(columnIndex), dataSourceId);
}
}
public void validateBlocksMetadata(ParquetDataSourceId dataSourceId, List rowGroupInfos)
throws ParquetCorruptionException
{
validateParquet(
rowGroupInfos.size() == rowGroups.size(),
dataSourceId,
"Number of row groups %d did not match %d",
rowGroupInfos.size(),
rowGroups.size());
for (int rowGroupIndex = 0; rowGroupIndex < rowGroupInfos.size(); rowGroupIndex++) {
PrunedBlockMetadata block = rowGroupInfos.get(rowGroupIndex).prunedBlockMetadata();
RowGroup rowGroup = rowGroups.get(rowGroupIndex);
validateParquet(
block.getRowCount() == rowGroup.getNum_rows(),
dataSourceId,
"Number of rows %d in row group %d did not match %d",
block.getRowCount(),
rowGroupIndex,
rowGroup.getNum_rows());
List columnChunkMetaData = block.getColumns();
validateParquet(
columnChunkMetaData.size() == rowGroup.getColumnsSize(),
dataSourceId,
"Number of columns %d in row group %d did not match %d",
columnChunkMetaData.size(),
rowGroupIndex,
rowGroup.getColumnsSize());
for (int columnIndex = 0; columnIndex < columnChunkMetaData.size(); columnIndex++) {
ColumnChunkMetadata actualColumnMetadata = columnChunkMetaData.get(columnIndex);
ColumnChunk columnChunk = rowGroup.getColumns().get(columnIndex);
ColumnMetaData expectedColumnMetadata = columnChunk.getMeta_data();
verifyColumnMetadataMatch(
actualColumnMetadata.getCodec().getParquetCompressionCodec().equals(expectedColumnMetadata.getCodec()),
"Compression codec",
actualColumnMetadata.getCodec(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getCodec());
verifyColumnMetadataMatch(
actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName().equals(getPrimitive(expectedColumnMetadata.getType())),
"Type",
actualColumnMetadata.getPrimitiveType().getPrimitiveTypeName(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getType());
verifyColumnMetadataMatch(
areEncodingsSame(actualColumnMetadata.getEncodings(), expectedColumnMetadata.getEncodings()),
"Encodings",
actualColumnMetadata.getEncodings(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getEncodings());
verifyColumnMetadataMatch(
areStatisticsSame(actualColumnMetadata.getStatistics(), expectedColumnMetadata.getStatistics()),
"Statistics",
actualColumnMetadata.getStatistics(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getStatistics());
verifyColumnMetadataMatch(
actualColumnMetadata.getFirstDataPageOffset() == expectedColumnMetadata.getData_page_offset(),
"Data page offset",
actualColumnMetadata.getFirstDataPageOffset(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getData_page_offset());
verifyColumnMetadataMatch(
actualColumnMetadata.getDictionaryPageOffset() == expectedColumnMetadata.getDictionary_page_offset(),
"Dictionary page offset",
actualColumnMetadata.getDictionaryPageOffset(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getDictionary_page_offset());
verifyColumnMetadataMatch(
actualColumnMetadata.getValueCount() == expectedColumnMetadata.getNum_values(),
"Value count",
actualColumnMetadata.getValueCount(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getNum_values());
verifyColumnMetadataMatch(
actualColumnMetadata.getTotalUncompressedSize() == expectedColumnMetadata.getTotal_uncompressed_size(),
"Total uncompressed size",
actualColumnMetadata.getTotalUncompressedSize(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getTotal_uncompressed_size());
verifyColumnMetadataMatch(
actualColumnMetadata.getTotalSize() == expectedColumnMetadata.getTotal_compressed_size(),
"Total size",
actualColumnMetadata.getTotalSize(),
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnMetadata.getTotal_compressed_size());
IndexReferenceValidation expectedColumnIndexReference = new IndexReferenceValidation(columnChunk.getColumn_index_offset(), columnChunk.getColumn_index_length());
IndexReference actualColumnIndexReference = actualColumnMetadata.getColumnIndexReference();
verifyColumnMetadataMatch(
actualColumnIndexReference == null || fromIndexReference(actualColumnMetadata.getColumnIndexReference()).equals(expectedColumnIndexReference),
"Column index reference",
actualColumnIndexReference,
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedColumnIndexReference);
IndexReferenceValidation expectedOffsetIndexReference = new IndexReferenceValidation(columnChunk.getOffset_index_offset(), columnChunk.getOffset_index_length());
IndexReference actualOffsetIndexReference = actualColumnMetadata.getOffsetIndexReference();
verifyColumnMetadataMatch(
actualOffsetIndexReference == null || fromIndexReference(actualOffsetIndexReference).equals(expectedOffsetIndexReference),
"Offset index reference",
actualOffsetIndexReference,
actualColumnMetadata.getPath(),
rowGroupIndex,
dataSourceId,
expectedOffsetIndexReference);
}
}
}
public void validateChecksum(ParquetDataSourceId dataSourceId, WriteChecksum actualChecksum)
throws ParquetCorruptionException
{
validateParquet(
checksum.totalRowCount() == actualChecksum.totalRowCount(),
dataSourceId,
"Write validation failed: Expected row count %d, found %d",
checksum.totalRowCount(),
actualChecksum.totalRowCount());
List columnHashes = actualChecksum.columnHashes();
for (int columnIndex = 0; columnIndex < columnHashes.size(); columnIndex++) {
long expectedHash = checksum.columnHashes().get(columnIndex);
validateParquet(
expectedHash == columnHashes.get(columnIndex),
dataSourceId,
"Invalid checksum for column %s: Expected hash %d, found %d",
columnIndex,
expectedHash,
columnHashes.get(columnIndex));
}
}
public record WriteChecksum(long totalRowCount, List columnHashes)
{
public WriteChecksum(long totalRowCount, List columnHashes)
{
this.totalRowCount = totalRowCount;
this.columnHashes = ImmutableList.copyOf(requireNonNull(columnHashes, "columnHashes is null"));
}
}
public static class WriteChecksumBuilder
{
private final List validationHashes;
private final List columnHashes;
private final byte[] longBuffer = new byte[Long.BYTES];
private final Slice longSlice = Slices.wrappedBuffer(longBuffer);
private long totalRowCount;
private WriteChecksumBuilder(List types)
{
this.validationHashes = requireNonNull(types, "types is null").stream()
.map(ValidationHash::createValidationHash)
.collect(toImmutableList());
ImmutableList.Builder columnHashes = ImmutableList.builder();
for (Type ignored : types) {
columnHashes.add(new XxHash64());
}
this.columnHashes = columnHashes.build();
}
public static WriteChecksumBuilder createWriteChecksumBuilder(List readTypes)
{
return new WriteChecksumBuilder(readTypes);
}
public void addPage(Page page)
{
requireNonNull(page, "page is null");
checkArgument(
page.getChannelCount() == columnHashes.size(),
"Invalid page: page channels count %s did not match columns count %s",
page.getChannelCount(),
columnHashes.size());
for (int channel = 0; channel < columnHashes.size(); channel++) {
ValidationHash validationHash = validationHashes.get(channel);
Block block = page.getBlock(channel);
XxHash64 xxHash64 = columnHashes.get(channel);
for (int position = 0; position < block.getPositionCount(); position++) {
long hash = validationHash.hash(block, position);
longSlice.setLong(0, hash);
xxHash64.update(longBuffer);
}
}
totalRowCount += page.getPositionCount();
}
public WriteChecksum build()
{
return new WriteChecksum(
totalRowCount,
columnHashes.stream()
.map(XxHash64::hash)
.collect(toImmutableList()));
}
}
public void validateRowGroupStatistics(ParquetDataSourceId dataSourceId, PrunedBlockMetadata blockMetaData, List actualColumnStatistics)
throws ParquetCorruptionException
{
List columnChunks = blockMetaData.getColumns();
checkArgument(
columnChunks.size() == actualColumnStatistics.size(),
"Column chunk metadata count %s did not match column fields count %s",
columnChunks.size(),
actualColumnStatistics.size());
for (int columnIndex = 0; columnIndex < columnChunks.size(); columnIndex++) {
ColumnChunkMetadata columnMetaData = columnChunks.get(columnIndex);
ColumnStatistics columnStatistics = actualColumnStatistics.get(columnIndex);
long expectedValuesCount = columnMetaData.getValueCount();
validateParquet(
expectedValuesCount == columnStatistics.valuesCount(),
dataSourceId,
"Invalid values count for column %s: Expected %d, found %d",
columnIndex,
expectedValuesCount,
columnStatistics.valuesCount());
Statistics> parquetStatistics = columnMetaData.getStatistics();
if (parquetStatistics.isNumNullsSet()) {
long expectedNullsCount = parquetStatistics.getNumNulls();
validateParquet(
expectedNullsCount == columnStatistics.nonLeafValuesCount(),
dataSourceId,
"Invalid nulls count for column %s: Expected %d, found %d",
columnIndex,
expectedNullsCount,
columnStatistics.nonLeafValuesCount());
}
}
}
public static class StatisticsValidation
{
private final List types;
private List columnStatisticsValidations;
private StatisticsValidation(List types)
{
this.types = requireNonNull(types, "types is null");
this.columnStatisticsValidations = types.stream()
.map(ColumnStatisticsValidation::new)
.collect(toImmutableList());
}
public static StatisticsValidation createStatisticsValidationBuilder(List readTypes)
{
return new StatisticsValidation(readTypes);
}
public void addPage(Page page)
{
requireNonNull(page, "page is null");
checkArgument(
page.getChannelCount() == columnStatisticsValidations.size(),
"Invalid page: page channels count %s did not match columns count %s",
page.getChannelCount(),
columnStatisticsValidations.size());
for (int channel = 0; channel < columnStatisticsValidations.size(); channel++) {
ColumnStatisticsValidation columnStatisticsValidation = columnStatisticsValidations.get(channel);
columnStatisticsValidation.addBlock(page.getBlock(channel));
}
}
public void reset()
{
this.columnStatisticsValidations = types.stream()
.map(ColumnStatisticsValidation::new)
.collect(toImmutableList());
}
public List build()
{
return this.columnStatisticsValidations.stream()
.flatMap(validation -> validation.build().stream())
.collect(toImmutableList());
}
}
public static class ParquetWriteValidationBuilder
{
private static final int INSTANCE_SIZE = instanceSize(ParquetWriteValidationBuilder.class);
private static final int COLUMN_DESCRIPTOR_INSTANCE_SIZE = instanceSize(ColumnDescriptor.class);
private static final int PRIMITIVE_TYPE_INSTANCE_SIZE = instanceSize(PrimitiveType.class);
private final List types;
private final List columnNames;
private final WriteChecksumBuilder checksum;
private String createdBy;
private Optional timeZoneId = Optional.empty();
private List columns;
private List rowGroups;
private long retainedSize = INSTANCE_SIZE;
public ParquetWriteValidationBuilder(List types, List columnNames)
{
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"));
checkArgument(
types.size() == columnNames.size(),
"Types count %s did not match column names count %s",
types.size(),
columnNames.size());
this.checksum = new WriteChecksumBuilder(types);
retainedSize += estimatedSizeOf(types, type -> 0)
+ estimatedSizeOf(columnNames, SizeOf::estimatedSizeOf);
}
public long getRetainedSize()
{
return retainedSize;
}
public void setCreatedBy(String createdBy)
{
this.createdBy = createdBy;
retainedSize += estimatedSizeOf(createdBy);
}
public void setTimeZone(Optional timeZoneId)
{
this.timeZoneId = timeZoneId;
timeZoneId.ifPresent(id -> retainedSize += estimatedSizeOf(id));
}
public void setColumns(List columns)
{
this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null"));
retainedSize += estimatedSizeOf(columns, descriptor -> {
return COLUMN_DESCRIPTOR_INSTANCE_SIZE
+ (2 * SIZE_OF_INT) // maxRep, maxDef
+ estimatedSizeOfStringArray(descriptor.getPath())
+ PRIMITIVE_TYPE_INSTANCE_SIZE
+ (3 * SIZE_OF_INT); // primitive, length, columnOrder
});
}
public void setRowGroups(List rowGroups)
{
this.rowGroups = ImmutableList.copyOf(requireNonNull(rowGroups, "rowGroups is null"));
}
public void addPage(Page page)
{
checksum.addPage(page);
}
public ParquetWriteValidation build()
{
return new ParquetWriteValidation(
createdBy,
timeZoneId,
columns,
rowGroups,
checksum.build(),
types,
columnNames);
}
}
// parquet-mr IndexReference class lacks equals and toString implementations
static class IndexReferenceValidation
{
private final long offset;
private final int length;
private IndexReferenceValidation(long offset, int length)
{
this.offset = offset;
this.length = length;
}
static IndexReferenceValidation fromIndexReference(IndexReference indexReference)
{
return new IndexReferenceValidation(indexReference.getOffset(), indexReference.getLength());
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
IndexReferenceValidation that = (IndexReferenceValidation) o;
return offset == that.offset && length == that.length;
}
@Override
public int hashCode()
{
return Objects.hash(offset, length);
}
@Override
public String toString()
{
return toStringHelper(this)
.add("offset", offset)
.add("length", length)
.toString();
}
}
private static void verifyColumnMetadataMatch(
boolean condition,
String name,
T actual,
ColumnPath path,
int rowGroup,
ParquetDataSourceId dataSourceId,
U expected)
throws ParquetCorruptionException
{
if (!condition) {
throw new ParquetCorruptionException(
dataSourceId,
"%s [%s] for column %s in row group %d did not match [%s]",
name,
actual,
path,
rowGroup,
expected);
}
}
private static boolean areEncodingsSame(Set actual, List expected)
{
return actual.equals(expected.stream().map(ParquetMetadataConverter::getEncoding).collect(toImmutableSet()));
}
private static boolean areStatisticsSame(org.apache.parquet.column.statistics.Statistics> actual, org.apache.parquet.format.Statistics expected)
{
Statistics.Builder expectedStatsBuilder = Statistics.getBuilderForReading(actual.type());
if (expected.isSetNull_count()) {
expectedStatsBuilder.withNumNulls(expected.getNull_count());
}
if (expected.isSetMin_value()) {
expectedStatsBuilder.withMin(expected.getMin_value());
}
if (expected.isSetMax_value()) {
expectedStatsBuilder.withMax(expected.getMax_value());
}
return actual.equals(expectedStatsBuilder.build());
}
private static void validateColumnDescriptorsSame(ColumnDescriptor actual, ColumnDescriptor expected, ParquetDataSourceId dataSourceId)
throws ParquetCorruptionException
{
// Column names are lower-cased by MetadataReader#readFooter
validateParquet(
Arrays.equals(actual.getPath(), Arrays.stream(expected.getPath()).map(field -> field.toLowerCase(Locale.ENGLISH)).toArray()),
dataSourceId,
"Column path %s did not match expected column path %s",
actual.getPath(),
expected.getPath());
validateParquet(
actual.getMaxDefinitionLevel() == expected.getMaxDefinitionLevel(),
dataSourceId,
"Column %s max definition level %d did not match expected max definition level %d",
actual.getPath(),
actual.getMaxDefinitionLevel(),
expected.getMaxDefinitionLevel());
validateParquet(
actual.getMaxRepetitionLevel() == expected.getMaxRepetitionLevel(),
dataSourceId,
"Column %s max repetition level %d did not match expected max repetition level %d",
actual.getPath(),
actual.getMaxRepetitionLevel(),
expected.getMaxRepetitionLevel());
PrimitiveType actualPrimitiveType = actual.getPrimitiveType();
PrimitiveType expectedPrimitiveType = expected.getPrimitiveType();
// We don't use PrimitiveType#equals directly because column names are lower-cased by MetadataReader#readFooter
validateParquet(
actualPrimitiveType.getPrimitiveTypeName().equals(expectedPrimitiveType.getPrimitiveTypeName())
&& actualPrimitiveType.getTypeLength() == expectedPrimitiveType.getTypeLength()
&& actualPrimitiveType.getRepetition().equals(expectedPrimitiveType.getRepetition())
&& actualPrimitiveType.getName().equals(expectedPrimitiveType.getName().toLowerCase(Locale.ENGLISH))
&& Objects.equals(actualPrimitiveType.getLogicalTypeAnnotation(), expectedPrimitiveType.getLogicalTypeAnnotation()),
dataSourceId,
"Column %s primitive type %s did not match expected primitive type %s",
actual.getPath(),
actualPrimitiveType,
expectedPrimitiveType);
}
private static long estimatedSizeOfStringArray(String[] path)
{
long size = sizeOf(path);
for (String field : path) {
size += estimatedSizeOf(field);
}
return size;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy