All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.PushProjectionThroughExchange Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.plan.Patterns.exchange;
import static io.trino.sql.planner.plan.Patterns.project;
import static io.trino.sql.planner.plan.Patterns.source;

/**
 * Transforms:
 * 
 *  Project(x = e1, y = e2)
 *    Exchange()
 *      Source(a, b, c)
 *  
* to: *
 *  Exchange()
 *    Project(x = e1, y = e2)
 *      Source(a, b, c)
 *  
* Or if Exchange needs symbols from Source for partitioning, ordering or as hash symbol to: *
 *  Project(x, y)
 *    Exchange()
 *      Project(x = e1, y = e2, a)
 *        Source(a, b, c)
 *  
* To avoid looping this optimizer will not be fired if upper Project contains just symbol references. */ public class PushProjectionThroughExchange implements Rule { private static final Capture CHILD = newCapture(); private static final Pattern PATTERN = project() .matching(project -> !isSymbolToSymbolProjection(project)) .with(source().matching(exchange().capturedAs(CHILD))); @Override public Pattern getPattern() { return PATTERN; } @Override public Result apply(ProjectNode project, Captures captures, Context context) { ExchangeNode exchange = captures.get(CHILD); Set partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getColumns(); ImmutableList.Builder newSourceBuilder = ImmutableList.builder(); ImmutableList.Builder> inputsBuilder = ImmutableList.builder(); for (int i = 0; i < exchange.getSources().size(); i++) { Map outputToInputMap = mapExchangeOutputToInput(exchange, i); Assignments.Builder projections = Assignments.builder(); ImmutableList.Builder inputs = ImmutableList.builder(); // Need to retain the partition keys for the exchange partitioningColumns.stream() .map(outputToInputMap::get) .forEach(inputSymbol -> { projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); // Need to retain the hash symbol for the exchange exchange.getPartitioningScheme().getHashColumn() .map(outputToInputMap::get) .ifPresent(inputSymbol -> { projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); if (exchange.getOrderingScheme().isPresent()) { // Need to retain ordering columns for the exchange exchange.getOrderingScheme().get().getOrderBy().stream() // Do not duplicate symbols in inputs list .filter(symbol -> !partitioningColumns.contains(symbol)) .map(outputToInputMap::get) .forEach(inputSymbol -> { projections.putIdentity(inputSymbol); inputs.add(inputSymbol); }); } ImmutableSet.Builder outputBuilder = ImmutableSet.builder(); partitioningColumns.forEach(outputBuilder::add); exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add); exchange.getOrderingScheme().ifPresent(orderingScheme -> outputBuilder.addAll(orderingScheme.getOrderBy())); Set partitioningHashAndOrderingOutputs = outputBuilder.build(); Map translationMap = outputToInputMap.entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); for (Map.Entry projection : project.getAssignments().entrySet()) { // Skip identity projection if symbol is in outputs already if (partitioningHashAndOrderingOutputs.contains(projection.getKey())) { continue; } Expression translatedExpression = inlineSymbols(translationMap, projection.getValue()); Type type = context.getSymbolAllocator().getTypes().get(projection.getKey()); Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); projections.put(symbol, translatedExpression); inputs.add(symbol); } newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build())); inputsBuilder.add(inputs.build()); } // Construct the output symbols in the same order as the sources ImmutableList.Builder outputBuilder = ImmutableList.builder(); partitioningColumns.forEach(outputBuilder::add); exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add); if (exchange.getOrderingScheme().isPresent()) { exchange.getOrderingScheme().get().getOrderBy().stream() // Do not duplicate symbols in outputs list (for consistency with inputs lists) .filter(symbol -> !partitioningColumns.contains(symbol)) .forEach(outputBuilder::add); } Set partitioningHashAndOrderingOutputs = ImmutableSet.copyOf(outputBuilder.build()); for (Map.Entry projection : project.getAssignments().entrySet()) { // Do not add output for identity projection if symbol is in outputs already if (partitioningHashAndOrderingOutputs.contains(projection.getKey())) { continue; } outputBuilder.add(projection.getKey()); } // outputBuilder contains all partition and hash symbols so simply swap the output layout PartitioningScheme partitioningScheme = new PartitioningScheme( exchange.getPartitioningScheme().getPartitioning(), outputBuilder.build(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition(), exchange.getPartitioningScheme().getPartitionCount()); PlanNode result = new ExchangeNode( exchange.getId(), exchange.getType(), exchange.getScope(), partitioningScheme, newSourceBuilder.build(), inputsBuilder.build(), exchange.getOrderingScheme()); // we need to strip unnecessary symbols (hash, partitioning columns). return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputSymbols())).orElse(result)); } private static boolean isSymbolToSymbolProjection(ProjectNode project) { return project.getAssignments().getExpressions().stream().allMatch(SymbolReference.class::isInstance); } private static Map mapExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) { ImmutableMap.Builder outputToInputMap = ImmutableMap.builder(); for (int i = 0; i < exchange.getOutputSymbols().size(); i++) { outputToInputMap.put(exchange.getOutputSymbols().get(i), exchange.getInputs().get(sourceIndex).get(i)); } return outputToInputMap.buildOrThrow(); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy