All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.OrderBy;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SortItem;
import io.trino.sql.tree.SortItem.NullOrdering;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.metadata.ResolvedFunction.extractFunctionName;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.filter;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.sql.planner.plan.Patterns.patternRecognition;
import static io.trino.sql.planner.plan.Patterns.project;
import static io.trino.sql.planner.plan.Patterns.values;
import static io.trino.sql.tree.SortItem.Ordering.ASCENDING;
import static io.trino.sql.tree.SortItem.Ordering.DESCENDING;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class ExpressionRewriteRuleSet
{
    public interface ExpressionRewriter
    {
        Expression rewrite(Expression expression, Rule.Context context);
    }

    private final ExpressionRewriter rewriter;

    public ExpressionRewriteRuleSet(ExpressionRewriter rewriter)
    {
        this.rewriter = requireNonNull(rewriter, "rewriter is null");
    }

    public Set> rules()
    {
        return ImmutableSet.of(
                projectExpressionRewrite(),
                aggregationExpressionRewrite(),
                filterExpressionRewrite(),
                joinExpressionRewrite(),
                valuesExpressionRewrite(),
                patternRecognitionExpressionRewrite());
    }

    public Rule projectExpressionRewrite()
    {
        return new ProjectExpressionRewrite(rewriter);
    }

    public Rule aggregationExpressionRewrite()
    {
        return new AggregationExpressionRewrite(rewriter);
    }

    public Rule filterExpressionRewrite()
    {
        return new FilterExpressionRewrite(rewriter);
    }

    public Rule joinExpressionRewrite()
    {
        return new JoinExpressionRewrite(rewriter);
    }

    public Rule valuesExpressionRewrite()
    {
        return new ValuesExpressionRewrite(rewriter);
    }

    public Rule patternRecognitionExpressionRewrite()
    {
        return new PatternRecognitionExpressionRewrite(rewriter);
    }

    private static final class ProjectExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        ProjectExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return project();
        }

        @Override
        public Result apply(ProjectNode projectNode, Captures captures, Context context)
        {
            Assignments assignments = projectNode.getAssignments().rewrite(x -> rewriter.rewrite(x, context));
            if (projectNode.getAssignments().equals(assignments)) {
                return Result.empty();
            }
            return Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), assignments));
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }

    private static final class AggregationExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        AggregationExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return aggregation();
        }

        @Override
        public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
        {
            boolean anyRewritten = false;
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
                Aggregation aggregation = entry.getValue();
                CatalogSchemaFunctionName name = aggregation.getResolvedFunction().getSignature().getName();
                FunctionCall call = (FunctionCall) rewriter.rewrite(
                        new FunctionCall(
                                Optional.empty(),
                                aggregation.getResolvedFunction().toQualifiedName(),
                                Optional.empty(),
                                aggregation.getFilter().map(symbol -> new SymbolReference(symbol.getName())),
                                aggregation.getOrderingScheme().map(orderBy -> new OrderBy(orderBy.getOrderBy().stream()
                                        .map(symbol -> new SortItem(
                                                new SymbolReference(symbol.getName()),
                                                orderBy.getOrdering(symbol).isAscending() ? ASCENDING : DESCENDING,
                                                orderBy.getOrdering(symbol).isNullsFirst() ? NullOrdering.FIRST : NullOrdering.LAST))
                                        .collect(toImmutableList()))),
                                aggregation.isDistinct(),
                                Optional.empty(),
                                Optional.empty(),
                                aggregation.getArguments()),
                        context);
                verify(
                        extractFunctionName(call.getName()).equals(name),
                        "Aggregation function name changed");
                Aggregation newAggregation = new Aggregation(
                        aggregation.getResolvedFunction(),
                        call.getArguments(),
                        call.isDistinct(),
                        call.getFilter().map(Symbol::from),
                        call.getOrderBy().map(OrderingScheme::fromOrderBy),
                        aggregation.getMask());
                aggregations.put(entry.getKey(), newAggregation);
                if (!aggregation.equals(newAggregation)) {
                    anyRewritten = true;
                }
            }
            if (anyRewritten) {
                return Result.ofPlanNode(AggregationNode.builderFrom(aggregationNode)
                        .setAggregations(aggregations.buildOrThrow())
                        .build());
            }
            return Result.empty();
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }

    private static final class FilterExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        FilterExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return filter();
        }

        @Override
        public Result apply(FilterNode filterNode, Captures captures, Context context)
        {
            Expression rewritten = rewriter.rewrite(filterNode.getPredicate(), context);
            if (filterNode.getPredicate().equals(rewritten)) {
                return Result.empty();
            }
            return Result.ofPlanNode(new FilterNode(filterNode.getId(), filterNode.getSource(), rewritten));
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }

    private static final class JoinExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        JoinExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return join();
        }

        @Override
        public Result apply(JoinNode joinNode, Captures captures, Context context)
        {
            Optional filter = joinNode.getFilter().map(x -> rewriter.rewrite(x, context));
            if (!joinNode.getFilter().equals(filter)) {
                return Result.ofPlanNode(new JoinNode(
                        joinNode.getId(),
                        joinNode.getType(),
                        joinNode.getLeft(),
                        joinNode.getRight(),
                        joinNode.getCriteria(),
                        joinNode.getLeftOutputSymbols(),
                        joinNode.getRightOutputSymbols(),
                        joinNode.isMaySkipOutputDuplicates(),
                        filter,
                        joinNode.getLeftHashSymbol(),
                        joinNode.getRightHashSymbol(),
                        joinNode.getDistributionType(),
                        joinNode.isSpillable(),
                        joinNode.getDynamicFilters(),
                        joinNode.getReorderJoinStatsAndCost()));
            }
            return Result.empty();
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }

    private static final class ValuesExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        ValuesExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return values();
        }

        @Override
        public Result apply(ValuesNode valuesNode, Captures captures, Context context)
        {
            if (valuesNode.getRows().isEmpty()) {
                return Result.empty();
            }

            boolean anyRewritten = false;
            ImmutableList.Builder rows = ImmutableList.builder();
            for (Expression row : valuesNode.getRows().get()) {
                Expression rewritten;
                if (row instanceof Row) {
                    // preserve the structure of row
                    rewritten = new Row(((Row) row).getItems().stream()
                            .map(item -> rewriter.rewrite(item, context))
                            .collect(toImmutableList()));
                }
                else {
                    rewritten = rewriter.rewrite(row, context);
                }
                if (!row.equals(rewritten)) {
                    anyRewritten = true;
                }
                rows.add(rewritten);
            }
            if (anyRewritten) {
                return Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputSymbols(), rows.build()));
            }
            return Result.empty();
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }

    private static final class PatternRecognitionExpressionRewrite
            implements Rule
    {
        private final ExpressionRewriter rewriter;

        PatternRecognitionExpressionRewrite(ExpressionRewriter rewriter)
        {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern getPattern()
        {
            return patternRecognition();
        }

        @Override
        public Result apply(PatternRecognitionNode node, Captures captures, Context context)
        {
            boolean anyRewritten = false;

            // rewrite MEASURES expressions
            ImmutableMap.Builder rewrittenMeasures = ImmutableMap.builder();
            for (Map.Entry entry : node.getMeasures().entrySet()) {
                ExpressionAndValuePointers pointers = entry.getValue().getExpressionAndValuePointers();
                Optional newPointers = rewrite(pointers, context);
                if (newPointers.isPresent()) {
                    anyRewritten = true;
                    rewrittenMeasures.put(entry.getKey(), new Measure(newPointers.get(), entry.getValue().getType()));
                }
                else {
                    rewrittenMeasures.put(entry);
                }
            }

            // rewrite DEFINE expressions
            ImmutableMap.Builder rewrittenDefinitions = ImmutableMap.builder();
            for (Map.Entry entry : node.getVariableDefinitions().entrySet()) {
                ExpressionAndValuePointers pointers = entry.getValue();
                Optional newPointers = rewrite(pointers, context);
                if (newPointers.isPresent()) {
                    anyRewritten = true;
                    rewrittenDefinitions.put(entry.getKey(), newPointers.get());
                }
                else {
                    rewrittenDefinitions.put(entry);
                }
            }

            if (anyRewritten) {
                return Result.ofPlanNode(new PatternRecognitionNode(
                        node.getId(),
                        node.getSource(),
                        node.getSpecification(),
                        node.getHashSymbol(),
                        node.getPrePartitionedInputs(),
                        node.getPreSortedOrderPrefix(),
                        node.getWindowFunctions(),
                        rewrittenMeasures.buildOrThrow(),
                        node.getCommonBaseFrame(),
                        node.getRowsPerMatch(),
                        node.getSkipToLabel(),
                        node.getSkipToPosition(),
                        node.isInitial(),
                        node.getPattern(),
                        node.getSubsets(),
                        rewrittenDefinitions.buildOrThrow()));
            }

            return Result.empty();
        }

        // return Optional containing the rewritten ExpressionAndValuePointers, or Optional.empty() in case when no rewrite applies
        private Optional rewrite(ExpressionAndValuePointers pointers, Context context)
        {
            boolean rewritten = false;

            List newLayout = pointers.getLayout();
            List newPointers = pointers.getValuePointers();
            Set newClassifierSymbols = pointers.getClassifierSymbols();
            Set newMatchNumberSymbols = pointers.getMatchNumberSymbols();

            // rewrite top-level expression
            Expression newExpression = rewriter.rewrite(pointers.getExpression(), context);
            if (!pointers.getExpression().equals(newExpression)) {
                rewritten = true;
                // prune unused symbols from layout and value pointers
                Set newSymbols = SymbolsExtractor.extractUnique(newExpression);
                List layout = pointers.getLayout();
                ImmutableList.Builder newLayoutBuilder = ImmutableList.builder();
                ImmutableList.Builder newPointersBuilder = ImmutableList.builder();
                for (int i = 0; i < layout.size(); i++) {
                    if (newSymbols.contains(layout.get(i))) {
                        newLayoutBuilder.add(layout.get(i));
                        newPointersBuilder.add(pointers.getValuePointers().get(i));
                    }
                }
                newLayout = newLayoutBuilder.build();
                newPointers = newPointersBuilder.build();
                newClassifierSymbols = pointers.getClassifierSymbols().stream()
                        .filter(newSymbols::contains)
                        .collect(toImmutableSet());
                newMatchNumberSymbols = pointers.getMatchNumberSymbols().stream()
                        .filter(newSymbols::contains)
                        .collect(toImmutableSet());
            }
            // process all aggregation arguments in remaining value pointers
            ImmutableList.Builder newPointersBuilder = ImmutableList.builder();
            for (ValuePointer pointer : newPointers) {
                if (pointer instanceof ScalarValuePointer) {
                    newPointersBuilder.add(pointer);
                }
                else {
                    AggregationValuePointer aggregationPointer = (AggregationValuePointer) pointer;

                    ImmutableList.Builder newArguments = ImmutableList.builder();
                    for (Expression argument : aggregationPointer.getArguments()) {
                        Expression newArgument = rewriter.rewrite(argument, context);
                        if (!newArgument.equals(argument)) {
                            rewritten = true;
                        }
                        newArguments.add(newArgument);
                    }
                    newPointersBuilder.add(new AggregationValuePointer(
                            aggregationPointer.getFunction(),
                            aggregationPointer.getSetDescriptor(),
                            newArguments.build(),
                            aggregationPointer.getClassifierSymbol(),
                            aggregationPointer.getMatchNumberSymbol()));
                }
            }
            newPointers = newPointersBuilder.build();

            if (rewritten) {
                return Optional.of(new ExpressionAndValuePointers(newExpression, newLayout, newPointers, newClassifierSymbols, newMatchNumberSymbols));
            }

            return Optional.empty();
        }

        @Override
        public String toString()
        {
            return format("%s(%s)", getClass().getSimpleName(), rewriter);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy