All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.PushDownProjectionsFromPatternRecognition Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.plan.Patterns.patternRecognition;

/**
 * Aggregate functions in pattern recognition context have special semantics.
 * It is allowed to use `CLASSIFIER()` and `MATCH_NUMBER()` functions
 * in aggregations arguments. Those calls are evaluated at runtime,
 * as they depend on the pattern matching state.
 * 

* As a consequence, some aggregation arguments cannot be pre-projected * and replaced with single symbols. These are the "runtime-evaluated arguments". *

* The purpose of this rule is to identify and pre-project all arguments which * are not runtime-evaluated. *

* Example: * `array_agg(CLASSIFIER(A))` -> the argument `CLASSIFIER(A)` cannot be pre-projected * `avg(A.price + A.tax)` -> we can pre-project the expression `price + tax`, and * replace the argument with a single symbol. The label 'A', which prefixes column * references, has been already extracted from the expression in the LogicalIndexExtractor. */ public class PushDownProjectionsFromPatternRecognition implements Rule { private static final Pattern PATTERN = patternRecognition(); @Override public Pattern getPattern() { return PATTERN; } @Override public Result apply(PatternRecognitionNode node, Captures captures, Context context) { Assignments.Builder assignments = Assignments.builder(); Map rewrittenVariableDefinitions = rewriteVariableDefinitions(node.getVariableDefinitions(), assignments, context); Map rewrittenMeasureDefinitions = rewriteMeasureDefinitions(node.getMeasures(), assignments, context); if (assignments.build().isEmpty()) { return Result.empty(); } assignments.putIdentities(node.getSource().getOutputSymbols()); ProjectNode projectNode = new ProjectNode( context.getIdAllocator().getNextId(), node.getSource(), assignments.build()); PatternRecognitionNode patternRecognitionNode = new PatternRecognitionNode( node.getId(), projectNode, node.getSpecification(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix(), node.getWindowFunctions(), rewrittenMeasureDefinitions, node.getCommonBaseFrame(), node.getRowsPerMatch(), node.getSkipToLabel(), node.getSkipToPosition(), node.isInitial(), node.getPattern(), node.getSubsets(), rewrittenVariableDefinitions); return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), patternRecognitionNode, ImmutableSet.copyOf(node.getOutputSymbols())).orElse(patternRecognitionNode)); } private static Map rewriteVariableDefinitions(Map variableDefinitions, Assignments.Builder assignments, Context context) { return variableDefinitions.entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> rewrite(entry.getValue(), assignments, context))); } private static Map rewriteMeasureDefinitions(Map measureDefinitions, Assignments.Builder assignments, Context context) { return measureDefinitions.entrySet().stream() .collect(toImmutableMap(Map.Entry::getKey, entry -> new Measure(rewrite(entry.getValue().getExpressionAndValuePointers(), assignments, context), entry.getValue().getType()))); } private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers expression, Assignments.Builder assignments, Context context) { ImmutableList.Builder rewrittenPointers = ImmutableList.builder(); for (ValuePointer valuePointer : expression.getValuePointers()) { if (valuePointer instanceof ScalarValuePointer) { rewrittenPointers.add(valuePointer); } else { AggregationValuePointer aggregationPointer = (AggregationValuePointer) valuePointer; Set runtimeEvaluatedSymbols = ImmutableSet.of(aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol()); List argumentTypes = aggregationPointer.getFunction().getSignature().getArgumentTypes(); ImmutableList.Builder rewrittenArguments = ImmutableList.builder(); for (int i = 0; i < aggregationPointer.getArguments().size(); i++) { Expression argument = aggregationPointer.getArguments().get(i); if (argument instanceof SymbolReference || SymbolsExtractor.extractUnique(argument).stream() .anyMatch(runtimeEvaluatedSymbols::contains)) { rewrittenArguments.add(argument); } else { Symbol symbol = context.getSymbolAllocator().newSymbol(argument, argumentTypes.get(i)); assignments.put(symbol, argument); rewrittenArguments.add(symbol.toSymbolReference()); } } rewrittenPointers.add(new AggregationValuePointer( aggregationPointer.getFunction(), aggregationPointer.getSetDescriptor(), rewrittenArguments.build(), aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol())); } } return new ExpressionAndValuePointers( expression.getExpression(), expression.getLayout(), rewrittenPointers.build(), expression.getClassifierSymbols(), expression.getMatchNumberSymbols()); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy