![JAR search and dependency download from the Maven repository](/logo.png)
io.trino.sql.planner.optimizations.SymbolMapper Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.optimizations;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.MergeProcessorNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.StatisticAggregations;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.sql.planner.plan.AggregationNode.groupingSets;
import static java.util.Objects.requireNonNull;
public class SymbolMapper
{
private final Function mappingFunction;
public SymbolMapper(Function mappingFunction)
{
this.mappingFunction = requireNonNull(mappingFunction, "mappingFunction is null");
}
public static SymbolMapper symbolMapper(Map mapping)
{
return new SymbolMapper(symbol -> {
while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) {
symbol = mapping.get(symbol);
}
return symbol;
});
}
public static SymbolMapper symbolReallocator(Map mapping, SymbolAllocator symbolAllocator)
{
return new SymbolMapper(symbol -> {
if (mapping.containsKey(symbol)) {
while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) {
symbol = mapping.get(symbol);
}
// do not remap the symbol further
mapping.put(symbol, symbol);
return symbol;
}
Symbol newSymbol = symbolAllocator.newSymbol(symbol);
mapping.put(symbol, newSymbol);
// do not remap the symbol further
mapping.put(newSymbol, newSymbol);
return newSymbol;
});
}
// Return the canonical mapping for the symbol.
public Symbol map(Symbol symbol)
{
return mappingFunction.apply(symbol);
}
public List map(List symbols)
{
return symbols.stream()
.map(this::map)
.collect(toImmutableList());
}
public List mapAndDistinct(List symbols)
{
return symbols.stream()
.map(this::map)
.distinct()
.collect(toImmutableList());
}
public Expression map(Expression expression)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<>()
{
@Override
public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter)
{
Symbol canonical = map(Symbol.from(node));
return canonical.toSymbolReference();
}
}, expression);
}
public AggregationNode map(AggregationNode node, PlanNode source)
{
return map(node, source, node.getId());
}
public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId)
{
ImmutableMap.Builder aggregations = ImmutableMap.builder();
for (Entry entry : node.getAggregations().entrySet()) {
aggregations.put(map(entry.getKey()), map(entry.getValue()));
}
return new AggregationNode(
newNodeId,
source,
aggregations.buildOrThrow(),
groupingSets(
mapAndDistinct(node.getGroupingKeys()),
node.getGroupingSetCount(),
node.getGlobalGroupingSets()),
ImmutableList.of(),
node.getStep(),
node.getHashSymbol().map(this::map),
node.getGroupIdSymbol().map(this::map));
}
public Aggregation map(Aggregation aggregation)
{
return new Aggregation(
aggregation.getResolvedFunction(),
aggregation.getArguments().stream()
.map(this::map)
.collect(toImmutableList()),
aggregation.isDistinct(),
aggregation.getFilter().map(this::map),
aggregation.getOrderingScheme().map(this::map),
aggregation.getMask().map(this::map));
}
public GroupIdNode map(GroupIdNode node, PlanNode source)
{
Map newGroupingMappings = new HashMap<>();
ImmutableList.Builder> newGroupingSets = ImmutableList.builder();
for (List groupingSet : node.getGroupingSets()) {
Set newGroupingSet = new LinkedHashSet<>();
for (Symbol output : groupingSet) {
Symbol newOutput = map(output);
newGroupingMappings.putIfAbsent(
newOutput,
map(node.getGroupingColumns().get(output)));
newGroupingSet.add(newOutput);
}
newGroupingSets.add(ImmutableList.copyOf(newGroupingSet));
}
return new GroupIdNode(
node.getId(),
source,
newGroupingSets.build(),
newGroupingMappings,
mapAndDistinct(node.getAggregationArguments()),
map(node.getGroupIdSymbol()));
}
public WindowNode map(WindowNode node, PlanNode source)
{
ImmutableMap.Builder newFunctions = ImmutableMap.builder();
node.getWindowFunctions().forEach((symbol, function) -> {
List newArguments = function.getArguments().stream()
.map(this::map)
.collect(toImmutableList());
WindowNode.Frame newFrame = map(function.getFrame());
newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
});
SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix());
return new WindowNode(
node.getId(),
source,
newSpecification.specification(),
newFunctions.buildOrThrow(),
node.getHashSymbol().map(this::map),
node.getPrePartitionedInputs().stream()
.map(this::map)
.collect(toImmutableSet()),
newSpecification.preSorted());
}
private WindowNode.Frame map(WindowNode.Frame frame)
{
return new WindowNode.Frame(
frame.getType(),
frame.getStartType(),
frame.getStartValue().map(this::map),
frame.getSortKeyCoercedForFrameStartComparison().map(this::map),
frame.getEndType(),
frame.getEndValue().map(this::map),
frame.getSortKeyCoercedForFrameEndComparison().map(this::map),
frame.getOriginalStartValue(),
frame.getOriginalEndValue());
}
private SpecificationWithPreSortedPrefix mapAndDistinct(DataOrganizationSpecification specification, int preSorted)
{
Optional newOrderingScheme = specification.getOrderingScheme()
.map(orderingScheme -> map(orderingScheme, preSorted));
return new SpecificationWithPreSortedPrefix(
new DataOrganizationSpecification(
mapAndDistinct(specification.getPartitionBy()),
newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::orderingScheme)),
newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::preSorted).orElse(preSorted));
}
public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification)
{
return new DataOrganizationSpecification(
mapAndDistinct(specification.getPartitionBy()),
specification.getOrderingScheme().map(this::map));
}
public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source)
{
SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix());
ImmutableMap.Builder newFunctions = ImmutableMap.builder();
node.getWindowFunctions().forEach((symbol, function) -> {
List newArguments = function.getArguments().stream()
.map(this::map)
.collect(toImmutableList());
WindowNode.Frame newFrame = map(function.getFrame());
newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
});
ImmutableMap.Builder newMeasures = ImmutableMap.builder();
node.getMeasures().forEach((symbol, measure) -> {
ExpressionAndValuePointers newExpression = map(measure.getExpressionAndValuePointers());
newMeasures.put(map(symbol), new Measure(newExpression, measure.getType()));
});
ImmutableMap.Builder newVariableDefinitions = ImmutableMap.builder();
node.getVariableDefinitions().forEach((label, expression) -> newVariableDefinitions.put(label, map(expression)));
return new PatternRecognitionNode(
node.getId(),
source,
newSpecification.specification(),
node.getHashSymbol().map(this::map),
node.getPrePartitionedInputs().stream()
.map(this::map)
.collect(toImmutableSet()),
newSpecification.preSorted(),
newFunctions.buildOrThrow(),
newMeasures.buildOrThrow(),
node.getCommonBaseFrame().map(this::map),
node.getRowsPerMatch(),
node.getSkipToLabel(),
node.getSkipToPosition(),
node.isInitial(),
node.getPattern(),
node.getSubsets(),
newVariableDefinitions.buildOrThrow());
}
private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers)
{
// Map only the input symbols of ValuePointers. These are the symbols produced by the source node.
// Other symbols present in the ExpressionAndValuePointers structure are synthetic unique symbols
// with no outer usage or dependencies.
ImmutableList.Builder newValuePointers = ImmutableList.builder();
for (ValuePointer valuePointer : expressionAndValuePointers.getValuePointers()) {
if (valuePointer instanceof ScalarValuePointer scalarValuePointer) {
Symbol inputSymbol = scalarValuePointer.getInputSymbol();
if (expressionAndValuePointers.getClassifierSymbols().contains(inputSymbol) || expressionAndValuePointers.getMatchNumberSymbols().contains(inputSymbol)) {
newValuePointers.add(scalarValuePointer);
}
else {
newValuePointers.add(new ScalarValuePointer(scalarValuePointer.getLogicalIndexPointer(), map(inputSymbol)));
}
}
else {
AggregationValuePointer aggregationValuePointer = (AggregationValuePointer) valuePointer;
List newArguments = aggregationValuePointer.getArguments().stream()
.map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter()
{
@Override
public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter)
{
if (Symbol.from(node).equals(aggregationValuePointer.getClassifierSymbol()) || Symbol.from(node).equals(aggregationValuePointer.getMatchNumberSymbol())) {
return node;
}
return map(node);
}
}, expression))
.collect(toImmutableList());
newValuePointers.add(new AggregationValuePointer(
aggregationValuePointer.getFunction(),
aggregationValuePointer.getSetDescriptor(),
newArguments,
aggregationValuePointer.getClassifierSymbol(),
aggregationValuePointer.getMatchNumberSymbol()));
}
}
return new ExpressionAndValuePointers(
expressionAndValuePointers.getExpression(),
expressionAndValuePointers.getLayout(),
newValuePointers.build(),
expressionAndValuePointers.getClassifierSymbols(),
expressionAndValuePointers.getMatchNumberSymbols());
}
public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source)
{
// rewrite and deduplicate pass-through specifications
// note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten
// to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences.
// For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism
// is more efficient for partitioning columns which are guaranteed to be constant within partition.
// TODO choose a partitioning column to be retrieved while deduplicating
ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder();
Set newPassThroughSymbols = new HashSet<>();
for (PassThroughSpecification specification : node.getPassThroughSpecifications()) {
ImmutableList.Builder newColumns = ImmutableList.builder();
for (PassThroughColumn column : specification.columns()) {
Symbol newSymbol = map(column.symbol());
if (newPassThroughSymbols.add(newSymbol)) {
newColumns.add(new PassThroughColumn(newSymbol, column.isPartitioningColumn()));
}
}
newPassThroughSpecifications.add(new PassThroughSpecification(specification.declaredAsPassThrough(), newColumns.build()));
}
// rewrite required symbols without deduplication. the table function expects specific input layout
List> newRequiredSymbols = node.getRequiredSymbols().stream()
.map(this::map)
.collect(toImmutableList());
// rewrite and deduplicate marker mapping
Optional
© 2015 - 2025 Weber Informatics LLC | Privacy Policy