All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.PushPartialAggregationThroughExchange Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.preferPartialAggregation;
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.exchange;
import static io.trino.sql.planner.plan.Patterns.source;
import static java.util.Objects.requireNonNull;

public class PushPartialAggregationThroughExchange
        implements Rule
{
    private final PlannerContext plannerContext;

    public PushPartialAggregationThroughExchange(PlannerContext plannerContext)
    {
        this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
    }

    private static final Capture EXCHANGE_NODE = Capture.newCapture();

    private static final Pattern PATTERN = aggregation()
            .with(source().matching(
                    exchange()
                            .matching(node -> node.getOrderingScheme().isEmpty())
                            .capturedAs(EXCHANGE_NODE)));

    @Override
    public Pattern getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
    {
        ExchangeNode exchangeNode = captures.get(EXCHANGE_NODE);

        boolean decomposable = aggregationNode.isDecomposable(context.getSession(), plannerContext.getMetadata());

        if (aggregationNode.getStep() == SINGLE &&
                aggregationNode.hasEmptyGroupingSet() &&
                aggregationNode.hasNonEmptyGroupingSet() &&
                exchangeNode.getType() == REPARTITION) {
            // single-step aggregation w/ empty grouping sets in a partitioned stage, so we need a partial that will produce
            // the default intermediates for the empty grouping set that will be routed to the appropriate final aggregation.
            // TODO: technically, AddExchanges generates a broken plan that this rule "fixes"
            checkState(
                    decomposable,
                    "Distributed aggregation with empty grouping set requires partial but functions are not decomposable");
            return Result.ofPlanNode(split(aggregationNode, context));
        }

        if (!decomposable || !preferPartialAggregation(context.getSession())) {
            return Result.empty();
        }

        // partial aggregation can only be pushed through exchange that doesn't change
        // the cardinality of the stream (i.e., gather or repartition)
        if ((exchangeNode.getType() != GATHER && exchangeNode.getType() != REPARTITION) ||
                exchangeNode.getPartitioningScheme().isReplicateNullsAndAny()) {
            return Result.empty();
        }

        if (exchangeNode.getType() == REPARTITION) {
            // if partitioning columns are not a subset of grouping keys,
            // we can't push this through
            List partitioningColumns = exchangeNode.getPartitioningScheme()
                    .getPartitioning()
                    .getArguments()
                    .stream()
                    .filter(Partitioning.ArgumentBinding::isVariable)
                    .map(Partitioning.ArgumentBinding::getColumn)
                    .collect(Collectors.toList());

            if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) {
                return Result.empty();
            }
        }

        // currently, we only support plans that don't use pre-computed hash functions
        if (aggregationNode.getHashSymbol().isPresent() || exchangeNode.getPartitioningScheme().getHashColumn().isPresent()) {
            return Result.empty();
        }

        return switch (aggregationNode.getStep()) {
            case SINGLE -> Result.ofPlanNode(split(aggregationNode, context)); // Split it into a FINAL on top of a PARTIAL and
            case PARTIAL -> Result.ofPlanNode(pushPartial(aggregationNode, exchangeNode, context)); // Push it underneath each branch of the exchange
            default -> Result.empty();
        };
    }

    private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Context context)
    {
        List partials = new ArrayList<>();
        for (int i = 0; i < exchange.getSources().size(); i++) {
            PlanNode source = exchange.getSources().get(i);

            SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
            for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) {
                Symbol output = exchange.getOutputSymbols().get(outputIndex);
                Symbol input = exchange.getInputs().get(i).get(outputIndex);
                if (!output.equals(input)) {
                    mappingsBuilder.put(output, input);
                }
            }

            SymbolMapper symbolMapper = mappingsBuilder.build();
            AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator().getNextId());

            Assignments.Builder assignments = Assignments.builder();

            for (Symbol output : aggregation.getOutputSymbols()) {
                Symbol input = symbolMapper.map(output);
                assignments.put(output, input.toSymbolReference());
            }
            partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
        }

        for (PlanNode node : partials) {
            verify(aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
        }

        // Since this exchange source is now guaranteed to have the same symbols as the inputs to the partial
        // aggregation, we don't need to rewrite symbols in the partitioning function
        PartitioningScheme partitioning = new PartitioningScheme(
                exchange.getPartitioningScheme().getPartitioning(),
                aggregation.getOutputSymbols(),
                exchange.getPartitioningScheme().getHashColumn(),
                exchange.getPartitioningScheme().isReplicateNullsAndAny(),
                exchange.getPartitioningScheme().getBucketToPartition(),
                exchange.getPartitioningScheme().getPartitionCount());

        return new ExchangeNode(
                context.getIdAllocator().getNextId(),
                exchange.getType(),
                exchange.getScope(),
                partitioning,
                partials,
                ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())),
                Optional.empty());
    }

    private PlanNode split(AggregationNode node, Context context)
    {
        // otherwise, add a partial and final with an exchange in between
        ImmutableMap.Builder intermediateAggregation = ImmutableMap.builder();
        ImmutableMap.Builder finalAggregation = ImmutableMap.builder();
        for (Map.Entry entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation originalAggregation = entry.getValue();
            ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
            AggregationFunctionMetadata functionMetadata = plannerContext.getMetadata().getAggregationFunctionMetadata(context.getSession(), resolvedFunction);
            List intermediateTypes = functionMetadata.getIntermediateTypes().stream()
                    .map(plannerContext.getTypeManager()::getType)
                    .collect(toImmutableList());
            Type intermediateType = intermediateTypes.size() == 1 ? intermediateTypes.get(0) : RowType.anonymous(intermediateTypes);
            Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(resolvedFunction.signature().getName().getFunctionName(), intermediateType);

            checkState(originalAggregation.getOrderingScheme().isEmpty(), "Aggregate with ORDER BY does not support partial aggregation");
            intermediateAggregation.put(
                    intermediateSymbol,
                    new AggregationNode.Aggregation(
                            resolvedFunction,
                            originalAggregation.getArguments(),
                            originalAggregation.isDistinct(),
                            originalAggregation.getFilter(),
                            originalAggregation.getOrderingScheme(),
                            originalAggregation.getMask()));

            // rewrite final aggregation in terms of intermediate function
            finalAggregation.put(
                    entry.getKey(),
                    new AggregationNode.Aggregation(
                            resolvedFunction,
                            ImmutableList.builder()
                                    .add(intermediateSymbol.toSymbolReference())
                                    .addAll(originalAggregation.getArguments().stream()
                                            .filter(Lambda.class::isInstance)
                                            .collect(toImmutableList()))
                                    .build(),
                            false,
                            Optional.empty(),
                            Optional.empty(),
                            Optional.empty()));
        }

        PlanNode partial = new AggregationNode(
                context.getIdAllocator().getNextId(),
                node.getSource(),
                intermediateAggregation.buildOrThrow(),
                node.getGroupingSets(),
                // preGroupedSymbols reflect properties of the input. Splitting the aggregation and pushing partial aggregation
                // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
                ImmutableList.of(),
                PARTIAL,
                node.getHashSymbol(),
                node.getGroupIdSymbol());

        return new AggregationNode(
                node.getId(),
                partial,
                finalAggregation.buildOrThrow(),
                node.getGroupingSets(),
                // preGroupedSymbols reflect properties of the input. Splitting the aggregation and pushing partial aggregation
                // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
                ImmutableList.of(),
                FINAL,
                node.getHashSymbol(),
                node.getGroupIdSymbol());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy