All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.plugin.hive.HivePageSink Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.plugin.hive;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.metastore.HiveType;
import io.trino.plugin.hive.HiveWritableTableHandle.BucketInfo;
import io.trino.plugin.hive.acid.AcidTransaction;
import io.trino.plugin.hive.util.HiveBucketing;
import io.trino.spi.Page;
import io.trino.spi.PageIndexer;
import io.trino.spi.PageIndexerFactory;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.IntArrayBlockBuilder;
import io.trino.spi.connector.ConnectorMergeSink;
import io.trino.spi.connector.ConnectorPageSink;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.Type;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;

import java.io.Closeable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.toCompletableFuture;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.trino.plugin.hive.HiveErrorCode.HIVE_TOO_MANY_OPEN_PARTITIONS;
import static io.trino.plugin.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR;
import static io.trino.spi.type.IntegerType.INTEGER;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public class HivePageSink
        implements ConnectorPageSink, ConnectorMergeSink
{
    private static final Logger LOG = Logger.get(HivePageSink.class);
    private static final int MAX_PAGE_POSITIONS = 4096;

    private final HiveWriterFactory writerFactory;

    private final boolean isTransactional;
    private final int[] dataColumnInputIndex; // ordinal of columns (not counting sample weight column)
    private final int[] partitionColumnsInputIndex; // ordinal of columns (not counting sample weight column)

    private final int[] bucketColumns;
    private final HiveBucketFunction bucketFunction;

    private final HiveWriterPagePartitioner pagePartitioner;

    private final int maxOpenWriters;
    private final ListeningExecutorService writeVerificationExecutor;

    private final JsonCodec partitionUpdateCodec;

    private final List writers = new ArrayList<>();

    private final long targetMaxFileSize;
    private final long idleWriterMinFileSize;
    private final List closedWriterRollbackActions = new ArrayList<>();
    private final List partitionUpdates = new ArrayList<>();
    private final List> verificationTasks = new ArrayList<>();
    private final List activeWriters = new ArrayList<>();

    private final boolean isMergeSink;
    private long writtenBytes;
    private long memoryUsage;
    private long validationCpuNanos;
    private long currentOpenWriters;

    public HivePageSink(
            HiveWriterFactory writerFactory,
            List inputColumns,
            AcidTransaction acidTransaction,
            Optional bucketInfo,
            PageIndexerFactory pageIndexerFactory,
            int maxOpenWriters,
            ListeningExecutorService writeVerificationExecutor,
            JsonCodec partitionUpdateCodec,
            ConnectorSession session)
    {
        this.writerFactory = requireNonNull(writerFactory, "writerFactory is null");

        requireNonNull(inputColumns, "inputColumns is null");

        requireNonNull(pageIndexerFactory, "pageIndexerFactory is null");

        this.isTransactional = acidTransaction.isTransactional();
        this.maxOpenWriters = maxOpenWriters;
        this.writeVerificationExecutor = requireNonNull(writeVerificationExecutor, "writeVerificationExecutor is null");
        this.partitionUpdateCodec = requireNonNull(partitionUpdateCodec, "partitionUpdateCodec is null");

        this.isMergeSink = acidTransaction.isMerge();
        requireNonNull(bucketInfo, "bucketInfo is null");
        this.pagePartitioner = new HiveWriterPagePartitioner(
                inputColumns,
                bucketInfo.isPresent(),
                pageIndexerFactory);

        // determine the input index of the partition columns and data columns
        // and determine the input index and type of bucketing columns
        ImmutableList.Builder partitionColumns = ImmutableList.builder();
        ImmutableList.Builder dataColumnsInputIndex = ImmutableList.builder();
        Object2IntMap dataColumnNameToIdMap = new Object2IntOpenHashMap<>();
        Map dataColumnNameToTypeMap = new HashMap<>();
        for (int inputIndex = 0; inputIndex < inputColumns.size(); inputIndex++) {
            HiveColumnHandle column = inputColumns.get(inputIndex);
            if (column.isPartitionKey()) {
                partitionColumns.add(inputIndex);
            }
            else {
                dataColumnsInputIndex.add(inputIndex);
                dataColumnNameToIdMap.put(column.getName(), inputIndex);
                dataColumnNameToTypeMap.put(column.getName(), column.getHiveType());
            }
        }
        this.partitionColumnsInputIndex = Ints.toArray(partitionColumns.build());
        this.dataColumnInputIndex = Ints.toArray(dataColumnsInputIndex.build());

        if (bucketInfo.isPresent()) {
            HiveBucketing.BucketingVersion bucketingVersion = bucketInfo.get().bucketingVersion();
            int bucketCount = bucketInfo.get().bucketCount();
            bucketColumns = bucketInfo.get().bucketedBy().stream()
                    .mapToInt(dataColumnNameToIdMap::getInt)
                    .toArray();
            List bucketColumnTypes = bucketInfo.get().bucketedBy().stream()
                    .map(dataColumnNameToTypeMap::get)
                    .collect(toList());
            bucketFunction = new HiveBucketFunction(bucketingVersion, bucketCount, bucketColumnTypes);
        }
        else {
            bucketColumns = null;
            bucketFunction = null;
        }

        this.targetMaxFileSize = HiveSessionProperties.getTargetMaxFileSize(session).toBytes();
        this.idleWriterMinFileSize = HiveSessionProperties.getIdleWriterMinFileSize(session).toBytes();
    }

    @Override
    public long getCompletedBytes()
    {
        return writtenBytes;
    }

    @Override
    public long getMemoryUsage()
    {
        return memoryUsage;
    }

    @Override
    public long getValidationCpuNanos()
    {
        return validationCpuNanos;
    }

    @Override
    public CompletableFuture> finish()
    {
        return toCompletableFuture(isMergeSink ? doMergeSinkFinish() : doInsertSinkFinish());
    }

    private ListenableFuture> doMergeSinkFinish()
    {
        ImmutableList.Builder resultSlices = ImmutableList.builder();
        for (HiveWriter writer : writers) {
            if (writer == null) {
                continue;
            }
            writer.commit();
            MergeFileWriter mergeFileWriter = (MergeFileWriter) writer.getFileWriter();
            PartitionUpdateAndMergeResults results = mergeFileWriter.getPartitionUpdateAndMergeResults(writer.getPartitionUpdate());
            resultSlices.add(wrappedBuffer(PartitionUpdateAndMergeResults.CODEC.toJsonBytes(results)));
        }
        List result = resultSlices.build();
        writtenBytes = writers.stream()
                .filter(Objects::nonNull)
                .mapToLong(HiveWriter::getWrittenBytes)
                .sum();
        return Futures.immediateFuture(result);
    }

    private ListenableFuture> doInsertSinkFinish()
    {
        for (int writerIndex = 0; writerIndex < writers.size(); writerIndex++) {
            closeWriter(writerIndex);
        }
        writers.clear();

        List result = ImmutableList.copyOf(partitionUpdates);

        if (verificationTasks.isEmpty()) {
            return Futures.immediateFuture(result);
        }

        try {
            List> futures = writeVerificationExecutor.invokeAll(verificationTasks).stream()
                    .map(future -> (ListenableFuture) future)
                    .collect(toList());
            return Futures.transform(Futures.allAsList(futures), input -> result, directExecutor());
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    @Override
    public void abort()
    {
        List rollbackActions = Streams.concat(
                        writers.stream()
                                // writers can contain nulls if an exception is thrown when doAppend expands the writer list
                                .filter(Objects::nonNull)
                                .map(writer -> writer::rollback),
                        closedWriterRollbackActions.stream())
                .collect(toImmutableList());
        RuntimeException rollbackException = null;
        for (Closeable rollbackAction : rollbackActions) {
            try {
                rollbackAction.close();
            }
            catch (Throwable t) {
                if (rollbackException == null) {
                    rollbackException = new TrinoException(HIVE_WRITER_CLOSE_ERROR, "Error rolling back write to Hive");
                }
                rollbackException.addSuppressed(t);
            }
        }
        if (rollbackException != null) {
            throw rollbackException;
        }
    }

    @Override
    public CompletableFuture appendPage(Page page)
    {
        int writeOffset = 0;
        while (writeOffset < page.getPositionCount()) {
            Page chunk = page.getRegion(writeOffset, min(page.getPositionCount() - writeOffset, MAX_PAGE_POSITIONS));
            writeOffset += chunk.getPositionCount();
            writePage(chunk);
        }
        return NOT_BLOCKED;
    }

    private void writePage(Page page)
    {
        int[] writerIndexes = getWriterIndexes(page);

        // position count for each writer
        int[] sizes = new int[writers.size()];
        for (int index : writerIndexes) {
            sizes[index]++;
        }

        // record which positions are used by which writer
        int[][] writerPositions = new int[writers.size()][];
        int[] counts = new int[writers.size()];

        for (int position = 0; position < page.getPositionCount(); position++) {
            int index = writerIndexes[position];

            int count = counts[index];
            if (count == 0) {
                writerPositions[index] = new int[sizes[index]];
            }
            writerPositions[index][count] = position;
            counts[index] = count + 1;
        }

        // invoke the writers
        Page dataPage = getDataPage(page);
        for (int index = 0; index < writerPositions.length; index++) {
            int[] positions = writerPositions[index];
            if (positions == null) {
                continue;
            }

            // If write is partitioned across multiple writers, filter page using dictionary blocks
            Page pageForWriter = dataPage;
            if (positions.length != dataPage.getPositionCount()) {
                verify(positions.length == counts[index]);
                pageForWriter = pageForWriter.getPositions(positions, 0, positions.length);
            }

            HiveWriter writer = writers.get(index);
            verify(writer != null, "Expected writer at index %s", index);

            long currentWritten = writer.getWrittenBytes();
            long currentMemory = writer.getMemoryUsage();

            writer.append(pageForWriter);

            writtenBytes += (writer.getWrittenBytes() - currentWritten);
            memoryUsage += (writer.getMemoryUsage() - currentMemory);
            // Mark this writer as active (i.e. not idle)
            activeWriters.set(index, true);
        }
    }

    private void closeWriter(int writerIndex)
    {
        HiveWriter writer = writers.get(writerIndex);
        if (writer == null) {
            return;
        }

        long currentWritten = writer.getWrittenBytes();
        long currentMemory = writer.getMemoryUsage();

        closedWriterRollbackActions.add(writer.commit());

        writtenBytes += (writer.getWrittenBytes() - currentWritten);
        memoryUsage -= currentMemory;
        validationCpuNanos += writer.getValidationCpuNanos();

        writers.set(writerIndex, null);
        currentOpenWriters--;

        PartitionUpdate partitionUpdate = writer.getPartitionUpdate();
        partitionUpdates.add(wrappedBuffer(partitionUpdateCodec.toJsonBytes(partitionUpdate)));
    }

    @Override
    public void closeIdleWriters()
    {
        // For transactional tables we don't want to split output files because there is an explicit or implicit bucketing
        // and file names have no random component (e.g. bucket_00000)
        if (bucketFunction != null || isTransactional) {
            return;
        }

        for (int writerIndex = 0; writerIndex < writers.size(); writerIndex++) {
            HiveWriter writer = writers.get(writerIndex);
            if (activeWriters.get(writerIndex) || writer == null || writer.getWrittenBytes() <= idleWriterMinFileSize) {
                activeWriters.set(writerIndex, false);
                continue;
            }
            LOG.debug("Closing writer %s with %s bytes written", writerIndex, writer.getWrittenBytes());
            closeWriter(writerIndex);
        }
    }

    private int[] getWriterIndexes(Page page)
    {
        Page partitionColumns = extractColumns(page, partitionColumnsInputIndex);
        Block bucketBlock = buildBucketBlock(page);
        int[] writerIndexes = pagePartitioner.partitionPage(partitionColumns, bucketBlock);

        // expand writers list to new size
        while (writers.size() <= pagePartitioner.getMaxIndex()) {
            writers.add(null);
            activeWriters.add(false);
        }

        // create missing writers
        for (int position = 0; position < page.getPositionCount(); position++) {
            int writerIndex = writerIndexes[position];
            HiveWriter writer = writers.get(writerIndex);
            if (writer != null) {
                // if current file not too big continue with the current writer
                // for transactional tables we don't want to split output files because there is an explicit or implicit bucketing
                // and file names have no random component (e.g. bucket_00000)
                if (bucketFunction != null || isTransactional || writer.getWrittenBytes() <= targetMaxFileSize) {
                    continue;
                }
                // close current writer
                closeWriter(writerIndex);
            }

            OptionalInt bucketNumber = OptionalInt.empty();
            if (bucketBlock != null) {
                bucketNumber = OptionalInt.of(INTEGER.getInt(bucketBlock, position));
            }

            writer = writerFactory.createWriter(partitionColumns, position, bucketNumber);

            writers.set(writerIndex, writer);
            currentOpenWriters++;
            memoryUsage += writer.getMemoryUsage();
        }
        verify(writers.size() == pagePartitioner.getMaxIndex() + 1);

        if (currentOpenWriters > maxOpenWriters) {
            throw new TrinoException(HIVE_TOO_MANY_OPEN_PARTITIONS, format("Exceeded limit of %s open writers for partitions/buckets", maxOpenWriters));
        }

        return writerIndexes;
    }

    private Page getDataPage(Page page)
    {
        if (isMergeSink) {
            return page;
        }
        Block[] blocks = new Block[dataColumnInputIndex.length];
        for (int i = 0; i < dataColumnInputIndex.length; i++) {
            int dataColumn = dataColumnInputIndex[i];
            blocks[i] = page.getBlock(dataColumn);
        }
        return new Page(page.getPositionCount(), blocks);
    }

    private Block buildBucketBlock(Page page)
    {
        if (bucketFunction == null) {
            return null;
        }

        IntArrayBlockBuilder bucketColumnBuilder = new IntArrayBlockBuilder(null, page.getPositionCount());
        Page bucketColumnsPage = extractColumns(page, bucketColumns);
        for (int position = 0; position < page.getPositionCount(); position++) {
            int bucket = bucketFunction.getBucket(bucketColumnsPage, position);
            INTEGER.writeInt(bucketColumnBuilder, bucket);
        }
        return bucketColumnBuilder.build();
    }

    private static Page extractColumns(Page page, int[] columns)
    {
        Block[] blocks = new Block[columns.length];
        for (int i = 0; i < columns.length; i++) {
            int dataColumn = columns[i];
            blocks[i] = page.getBlock(dataColumn);
        }
        return new Page(page.getPositionCount(), blocks);
    }

    @Override
    public void storeMergedRows(Page page)
    {
        checkArgument(isMergeSink, "isMergeSink is false");
        appendPage(page);
    }

    private static class HiveWriterPagePartitioner
    {
        private final PageIndexer pageIndexer;

        public HiveWriterPagePartitioner(
                List inputColumns,
                boolean bucketed,
                PageIndexerFactory pageIndexerFactory)
        {
            requireNonNull(inputColumns, "inputColumns is null");
            requireNonNull(pageIndexerFactory, "pageIndexerFactory is null");

            List partitionColumnTypes = inputColumns.stream()
                    .filter(HiveColumnHandle::isPartitionKey)
                    .map(HiveColumnHandle::getType)
                    .collect(toList());

            if (bucketed) {
                partitionColumnTypes.add(INTEGER);
            }

            this.pageIndexer = pageIndexerFactory.createPageIndexer(partitionColumnTypes);
        }

        public int[] partitionPage(Page partitionColumns, Block bucketBlock)
        {
            if (bucketBlock != null) {
                Block[] blocks = new Block[partitionColumns.getChannelCount() + 1];
                for (int i = 0; i < partitionColumns.getChannelCount(); i++) {
                    blocks[i] = partitionColumns.getBlock(i);
                }
                blocks[blocks.length - 1] = bucketBlock;
                partitionColumns = new Page(partitionColumns.getPositionCount(), blocks);
            }
            return pageIndexer.indexPage(partitionColumns);
        }

        public int getMaxIndex()
        {
            return pageIndexer.getMaxIndex();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy