All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.PushAggregationThroughOuterJoin Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.isPushAggregationThroughOuterJoin;
import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct;
import static io.trino.sql.planner.optimizations.SymbolMapper.symbolMapper;
import static io.trino.sql.planner.plan.AggregationNode.globalAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.sql.planner.plan.Patterns.source;

/**
 * This optimizer pushes aggregations below outer joins when: the aggregation
 * is on top of the outer join, it groups by all columns in the outer table, and
 * the outer rows are guaranteed to be distinct.
 * 

* When the aggregation is pushed down, we still need to perform aggregations * on the null values that come out of the absent values in an outer * join. We add a cross join with a row of aggregations on null literals, * and coalesce the aggregation that results from the left outer join with * the result of the aggregation over nulls. *

* Example: *

 * - Filter ("nationkey" > "avg")
 *  - Aggregate(Group by: all columns from the left table, aggregation:
 *    avg("n2.nationkey"))
 *      - LeftJoin("regionkey" = "regionkey")
 *          - AssignUniqueId (nation)
 *              - Tablescan (nation)
 *          - Tablescan (nation)
 * 
*

* Is rewritten to: *

 * - Filter ("nationkey" > "avg")
 *  - project(regionkey, coalesce("avg", "avg_over_null")
 *      - CrossJoin
 *          - LeftJoin("regionkey" = "regionkey")
 *              - AssignUniqueId (nation)
 *                  - Tablescan (nation)
 *              - Aggregate(Group by: regionkey, aggregation:
 *                avg(nationkey))
 *                  - Tablescan (nation)
 *          - Aggregate
 *            avg(null_literal)
 *              - Values (null_literal)
 * 
*/ public class PushAggregationThroughOuterJoin implements Rule { private static final Capture JOIN = newCapture(); private static final Pattern PATTERN = aggregation() .with(source().matching(join().capturedAs(JOIN))); @Override public Pattern getPattern() { return PATTERN; } @Override public boolean isEnabled(Session session) { return isPushAggregationThroughOuterJoin(session); } @Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { // This rule doesn't deal with AggregationNode's hash symbol. Hash symbols are not yet present at this stage of optimization. checkArgument(aggregation.getHashSymbol().isEmpty(), "unexpected hash symbol"); JoinNode join = captures.get(JOIN); if (join.getFilter().isPresent() || !(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve) || !isAggregationOnSymbols(aggregation, getInnerTable(join))) { return Result.empty(); } List groupingKeys = join.getCriteria().stream() .map(join.getType() == JoinType.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight) .collect(toImmutableList()); AggregationNode rewrittenAggregation = AggregationNode.builderFrom(aggregation) .setSource(getInnerTable(join)) .setGroupingSets(singleGroupingSet(groupingKeys)) .setPreGroupedSymbols(ImmutableList.of()) .build(); JoinNode rewrittenJoin; if (join.getType() == JoinType.LEFT) { rewrittenJoin = new JoinNode( join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), join.getLeft().getOutputSymbols(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), // there are no duplicate rows possible since outer rows were guaranteed to be distinct false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost()); } else { rewrittenJoin = new JoinNode( join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), join.getRight().getOutputSymbols(), // there are no duplicate rows possible since outer rows were guaranteed to be distinct false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost()); } Optional resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator()); if (resultNode.isEmpty()) { return Result.empty(); } return Result.ofPlanNode(resultNode.get()); } private static PlanNode getInnerTable(JoinNode join) { checkState(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT, "expected LEFT or RIGHT JOIN"); PlanNode innerNode; if (join.getType() == JoinType.LEFT) { innerNode = join.getRight(); } else { innerNode = join.getLeft(); } return innerNode; } private static PlanNode getOuterTable(JoinNode join) { checkState(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT, "expected LEFT or RIGHT JOIN"); PlanNode outerNode; if (join.getType() == JoinType.LEFT) { outerNode = join.getLeft(); } else { outerNode = join.getRight(); } return outerNode; } private static boolean groupsOnAllColumns(AggregationNode node, List columns) { return node.getGroupingSetCount() == 1 && new HashSet<>(node.getGroupingKeys()).equals(new HashSet<>(columns)); } // When the aggregation is done after the join, there will be a null value that gets aggregated over // where rows did not exist in the inner table. For some aggregate functions, such as count, the result // of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results, // we add a coalesce function with the output of the new outer join and the aggregation performed over a single // null row. private Optional coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { // Create an aggregation node over a row of nulls. MappedAggregationInfo aggregationOverNullInfo = createAggregationOverNull( aggregationNode, symbolAllocator, idAllocator); AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation(); Map sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping(); // Do a cross join with the aggregation over null JoinNode crossJoin = new JoinNode( idAllocator.getNextId(), JoinType.INNER, outerJoin, aggregationOverNull, ImmutableList.of(), outerJoin.getOutputSymbols(), aggregationOverNull.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()); // Add coalesce expressions for all aggregation functions Assignments.Builder assignmentsBuilder = Assignments.builder(); for (Symbol symbol : outerJoin.getOutputSymbols()) { if (aggregationNode.getAggregations().containsKey(symbol)) { assignmentsBuilder.put(symbol, new Coalesce(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); } else { assignmentsBuilder.putIdentity(symbol); } } return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build())); } private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { // Create a values node that consists of a single row of nulls. // Map the output symbols from the referenceAggregation's source // to symbol references for the new values node. ImmutableList.Builder nullSymbols = ImmutableList.builder(); ImmutableList.Builder nullLiterals = ImmutableList.builder(); ImmutableMap.Builder sourcesSymbolMappingBuilder = ImmutableMap.builder(); for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) { Type type = sourceSymbol.type(); nullLiterals.add(new Constant(type, null)); Symbol nullSymbol = symbolAllocator.newSymbol("null", type); nullSymbols.add(nullSymbol); sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol); } ValuesNode nullRow = new ValuesNode( idAllocator.getNextId(), nullSymbols.build(), ImmutableList.of(new Row(nullLiterals.build()))); // For each aggregation function in the reference node, create a corresponding aggregation function // that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the // symbols in these new aggregations. ImmutableMap.Builder aggregationsSymbolMappingBuilder = ImmutableMap.builder(); ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder(); SymbolMapper mapper = symbolMapper(sourcesSymbolMappingBuilder.buildOrThrow()); for (Map.Entry entry : referenceAggregation.getAggregations().entrySet()) { Symbol aggregationSymbol = entry.getKey(); Aggregation overNullAggregation = mapper.map(entry.getValue()); Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getResolvedFunction().signature().getName().getFunctionName(), aggregationSymbol.type()); aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol); } Map aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow(); // create an aggregation node whose source is the null row. AggregationNode aggregationOverNullRow = singleAggregation( idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), globalAggregation()); return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping); } private static boolean isAggregationOnSymbols(AggregationNode aggregationNode, PlanNode source) { Set sourceSymbols = ImmutableSet.copyOf(source.getOutputSymbols()); return aggregationNode.getAggregations().values().stream() .allMatch(aggregation -> sourceSymbols.containsAll(SymbolsExtractor.extractUnique(aggregation))); } private static class MappedAggregationInfo { private final AggregationNode aggregationNode; private final Map symbolMapping; public MappedAggregationInfo(AggregationNode aggregationNode, Map symbolMapping) { this.aggregationNode = aggregationNode; this.symbolMapping = symbolMapping; } public Map getSymbolMapping() { return symbolMapping; } public AggregationNode getAggregation() { return aggregationNode; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy