All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.aggregation.partial.SkipAggregationBuilder Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.aggregation.partial;

import com.google.common.util.concurrent.ListenableFuture;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.CompletedWork;
import io.trino.operator.Work;
import io.trino.operator.WorkProcessor;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.GroupedAggregator;
import io.trino.operator.aggregation.builder.HashAggregationBuilder;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import jakarta.annotation.Nullable;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
 * {@link HashAggregationBuilder} that does not aggregate input rows at all.
 * It passes the input pages, augmented with initial accumulator state to the output.
 * It can only be used at the partial aggregation step as it relies on rows be aggregated at the final step.
 */
public class SkipAggregationBuilder
        implements HashAggregationBuilder
{
    private final LocalMemoryContext memoryContext;
    private final List groupedAggregators;
    @Nullable
    private Page currentPage;
    private final int[] hashChannels;

    public SkipAggregationBuilder(
            List groupByChannels,
            Optional inputHashChannel,
            List aggregatorFactories,
            LocalMemoryContext memoryContext)
    {
        this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
        this.groupedAggregators = aggregatorFactories.stream()
                .map(AggregatorFactory::createGroupedAggregator)
                .collect(toImmutableList());
        this.hashChannels = new int[groupByChannels.size() + (inputHashChannel.isPresent() ? 1 : 0)];
        for (int i = 0; i < groupByChannels.size(); i++) {
            hashChannels[i] = groupByChannels.get(i);
        }
        inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex);
    }

    @Override
    public Work processPage(Page page)
    {
        checkArgument(currentPage == null);
        currentPage = page;
        return new CompletedWork<>();
    }

    @Override
    public WorkProcessor buildResult()
    {
        if (currentPage == null) {
            return WorkProcessor.of();
        }

        Page result = buildOutputPage(currentPage);
        currentPage = null;
        return WorkProcessor.of(result);
    }

    @Override
    public boolean isFull()
    {
        return currentPage != null;
    }

    @Override
    public void updateMemory()
    {
        if (currentPage != null) {
            memoryContext.setBytes(currentPage.getSizeInBytes());
        }
    }

    @Override
    public void close()
    {
    }

    @Override
    public ListenableFuture startMemoryRevoke()
    {
        throw new UnsupportedOperationException("startMemoryRevoke not supported for SkipAggregationBuilder");
    }

    @Override
    public void finishMemoryRevoke()
    {
        throw new UnsupportedOperationException("finishMemoryRevoke not supported for SkipAggregationBuilder");
    }

    private Page buildOutputPage(Page page)
    {
        populateInitialAccumulatorState(page);

        BlockBuilder[] outputBuilders = serializeAccumulatorState(page.getPositionCount());

        return constructOutputPage(page, outputBuilders);
    }

    private void populateInitialAccumulatorState(Page page)
    {
        int[] groupIds = new int[page.getPositionCount()];
        for (int position = 0; position < page.getPositionCount(); position++) {
            groupIds[position] = position;
        }

        for (GroupedAggregator groupedAggregator : groupedAggregators) {
            groupedAggregator.processPage(page.getPositionCount(), groupIds, page);
        }
    }

    private BlockBuilder[] serializeAccumulatorState(int positionCount)
    {
        BlockBuilder[] outputBuilders = new BlockBuilder[groupedAggregators.size()];
        for (int i = 0; i < outputBuilders.length; i++) {
            outputBuilders[i] = groupedAggregators.get(i).getType().createBlockBuilder(null, positionCount);
        }

        for (int position = 0; position < positionCount; position++) {
            for (int i = 0; i < groupedAggregators.size(); i++) {
                GroupedAggregator groupedAggregator = groupedAggregators.get(i);
                BlockBuilder output = outputBuilders[i];
                groupedAggregator.evaluate(position, output);
            }
        }
        return outputBuilders;
    }

    private Page constructOutputPage(Page page, BlockBuilder[] outputBuilders)
    {
        Block[] outputBlocks = new Block[hashChannels.length + outputBuilders.length];
        for (int i = 0; i < hashChannels.length; i++) {
            outputBlocks[i] = page.getBlock(hashChannels[i]);
        }
        for (int i = 0; i < outputBuilders.length; i++) {
            outputBlocks[hashChannels.length + i] = outputBuilders[i].build();
        }
        return new Page(page.getPositionCount(), outputBlocks);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy