All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.rowpattern.LogicalIndexExtractor Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.rowpattern;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.LabelDereference;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.ProcessingMode;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.analyzer.ExpressionAnalyzer.isPatternRecognitionFunction;
import static io.trino.sql.analyzer.ExpressionTreeUtils.extractExpressions;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.trino.sql.tree.ProcessingMode.Mode.FINAL;
import static java.lang.Math.toIntExact;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

/**
 * Rewriter for expressions specific to row pattern recognition.
 * Removes label-prefixed symbol references from the expression and replaces them with symbols.
 * Removes row pattern navigation functions (PREV, NEXT, FIRST and LAST) from the expression.
 * Removes pattern special functions CLASSIFIER() and MATCH_NUMBER() and replaces them with symbols.
 * Reallocates all symbols in the expression to avoid unwanted optimizations when the expression is compiled.
 * For each of the symbols creates a value accessor (ValuePointer).
 * Returns new symbols as expected "input descriptor", upon which the rewritten expression will be compiled, along with value accessors.
 * Value accessors are ordered the same way as the corresponding symbols so that they can be used to provide actual values to the compiled expression.
 * Each time the compiled expression will be executed, a single-row input will be prepared with the use of the value accessors,
 * following the symbols layout.
 * 

* Aggregate functions in pattern recognition expressions are handled in special way. Similarly to column references, CLASSIFIER() and MATCH_NUMBER() * calls, they are replaced with a single symbol in the resulting expression, backed with a ValuePointer. The ValuePointer is an instance of the * AggregationValuePointer class. It captures the aggregate function, the descriptor of the aggregated set of rows, and a list of arguments. * The expressions in the arguments list are rewritten so that they do not contain any pattern-recognition-specific elements. */ public class LogicalIndexExtractor { public static ExpressionAndValuePointers rewrite(Expression expression, Map> subsets, SymbolAllocator symbolAllocator, Session session, Metadata metadata) { ImmutableList.Builder layout = ImmutableList.builder(); ImmutableList.Builder valuePointers = ImmutableList.builder(); ImmutableSet.Builder classifierSymbols = ImmutableSet.builder(); ImmutableSet.Builder matchNumberSymbols = ImmutableSet.builder(); Visitor visitor = new Visitor(subsets, layout, valuePointers, classifierSymbols, matchNumberSymbols, symbolAllocator, session, metadata); Expression rewritten = ExpressionTreeRewriter.rewriteWith(visitor, expression, LogicalIndexContext.DEFAULT); return new ExpressionAndValuePointers(rewritten, layout.build(), valuePointers.build(), classifierSymbols.build(), matchNumberSymbols.build()); } private LogicalIndexExtractor() {} private static class Visitor extends ExpressionRewriter { private final Map> subsets; private final ImmutableList.Builder layout; private final ImmutableList.Builder valuePointers; private final ImmutableSet.Builder classifierSymbols; private final ImmutableSet.Builder matchNumberSymbols; private final SymbolAllocator symbolAllocator; private final Session session; private final Metadata metadata; public Visitor( Map> subsets, ImmutableList.Builder layout, ImmutableList.Builder valuePointers, ImmutableSet.Builder classifierSymbols, ImmutableSet.Builder matchNumberSymbols, SymbolAllocator symbolAllocator, Session session, Metadata metadata) { this.subsets = requireNonNull(subsets, "subsets is null"); this.layout = requireNonNull(layout, "layout is null"); this.valuePointers = requireNonNull(valuePointers, "valuePointers is null"); this.classifierSymbols = requireNonNull(classifierSymbols, "classifierSymbols is null"); this.matchNumberSymbols = requireNonNull(matchNumberSymbols, "matchNumberSymbols is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); } @Override protected Expression rewriteExpression(Expression node, LogicalIndexContext context, ExpressionTreeRewriter treeRewriter) { return treeRewriter.defaultRewrite(node, context); } @Override public Expression rewriteLabelDereference(LabelDereference node, LogicalIndexContext context, ExpressionTreeRewriter treeRewriter) { Symbol referenced = Symbol.from(node.getReference().orElseThrow()); Symbol reallocated = symbolAllocator.newSymbol(referenced); layout.add(reallocated); Set labels = subsets.get(irLabel(node.getLabel())); if (labels == null) { labels = ImmutableSet.of(irLabel(node.getLabel())); } valuePointers.add(new ScalarValuePointer(context.withLabels(labels).toLogicalIndexPointer(), referenced)); return reallocated.toSymbolReference(); } @Override public Expression rewriteSymbolReference(SymbolReference node, LogicalIndexContext context, ExpressionTreeRewriter treeRewriter) { // symbol reference with no label prefix is implicitly prefixed with a universal row pattern variable (matches every label) // it is encoded as empty label set Symbol reallocated = symbolAllocator.newSymbol(Symbol.from(node)); layout.add(reallocated); valuePointers.add(new ScalarValuePointer(context.withLabels(ImmutableSet.of()).toLogicalIndexPointer(), Symbol.from(node))); return reallocated.toSymbolReference(); } @Override public Expression rewriteFunctionCall(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter treeRewriter) { if (isPatternRecognitionFunction(node)) { QualifiedName name = node.getName(); String functionName = name.getSuffix().toUpperCase(ENGLISH); return switch (functionName) { case "FIRST", "LAST", "PREV", "NEXT" -> rewritePatternNavigationFunction(node, context, treeRewriter); case "CLASSIFIER" -> rewriteClassifierFunction(node, context); case "MATCH_NUMBER" -> rewriteMatchNumberFunction(); default -> throw new UnsupportedOperationException("unsupported pattern recognition function type: " + node.getName()); }; } ResolvedFunction resolvedFunction = metadata.decodeFunction(node.getName()); if (resolvedFunction.getFunctionKind() == FunctionKind.AGGREGATE) { Type type = resolvedFunction.getSignature().getReturnType(); Symbol aggregationSymbol = symbolAllocator.newSymbol(node, type); layout.add(aggregationSymbol); Symbol classifierSymbol = symbolAllocator.newSymbol("classifier", VARCHAR); Symbol matchNumberSymbol = symbolAllocator.newSymbol("match_number", BIGINT); List rewrittenArguments = AggregateArgumentsRewriter.rewrite(node.getArguments(), classifierSymbol, matchNumberSymbol); AggregationValuePointer descriptor = new AggregationValuePointer( resolvedFunction, new AggregatedSetDescriptor( extractLabels(node), node.getProcessingMode().isEmpty() || node.getProcessingMode().get().getMode() != FINAL), rewrittenArguments, classifierSymbol, matchNumberSymbol); valuePointers.add(descriptor); return aggregationSymbol.toSymbolReference(); } return super.rewriteFunctionCall(node, context, treeRewriter); } /** * Extract labels to identify rows to which the aggregation should be applied. * It is assumed that all arguments of the aggregation apply consistently to the same set of labels. * * @param node a `FunctionCall` for the aggregate function * @return set of `IrLabel`s corresponding to primary row pattern variables or empty set for the universal row pattern variable */ private Set extractLabels(FunctionCall node) { if (node.getArguments().isEmpty()) { return ImmutableSet.of(); } List labeledDereferences = extractExpressions(node.getArguments(), LabelDereference.class); if (!labeledDereferences.isEmpty()) { IrLabel label = irLabel(labeledDereferences.get(0).getLabel()); Set labels = subsets.get(label); if (labels == null) { labels = ImmutableSet.of(label); } return labels; } Optional classifierCall = extractExpressions(node.getArguments(), FunctionCall.class).stream() .filter(ExpressionAnalyzer::isPatternRecognitionFunction) .filter(function -> function.getName().getSuffix().toUpperCase(ENGLISH).equals("CLASSIFIER")) .findFirst(); if (classifierCall.isPresent()) { FunctionCall classifier = classifierCall.get(); if (!classifier.getArguments().isEmpty()) { IrLabel label = irLabel(((Identifier) getOnlyElement(classifier.getArguments())).getCanonicalValue()); Set labels = subsets.get(label); if (labels == null) { labels = ImmutableSet.of(label); } return labels; } } return ImmutableSet.of(); } private Expression rewritePatternNavigationFunction(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter treeRewriter) { String functionName = node.getName().getSuffix().toUpperCase(ENGLISH); Expression argument = node.getArguments().get(0); Optional processingMode = node.getProcessingMode(); OptionalInt offset = OptionalInt.empty(); if (node.getArguments().size() > 1) { offset = OptionalInt.of(toIntExact(((LongLiteral) node.getArguments().get(1)).getParsedValue())); } return switch (functionName) { case "PREV" -> treeRewriter.rewrite(argument, context.withPhysicalOffset(-offset.orElse(1))); case "NEXT" -> treeRewriter.rewrite(argument, context.withPhysicalOffset(offset.orElse(1))); case "FIRST" -> treeRewriter.rewrite(argument, context.withLogicalOffset( processingMode.isEmpty() || processingMode.get().getMode() != FINAL, false, offset.orElse(0))); case "LAST" -> treeRewriter.rewrite(argument, context.withLogicalOffset( processingMode.isEmpty() || processingMode.get().getMode() != FINAL, true, offset.orElse(0))); default -> throw new UnsupportedOperationException("unsupported pattern navigation function type: " + node.getName()); }; } private Expression rewriteClassifierFunction(FunctionCall node, LogicalIndexContext context) { Symbol classifierSymbol = symbolAllocator.newSymbol("classifier", VARCHAR); layout.add(classifierSymbol); Set labels = ImmutableSet.of(); if (!node.getArguments().isEmpty()) { IrLabel label = irLabel(((Identifier) getOnlyElement(node.getArguments())).getCanonicalValue()); labels = subsets.get(label); if (labels == null) { labels = ImmutableSet.of(label); } } // pass the new symbol as input symbol. It will be used to identify classifier function. valuePointers.add(new ScalarValuePointer(context.withLabels(labels).toLogicalIndexPointer(), classifierSymbol)); classifierSymbols.add(classifierSymbol); return classifierSymbol.toSymbolReference(); } private Expression rewriteMatchNumberFunction() { Symbol matchNumberSymbol = symbolAllocator.newSymbol("match_number", BIGINT); layout.add(matchNumberSymbol); // pass default LogicalIndexPointer. It will not be accessed. match_number() is constant in the context of a match. // pass the new symbol as input symbol. It will be used to identify match number function. valuePointers.add(new ScalarValuePointer(LogicalIndexContext.DEFAULT.toLogicalIndexPointer(), matchNumberSymbol)); matchNumberSymbols.add(matchNumberSymbol); return matchNumberSymbol.toSymbolReference(); } private IrLabel irLabel(String label) { return new IrLabel(label); } } private static class LogicalIndexContext { public static final LogicalIndexContext DEFAULT = new LogicalIndexContext(ImmutableSet.of(), true, true, 0, 0); private final Set label; private final boolean running; private final boolean last; private final int logicalOffset; private final int physicalOffset; private LogicalIndexContext(Set label, boolean running, boolean last, int logicalOffset, int physicalOffset) { this.label = requireNonNull(label, "label is null"); this.running = running; this.last = last; this.logicalOffset = logicalOffset; this.physicalOffset = physicalOffset; } public LogicalIndexContext withPhysicalOffset(int physicalOffset) { return new LogicalIndexContext(this.label, this.running, this.last, this.logicalOffset, physicalOffset); } public LogicalIndexContext withLogicalOffset(boolean running, boolean last, int logicalOffset) { return new LogicalIndexContext(this.label, running, last, logicalOffset, this.physicalOffset); } public LogicalIndexContext withLabels(Set labels) { return new LogicalIndexContext(labels, this.running, this.last, this.logicalOffset, this.physicalOffset); } public LogicalIndexPointer toLogicalIndexPointer() { return new LogicalIndexPointer(label, last, running, logicalOffset, physicalOffset); } } public static class ExpressionAndValuePointers { public static final ExpressionAndValuePointers TRUE = new ExpressionAndValuePointers(TRUE_LITERAL, ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()); private final Expression expression; private final List layout; private final List valuePointers; private final Set classifierSymbols; private final Set matchNumberSymbols; @JsonCreator public ExpressionAndValuePointers(Expression expression, List layout, List valuePointers, Set classifierSymbols, Set matchNumberSymbols) { this.expression = requireNonNull(expression, "expression is null"); this.layout = requireNonNull(layout, "layout is null"); this.valuePointers = requireNonNull(valuePointers, "valuePointers is null"); checkArgument(layout.size() == valuePointers.size(), "layout and valuePointers sizes don't match"); this.classifierSymbols = requireNonNull(classifierSymbols, "classifierSymbols is null"); this.matchNumberSymbols = requireNonNull(matchNumberSymbols, "matchNumberSymbols is null"); } @JsonProperty public Expression getExpression() { return expression; } @JsonProperty public List getLayout() { return layout; } @JsonProperty public List getValuePointers() { return valuePointers; } @JsonProperty public Set getClassifierSymbols() { return classifierSymbols; } @JsonProperty public Set getMatchNumberSymbols() { return matchNumberSymbols; } public List getInputSymbols() { ImmutableList.Builder inputSymbols = ImmutableList.builder(); for (ValuePointer valuePointer : valuePointers) { if (valuePointer instanceof ScalarValuePointer pointer) { Symbol symbol = pointer.getInputSymbol(); if (!classifierSymbols.contains(symbol) && !matchNumberSymbols.contains(symbol)) { inputSymbols.add(symbol); } } else if (valuePointer instanceof AggregationValuePointer) { inputSymbols.addAll(((AggregationValuePointer) valuePointer).getInputSymbols()); } else { throw new UnsupportedOperationException("unexpected ValuePointer type: " + valuePointer.getClass().getSimpleName()); } } return inputSymbols.build(); } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if ((obj == null) || (getClass() != obj.getClass())) { return false; } ExpressionAndValuePointers o = (ExpressionAndValuePointers) obj; return Objects.equals(expression, o.expression) && Objects.equals(layout, o.layout) && Objects.equals(valuePointers, o.valuePointers) && Objects.equals(classifierSymbols, o.classifierSymbols) && Objects.equals(matchNumberSymbols, o.matchNumberSymbols); } @Override public int hashCode() { return Objects.hash(expression, layout, valuePointers, classifierSymbols, matchNumberSymbols); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy