All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.aggregation.GroupedAggregator Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.aggregation;

import com.google.common.primitives.Ints;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Step;

import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class GroupedAggregator
{
    private final GroupedAccumulator accumulator;
    private AggregationNode.Step step;
    private final Type intermediateType;
    private final Type finalType;
    private final int[] inputChannels;
    private final OptionalInt maskChannel;
    private final AggregationMaskBuilder maskBuilder;

    public GroupedAggregator(
            GroupedAccumulator accumulator,
            Step step,
            Type intermediateType,
            Type finalType,
            List inputChannels,
            OptionalInt maskChannel,
            AggregationMaskBuilder maskBuilder)
    {
        this.accumulator = requireNonNull(accumulator, "accumulator is null");
        this.step = requireNonNull(step, "step is null");
        this.intermediateType = requireNonNull(intermediateType, "intermediateType is null");
        this.finalType = requireNonNull(finalType, "finalType is null");
        this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null"));
        this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
        this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null");
        checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
    }

    public long getEstimatedSize()
    {
        return accumulator.getEstimatedSize();
    }

    public Type getType()
    {
        if (step.isOutputPartial()) {
            return intermediateType;
        }
        return finalType;
    }

    public void processPage(int groupCount, int[] groupIds, Page page)
    {
        accumulator.setGroupCount(groupCount);

        if (step.isInputRaw()) {
            Page arguments = page.getColumns(inputChannels);
            Optional maskBlock = Optional.empty();
            if (maskChannel.isPresent()) {
                maskBlock = Optional.of(page.getBlock(maskChannel.getAsInt()).getLoadedBlock());
            }
            AggregationMask mask = maskBuilder.buildAggregationMask(arguments, maskBlock);

            if (mask.isSelectNone()) {
                return;
            }
            // Unwrap any LazyBlock values before evaluating the accumulator
            arguments = arguments.getLoadedPage();
            accumulator.addInput(groupIds, arguments, mask);
        }
        else {
            accumulator.addIntermediate(groupIds, page.getBlock(inputChannels[0]));
        }
    }

    public void prepareFinal()
    {
        accumulator.prepareFinal();
    }

    public void evaluate(int groupId, BlockBuilder output)
    {
        if (step.isOutputPartial()) {
            accumulator.evaluateIntermediate(groupId, output);
        }
        else {
            accumulator.evaluateFinal(groupId, output);
        }
    }

    // todo this should return a new GroupedAggregator instead of modifying the existing object
    public void setSpillOutput()
    {
        step = AggregationNode.Step.partialOutput(step);
    }

    public Type getSpillType()
    {
        return intermediateType;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy