All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.PushAggregationIntoTableScan Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SortItem;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors;
import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer;
import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown;
import static io.trino.sql.planner.plan.Patterns.Aggregation.step;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.Patterns.tableScan;
import static java.util.Objects.requireNonNull;

public class PushAggregationIntoTableScan
        implements Rule
{
    private static final Capture TABLE_SCAN = newCapture();

    private static final Pattern PATTERN =
            aggregation()
                    .with(step().equalTo(AggregationNode.Step.SINGLE))
                    // skip arguments that are, for instance, lambda expressions
                    .matching(PushAggregationIntoTableScan::allArgumentsAreSimpleReferences)
                    .matching(node -> node.getGroupingSets().getGroupingSetCount() <= 1)
                    .matching(PushAggregationIntoTableScan::hasNoMasks)
                    .with(source().matching(tableScan().capturedAs(TABLE_SCAN)));

    private final PlannerContext plannerContext;

    public PushAggregationIntoTableScan(PlannerContext plannerContext)
    {
        this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public Pattern getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isAllowPushdownIntoConnectors(session);
    }

    private static boolean allArgumentsAreSimpleReferences(AggregationNode node)
    {
        return node.getAggregations()
                .values().stream()
                .flatMap(aggregation -> aggregation.getArguments().stream())
                .allMatch(Reference.class::isInstance);
    }

    private static boolean hasNoMasks(AggregationNode node)
    {
        return node.getAggregations()
                .values().stream()
                .allMatch(aggregation -> aggregation.getMask().isEmpty());
    }

    @Override
    public Result apply(AggregationNode node, Captures captures, Context context)
    {
        return pushAggregationIntoTableScan(plannerContext, context, node, captures.get(TABLE_SCAN), node.getAggregations(), node.getGroupingSets().getGroupingKeys())
                .map(Rule.Result::ofPlanNode)
                .orElseGet(Rule.Result::empty);
    }

    public static Optional pushAggregationIntoTableScan(
            PlannerContext plannerContext,
            Context context,
            PlanNode aggregationNode,
            TableScanNode tableScan,
            Map aggregations,
            List groupingKeys)
    {
        Session session = context.getSession();

        if (groupingKeys.isEmpty() && aggregations.isEmpty()) {
            // Global aggregation with no aggregate functions. No point to push this down into connector.
            return Optional.empty();
        }

        Map assignments = tableScan.getAssignments()
                .entrySet().stream()
                .collect(toImmutableMap(entry -> entry.getKey().name(), Entry::getValue));

        List> aggregationsList = ImmutableList.copyOf(aggregations.entrySet());

        List aggregateFunctions = aggregationsList.stream()
                .map(Entry::getValue)
                .map(PushAggregationIntoTableScan::toAggregateFunction)
                .collect(toImmutableList());

        List aggregationOutputSymbols = aggregationsList.stream()
                .map(Entry::getKey)
                .collect(toImmutableList());

        List groupByColumns = groupingKeys.stream()
                .map(groupByColumn -> assignments.get(groupByColumn.name()))
                .collect(toImmutableList());

        Optional> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(
                session,
                tableScan.getTable(),
                aggregateFunctions,
                assignments,
                ImmutableList.of(groupByColumns));

        if (aggregationPushdownResult.isEmpty()) {
            return Optional.empty();
        }

        AggregationApplicationResult result = aggregationPushdownResult.get();

        // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
        ImmutableList.Builder newScanOutputs = ImmutableList.builder();
        newScanOutputs.addAll(tableScan.getOutputSymbols());

        ImmutableBiMap.Builder newScanAssignments = ImmutableBiMap.builder();
        newScanAssignments.putAll(tableScan.getAssignments());

        Map variableMappings = new HashMap<>();

        for (Assignment assignment : result.getAssignments()) {
            Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());

            newScanOutputs.add(symbol);
            newScanAssignments.put(symbol, assignment.getColumn());
            variableMappings.put(assignment.getVariable(), symbol);
        }

        List newProjections = result.getProjections().stream()
                .map(expression -> {
                    Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings);
                    // ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop
                    // by ensuring expression is optimized.
                    return newOptimizer(plannerContext).process(translated, session, ImmutableMap.of()).orElse(translated);
                })
                .collect(toImmutableList());

        verify(aggregationOutputSymbols.size() == newProjections.size());

        Assignments.Builder assignmentBuilder = Assignments.builder();
        IntStream.range(0, aggregationOutputSymbols.size())
                .forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));

        ImmutableBiMap scanAssignments = newScanAssignments.build();
        ImmutableBiMap columnHandleToSymbol = scanAssignments.inverse();
        // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
        groupingKeys
                .forEach(groupBySymbol -> {
                    // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
                    // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
                    ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.name());
                    ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
                    assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
                });

        return Optional.of(
                new ProjectNode(
                        context.getIdAllocator().getNextId(),
                        new TableScanNode(
                                context.getIdAllocator().getNextId(),
                                result.getHandle(),
                                newScanOutputs.build(),
                                scanAssignments,
                                TupleDomain.all(),
                                deriveTableStatisticsForPushdown(context.getStatsProvider(), session, result.isPrecalculateStatistics(), aggregationNode),
                                tableScan.isUpdateTarget(),
                                // table scan partitioning might have changed with new table handle
                                Optional.empty()),
                        assignmentBuilder.build()));
    }

    private static AggregateFunction toAggregateFunction(AggregationNode.Aggregation aggregation)
    {
        BoundSignature signature = aggregation.getResolvedFunction().signature();

        ImmutableList.Builder arguments = ImmutableList.builder();
        for (int i = 0; i < aggregation.getArguments().size(); i++) {
            Reference argument = (Reference) aggregation.getArguments().get(i);
            arguments.add(new Variable(argument.name(), signature.getArgumentTypes().get(i)));
        }

        Optional orderingScheme = aggregation.getOrderingScheme();
        Optional> sortBy = orderingScheme.map(OrderingScheme::toSortItems);

        Optional filter = aggregation.getFilter()
                .map(symbol -> new Variable(symbol.name(), symbol.type()));

        return new AggregateFunction(
                signature.getName().getFunctionName(),
                signature.getReturnType(),
                arguments.build(),
                sortBy.orElse(ImmutableList.of()),
                aggregation.isDistinct(),
                filter);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy