All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.hive.formats.avro.AvroPagePositionDataWriter Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.hive.formats.avro;

import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlMap;
import io.trino.spi.block.SqlRow;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.Encoder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.IntFunction;
import java.util.function.ToIntBiFunction;
import java.util.function.ToLongBiFunction;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.hive.formats.avro.AvroTypeUtils.SimpleUnionNullIndex;
import static io.trino.hive.formats.avro.AvroTypeUtils.getSimpleNullableUnionNullIndex;
import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion;
import static io.trino.hive.formats.avro.AvroTypeUtils.lowerCaseAllFieldsForWriter;
import static io.trino.hive.formats.avro.AvroTypeUtils.unwrapNullableUnion;
import static io.trino.hive.formats.avro.AvroTypeUtils.verifyNoCircularReferences;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.util.Objects.requireNonNull;

public class AvroPagePositionDataWriter
        implements DatumWriter
{
    private Page page;
    private final Schema schema;
    private final RecordBlockPositionEncoder pageBlockPositionEncoder;

    public AvroPagePositionDataWriter(Schema schema, AvroTypeManager avroTypeManager, List channelNames, List channelTypes)
            throws AvroTypeException
    {
        this.schema = requireNonNull(schema, "schema is null");
        verifyNoCircularReferences(schema);
        pageBlockPositionEncoder = new RecordBlockPositionEncoder(schema, avroTypeManager, channelNames, channelTypes);
        checkInvariants();
    }

    @Override
    public void setSchema(Schema schema)
    {
        requireNonNull(schema, "schema is null");
        if (this.schema != schema) {
            verify(this.schema.equals(lowerCaseAllFieldsForWriter(schema)), "Unable to change schema for this data writer");
        }
    }

    public void setPage(Page page)
    {
        this.page = requireNonNull(page, "page is null");
        checkInvariants();
        pageBlockPositionEncoder.setChannelBlocksFromPage(page);
    }

    private void checkInvariants()
    {
        verify(schema.getType() == Schema.Type.RECORD, "Can only write pages to record schema");
        verify(page == null || page.getChannelCount() == schema.getFields().size(), "Page channel count must equal schema field count");
    }

    @Override
    public void write(Integer position, Encoder encoder)
            throws IOException
    {
        checkWritable();
        if (position >= page.getPositionCount()) {
            throw new IndexOutOfBoundsException("Position %s not within page with position count %s".formatted(position, page.getPositionCount()));
        }
        pageBlockPositionEncoder.encodePositionInEachChannel(position, encoder);
    }

    private void checkWritable()
    {
        checkState(page != null, "page must be set before beginning to write positions");
    }

    private abstract static class BlockPositionEncoder
    {
        protected Block block;
        private final Optional nullIndex;

        public BlockPositionEncoder(Optional nullIndex)
        {
            this.nullIndex = requireNonNull(nullIndex, "nullIdx is null");
        }

        abstract void encodeFromBlock(int position, Encoder encoder)
                throws IOException;

        void encode(int position, Encoder encoder)
                throws IOException
        {
            checkState(block != null, "block must be set before calling encode");
            boolean isNull = block.isNull(position);
            if (isNull && nullIndex.isEmpty()) {
                throw new IOException("Can not write null value for non-nullable schema");
            }
            if (nullIndex.isPresent()) {
                encoder.writeIndex(isNull ? nullIndex.get().getIndex() : 1 ^ nullIndex.get().getIndex());
            }
            if (isNull) {
                encoder.writeNull();
            }
            else {
                encodeFromBlock(position, encoder);
            }
        }

        void setBlock(Block block)
        {
            this.block = block;
        }
    }

    private static BlockPositionEncoder createBlockPositionEncoder(Schema schema, AvroTypeManager avroTypeManager, Type type)
            throws AvroTypeException
    {
        return createBlockPositionEncoder(schema, avroTypeManager, type, Optional.empty());
    }

    private static BlockPositionEncoder createBlockPositionEncoder(Schema schema, AvroTypeManager avroTypeManager, Type type, Optional nullIdx)
            throws AvroTypeException
    {
        Optional> overrideToAvroGenericObject = avroTypeManager.overrideBlockToAvroObject(schema, type);
        if (overrideToAvroGenericObject.isPresent()) {
            return new UserDefinedBlockPositionEncoder(nullIdx, schema, overrideToAvroGenericObject.get());
        }
        switch (schema.getType()) {
            case NULL -> throw new AvroTypeException("No null support outside of union");
            case BOOLEAN -> {
                if (BOOLEAN.equals(type)) {
                    return new BooleanBlockPositionEncoder(nullIdx);
                }
            }
            case INT -> {
                if (TINYINT.equals(type)) {
                    return new IntBlockPositionEncoder(nullIdx, TINYINT::getByte);
                }
                if (SMALLINT.equals(type)) {
                    return new IntBlockPositionEncoder(nullIdx, SMALLINT::getShort);
                }
                if (INTEGER.equals(type)) {
                    return new IntBlockPositionEncoder(nullIdx, INTEGER::getInt);
                }
            }
            case LONG -> {
                if (TINYINT.equals(type)) {
                    return new LongBlockPositionEncoder(nullIdx, TINYINT::getByte);
                }
                if (SMALLINT.equals(type)) {
                    return new LongBlockPositionEncoder(nullIdx, SMALLINT::getShort);
                }
                if (INTEGER.equals(type)) {
                    return new LongBlockPositionEncoder(nullIdx, INTEGER::getInt);
                }
                if (BIGINT.equals(type)) {
                    return new LongBlockPositionEncoder(nullIdx, BIGINT::getLong);
                }
            }
            case FLOAT -> {
                if (REAL.equals(type)) {
                    return new FloatBlockPositionEncoder(nullIdx);
                }
            }
            case DOUBLE -> {
                if (DOUBLE.equals(type)) {
                    return new DoubleBlockPositionEncoder(nullIdx);
                }
            }
            case STRING -> {
                if (VARCHAR.equals(type)) {
                    return new StringPositionEncoder(nullIdx);
                }
            }
            case BYTES -> {
                if (VARBINARY.equals(type)) {
                    return new BytesPositionEncoder(nullIdx);
                }
            }
            case FIXED -> {
                if (VARBINARY.equals(type)) {
                    return new FixedBlockPositionEncoder(nullIdx, schema.getFixedSize());
                }
            }
            case ENUM -> {
                if (VARCHAR.equals(type)) {
                    return new EnumBlockPositionEncoder(nullIdx, schema.getEnumSymbols());
                }
            }
            case ARRAY -> {
                if (type instanceof ArrayType arrayType) {
                    return new ArrayBlockPositionEncoder(nullIdx, schema, avroTypeManager, arrayType);
                }
            }
            case MAP -> {
                if (type instanceof MapType mapType) {
                    return new MapBlockPositionEncoder(nullIdx, schema, avroTypeManager, mapType);
                }
            }
            case RECORD -> {
                if (type instanceof RowType rowType) {
                    return new RecordBlockPositionEncoder(nullIdx, schema, avroTypeManager, rowType);
                }
            }
            case UNION -> {
                if (isSimpleNullableUnion(schema)) {
                    return createBlockPositionEncoder(unwrapNullableUnion(schema), avroTypeManager, type, Optional.of(getSimpleNullableUnionNullIndex(schema)));
                }
                else {
                    throw new AvroTypeException("Unable to make writer for schema with non simple nullable union %s".formatted(schema));
                }
            }
        }
        throw new AvroTypeException("Schema and Trino Type mismatch between %s and %s".formatted(schema, type));
    }

    private static class BooleanBlockPositionEncoder
            extends BlockPositionEncoder
    {
        public BooleanBlockPositionEncoder(Optional isNullWithIndex)
        {
            super(isNullWithIndex);
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            encoder.writeBoolean(BOOLEAN.getBoolean(block, position));
        }
    }

    private static class IntBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final ToIntBiFunction getInt;

        public IntBlockPositionEncoder(Optional isNullWithIndex, ToIntBiFunction getInt)
        {
            super(isNullWithIndex);
            this.getInt = requireNonNull(getInt, "getInt is null");
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            encoder.writeInt(getInt.applyAsInt(block, position));
        }
    }

    private static class LongBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final ToLongBiFunction getLong;

        public LongBlockPositionEncoder(Optional isNullWithIndex, ToLongBiFunction getLong)
        {
            super(isNullWithIndex);
            this.getLong = requireNonNull(getLong, "getLong is null");
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            encoder.writeLong(getLong.applyAsLong(block, position));
        }
    }

    private static class FloatBlockPositionEncoder
            extends BlockPositionEncoder
    {
        public FloatBlockPositionEncoder(Optional isNullWithIndex)
        {
            super(isNullWithIndex);
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            encoder.writeFloat(REAL.getFloat(block, position));
        }
    }

    private static class DoubleBlockPositionEncoder
            extends BlockPositionEncoder
    {
        public DoubleBlockPositionEncoder(Optional isNullWithIndex)
        {
            super(isNullWithIndex);
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            encoder.writeDouble(DOUBLE.getDouble(block, position));
        }
    }

    private static class StringPositionEncoder
            extends BlockPositionEncoder
    {
        public StringPositionEncoder(Optional isNullWithIndex)
        {
            super(isNullWithIndex);
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            Slice value = VARCHAR.getSlice(block, position);
            encoder.writeLong(value.length());
            encoder.writeFixed(value.getBytes());
        }
    }

    private static class BytesPositionEncoder
            extends BlockPositionEncoder
    {
        public BytesPositionEncoder(Optional isNullWithIndex)
        {
            super(isNullWithIndex);
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            Slice value = VARBINARY.getSlice(block, position);
            encoder.writeLong(value.length());
            encoder.writeFixed(value.getBytes());
        }
    }

    private static class FixedBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final int fixedSize;

        public FixedBlockPositionEncoder(Optional nullIdx, int fixedSize)
        {
            super(nullIdx);
            this.fixedSize = fixedSize;
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            Slice value = VARBINARY.getSlice(block, position);
            if (value.length() != fixedSize) {
                throw new IOException("Unable to write Avro fixed with size %s from slice of length %s".formatted(fixedSize, value.length()));
            }
            encoder.writeFixed(value.getBytes());
        }
    }

    private static class EnumBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final Map symbolToIndex;

        public EnumBlockPositionEncoder(Optional nullIdx, List symbols)
        {
            super(nullIdx);
            ImmutableMap.Builder symbolToIndex = ImmutableMap.builder();
            for (int i = 0; i < symbols.size(); i++) {
                symbolToIndex.put(Slices.utf8Slice(symbols.get(i)), i);
            }
            this.symbolToIndex = symbolToIndex.buildOrThrow();
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            Slice value = VARCHAR.getSlice(block, position);
            Integer symbolIndex = symbolToIndex.get(value);
            if (symbolIndex == null) {
                throw new IOException("Unable to write Avro Enum symbol %s. Not found in set %s".formatted(
                        value.toStringUtf8(),
                        symbolToIndex.keySet().stream().map(Slice::toStringUtf8).toList()));
            }
            encoder.writeEnum(symbolIndex);
        }
    }

    private static class ArrayBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final BlockPositionEncoder elementBlockPositionEncoder;
        private final ArrayType type;

        public ArrayBlockPositionEncoder(Optional nullIdx, Schema schema, AvroTypeManager avroTypeManager, ArrayType type)
                throws AvroTypeException
        {
            super(nullIdx);
            verify(requireNonNull(schema, "schema is null").getType() == Schema.Type.ARRAY);
            this.type = requireNonNull(type, "type is null");
            elementBlockPositionEncoder = createBlockPositionEncoder(schema.getElementType(), avroTypeManager, type.getElementType());
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            Block elementBlock = type.getObject(block, position);
            elementBlockPositionEncoder.setBlock(elementBlock);
            int size = elementBlock.getPositionCount();
            encoder.writeArrayStart();
            encoder.setItemCount(size);
            for (int itemPos = 0; itemPos < size; itemPos++) {
                encoder.startItem();
                elementBlockPositionEncoder.encode(itemPos, encoder);
            }
            encoder.writeArrayEnd();
        }
    }

    private static class MapBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final BlockPositionEncoder keyBlockPositionEncoder = new StringPositionEncoder(Optional.empty());
        private final BlockPositionEncoder valueBlockPositionEncoder;
        private final MapType type;

        public MapBlockPositionEncoder(Optional nullIdx, Schema schema, AvroTypeManager avroTypeManager, MapType type)
                throws AvroTypeException
        {
            super(nullIdx);
            verify(requireNonNull(schema, "schema is null").getType() == Schema.Type.MAP);
            this.type = requireNonNull(type, "type is null");
            if (!VARCHAR.equals(this.type.getKeyType())) {
                throw new AvroTypeException("Avro Maps must have String keys, invalid type: %s".formatted(this.type.getKeyType()));
            }
            valueBlockPositionEncoder = createBlockPositionEncoder(schema.getValueType(), avroTypeManager, type.getValueType());
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            SqlMap sqlMap = type.getObject(block, position);
            keyBlockPositionEncoder.setBlock(sqlMap.getRawKeyBlock());
            valueBlockPositionEncoder.setBlock(sqlMap.getRawValueBlock());
            encoder.writeMapStart();
            encoder.setItemCount(sqlMap.getSize());

            int rawOffset = sqlMap.getRawOffset();
            for (int i = 0; i < sqlMap.getSize(); i++) {
                encoder.startItem();
                keyBlockPositionEncoder.encode(rawOffset + i, encoder);
                valueBlockPositionEncoder.encode(rawOffset + i, encoder);
            }
            encoder.writeMapEnd();
        }
    }

    private static class RecordBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final RowType type;
        private final BlockPositionEncoder[] channelEncoders;
        private final int[] fieldToChannel;

        // used only for nested row building
        public RecordBlockPositionEncoder(Optional nullIdx, Schema schema, AvroTypeManager avroTypeManager, RowType rowType)
                throws AvroTypeException
        {
            this(nullIdx,
                    schema,
                    avroTypeManager,
                    rowType.getFields().stream()
                            .map(RowType.Field::getName)
                            .map(optName -> optName.orElseThrow(() -> new IllegalArgumentException("Unable to use nested anonymous row type for avro writing")))
                            .collect(toImmutableList()),
                    rowType.getFields().stream()
                            .map(RowType.Field::getType)
                            .collect(toImmutableList()));
        }

        // used only for top level page building
        public RecordBlockPositionEncoder(Schema schema, AvroTypeManager avroTypeManager, List channelNames, List channelTypes)
                throws AvroTypeException
        {
            this(Optional.empty(), schema, avroTypeManager, channelNames, channelTypes);
        }

        private RecordBlockPositionEncoder(Optional nullIdx, Schema schema, AvroTypeManager avroTypeManager, List channelNames, List channelTypes)
                throws AvroTypeException
        {
            super(nullIdx);
            type = RowType.anonymous(requireNonNull(channelTypes, "channelTypes is null"));
            verify(requireNonNull(schema, "schema is null").getType() == Schema.Type.RECORD);
            verify(schema.getFields().size() == channelTypes.size(), "Must have channel for each record field");
            verify(requireNonNull(channelNames, "channelNames is null").size() == channelTypes.size(), "Must provide names for all channels");
            fieldToChannel = new int[schema.getFields().size()];
            channelEncoders = new BlockPositionEncoder[schema.getFields().size()];
            for (int i = 0; i < channelNames.size(); i++) {
                String fieldName = channelNames.get(i);
                Schema.Field avroField = requireNonNull(schema.getField(fieldName), "no field with name %s in schema %s".formatted(fieldName, schema));
                fieldToChannel[avroField.pos()] = i;
                channelEncoders[i] = createBlockPositionEncoder(avroField.schema(), avroTypeManager, channelTypes.get(i));
            }
            verify(IntStream.of(fieldToChannel).sum() == (schema.getFields().size() * (schema.getFields().size() - 1) / 2), "all channels must be accounted for");
        }

        // Used only for nested rows
        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            SqlRow sqlRow = type.getObject(block, position);
            for (int i = 0; i < channelEncoders.length; i++) {
                channelEncoders[i].setBlock(sqlRow.getRawFieldBlock(i));
            }
            int rawIndex = sqlRow.getRawIndex();
            encodeInternal(i -> rawIndex, encoder);
        }

        public void setChannelBlocksFromPage(Page page)
        {
            verify(page.getChannelCount() == channelEncoders.length, "Page must have channels equal to provided type list");
            for (int channel = 0; channel < page.getChannelCount(); channel++) {
                channelEncoders[channel].setBlock(page.getBlock(channel));
            }
        }

        public void encodePositionInEachChannel(int position, Encoder encoder)
                throws IOException
        {
            encodeInternal(ignore -> position, encoder);
        }

        private void encodeInternal(IntFunction channelToPosition, Encoder encoder)
                throws IOException
        {
            for (int channel : fieldToChannel) {
                BlockPositionEncoder channelEncoder = channelEncoders[channel];
                channelEncoder.encode(channelToPosition.apply(channel), encoder);
            }
        }
    }

    private static class UserDefinedBlockPositionEncoder
            extends BlockPositionEncoder
    {
        private final GenericDatumWriter datumWriter;
        private final BiFunction toAvroGeneric;

        public UserDefinedBlockPositionEncoder(Optional nullIdx, Schema schema, BiFunction toAvroGeneric)
        {
            super(nullIdx);
            datumWriter = new GenericDatumWriter<>(requireNonNull(schema, "schema is null"));
            this.toAvroGeneric = requireNonNull(toAvroGeneric, "toAvroGeneric is null");
        }

        @Override
        void encodeFromBlock(int position, Encoder encoder)
                throws IOException
        {
            datumWriter.write(toAvroGeneric.apply(block, position), encoder);
        }
    }
}