All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.optimizations.PredicatePushDown Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.EffectivePredicateExtractor;
import io.trino.sql.planner.EqualityInference;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SampleNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.TryExpression;
import io.trino.sql.util.AstUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.isEnableDynamicFiltering;
import static io.trino.SystemSessionProperties.isPredicatePushdownUseTableProperties;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.sql.DynamicFilters.createDynamicFilterExpression;
import static io.trino.sql.ExpressionUtils.combineConjuncts;
import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.ExpressionUtils.filterDeterministicConjuncts;
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression;
import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.unwrapCasts;
import static io.trino.sql.planner.plan.JoinNode.Type.FULL;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.sql.planner.plan.JoinNode.Type.RIGHT;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static java.util.Objects.requireNonNull;

public class PredicatePushDown
        implements PlanOptimizer
{
    private static final Set DYNAMIC_FILTERING_SUPPORTED_COMPARISONS = ImmutableSet.of(
            EQUAL,
            GREATER_THAN,
            GREATER_THAN_OR_EQUAL,
            LESS_THAN,
            LESS_THAN_OR_EQUAL);

    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;
    private final boolean useTableProperties;
    private final boolean dynamicFiltering;

    public PredicatePushDown(
            PlannerContext plannerContext,
            TypeAnalyzer typeAnalyzer,
            boolean useTableProperties,
            boolean dynamicFiltering)
    {
        this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        this.useTableProperties = useTableProperties;
        this.dynamicFiltering = dynamicFiltering;
    }

    @Override
    public PlanNode optimize(
            PlanNode plan,
            Session session,
            TypeProvider types,
            SymbolAllocator symbolAllocator,
            PlanNodeIdAllocator idAllocator,
            WarningCollector warningCollector,
            PlanOptimizersStatsCollector planOptimizersStatsCollector,
            TableStatsProvider tableStatsProvider)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");
        requireNonNull(idAllocator, "idAllocator is null");

        return SimplePlanRewriter.rewriteWith(
                new Rewriter(symbolAllocator, idAllocator, plannerContext, typeAnalyzer, session, types, useTableProperties, dynamicFiltering),
                plan,
                TRUE_LITERAL);
    }

    private static class Rewriter
            extends SimplePlanRewriter
    {
        private final SymbolAllocator symbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final PlannerContext plannerContext;
        private final Metadata metadata;
        private final TypeAnalyzer typeAnalyzer;
        private final Session session;
        private final TypeProvider types;
        private final ExpressionEquivalence expressionEquivalence;
        private final boolean dynamicFiltering;
        private final LiteralEncoder literalEncoder;
        private final EffectivePredicateExtractor effectivePredicateExtractor;

        private Rewriter(
                SymbolAllocator symbolAllocator,
                PlanNodeIdAllocator idAllocator,
                PlannerContext plannerContext,
                TypeAnalyzer typeAnalyzer,
                Session session,
                TypeProvider types,
                boolean useTableProperties,
                boolean dynamicFiltering)
        {
            this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
            this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
            this.metadata = plannerContext.getMetadata();
            this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.session = requireNonNull(session, "session is null");
            this.types = requireNonNull(types, "types is null");
            this.expressionEquivalence = new ExpressionEquivalence(plannerContext.getMetadata(), plannerContext.getFunctionManager(), typeAnalyzer);
            this.dynamicFiltering = dynamicFiltering;

            this.effectivePredicateExtractor = new EffectivePredicateExtractor(
                    new DomainTranslator(plannerContext),
                    plannerContext,
                    useTableProperties && isPredicatePushdownUseTableProperties(session));
            this.literalEncoder = new LiteralEncoder(plannerContext);
        }

        @Override
        public PlanNode visitPlan(PlanNode node, RewriteContext context)
        {
            PlanNode rewrittenNode = context.defaultRewrite(node, TRUE_LITERAL);
            if (!context.get().equals(TRUE_LITERAL)) {
                // Drop in a FilterNode b/c we cannot push our predicate down any further
                rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, context.get());
            }
            return rewrittenNode;
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, RewriteContext context)
        {
            boolean modified = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); i++) {
                Map outputsToInputs = new HashMap<>();
                for (int index = 0; index < node.getInputs().get(i).size(); index++) {
                    outputsToInputs.put(
                            node.getOutputSymbols().get(index),
                            node.getInputs().get(i).get(index).toSymbolReference());
                }

                Expression sourcePredicate = inlineSymbols(outputsToInputs, context.get());
                PlanNode source = node.getSources().get(i);
                PlanNode rewrittenSource = context.rewrite(source, sourcePredicate);
                if (rewrittenSource != source) {
                    modified = true;
                }
                builder.add(rewrittenSource);
            }

            if (modified) {
                return new ExchangeNode(
                        node.getId(),
                        node.getType(),
                        node.getScope(),
                        node.getPartitioningScheme(),
                        builder.build(),
                        node.getInputs(),
                        node.getOrderingScheme());
            }

            return node;
        }

        @Override
        public PlanNode visitWindow(WindowNode node, RewriteContext context)
        {
            List partitionSymbols = node.getPartitionBy();

            // TODO: This could be broader. We can push down conjucts if they are constant for all rows in a window partition.
            // The simplest way to guarantee this is if the conjucts are deterministic functions of the partitioning symbols.
            // This can leave out cases where they're both functions of some set of common expressions and the partitioning
            // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by
            // pre-projected symbols.
            Predicate isSupported = conjunct ->
                    isDeterministic(conjunct, metadata) &&
                            partitionSymbols.containsAll(extractUnique(conjunct));

            Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported));

            PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, conjuncts.get(true)));

            if (!conjuncts.get(false).isEmpty()) {
                rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false)));
            }

            return rewrittenNode;
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext context)
        {
            Set deterministicSymbols = node.getAssignments().entrySet().stream()
                    .filter(entry -> isDeterministic(entry.getValue(), metadata))
                    .map(Map.Entry::getKey)
                    .collect(Collectors.toSet());

            Predicate deterministic = conjunct -> deterministicSymbols.containsAll(extractUnique(conjunct));

            Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic));

            // Push down conjuncts from the inherited predicate that only depend on deterministic assignments with
            // certain limitations.
            List deterministicConjuncts = conjuncts.get(true);

            // We partition the expressions in the deterministicConjuncts into two lists, and only inline the
            // expressions that are in the inlining targets list.
            Map> inlineConjuncts = deterministicConjuncts.stream()
                    .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node)));

            List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream()
                    .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry))
                    .map(conjunct -> canonicalizeExpression(conjunct, typeAnalyzer.getTypes(session, types, conjunct), plannerContext, session)) // normalize expressions to a form that unwrapCasts understands
                    .map(conjunct -> unwrapCasts(session, plannerContext, typeAnalyzer, types, conjunct))
                    .collect(Collectors.toList());

            PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, inlinedDeterministicConjuncts));

            // All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts,
            // if any, will be in the filter node.
            List nonInliningConjuncts = inlineConjuncts.get(false);
            nonInliningConjuncts.addAll(conjuncts.get(false));

            if (!nonInliningConjuncts.isEmpty()) {
                rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, nonInliningConjuncts));
            }

            return rewrittenNode;
        }

        private boolean isInliningCandidate(Expression expression, ProjectNode node)
        {
            // TryExpressions should not be pushed down. However they are now being handled as lambda
            // passed to a FunctionCall now and should not affect predicate push down. So we want to make
            // sure the conjuncts are not TryExpressions.
            verify(AstUtils.preOrder(expression).noneMatch(TryExpression.class::isInstance));

            // candidate symbols for inlining are
            //   1. references to simple constants or symbol references
            //   2. references to complex expressions that appear only once
            // which come from the node, as opposed to an enclosing scope.
            Set childOutputSet = ImmutableSet.copyOf(node.getOutputSymbols());
            Map dependencies = SymbolsExtractor.extractAll(expression).stream()
                    .filter(childOutputSet::contains)
                    .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));

            return dependencies.entrySet().stream()
                    .allMatch(entry -> entry.getValue() == 1
                            || isEffectivelyLiteral(plannerContext, session, node.getAssignments().get(entry.getKey()))
                            || node.getAssignments().get(entry.getKey()) instanceof SymbolReference);
        }

        @Override
        public PlanNode visitGroupId(GroupIdNode node, RewriteContext context)
        {
            Map commonGroupingSymbolMapping = node.getGroupingColumns().entrySet().stream()
                    .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey()))
                    .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));

            Predicate pushdownEligiblePredicate = conjunct -> commonGroupingSymbolMapping.keySet().containsAll(extractUnique(conjunct));

            Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate));

            // Push down conjuncts from the inherited predicate that apply to common grouping symbols
            PlanNode rewrittenNode = context.defaultRewrite(node, inlineSymbols(commonGroupingSymbolMapping, combineConjuncts(metadata, conjuncts.get(true))));

            // All other conjuncts, if any, will be in the filter node.
            if (!conjuncts.get(false).isEmpty()) {
                rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false)));
            }

            return rewrittenNode;
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context)
        {
            Set pushDownableSymbols = ImmutableSet.copyOf(node.getDistinctSymbols());
            Map> conjuncts = extractConjuncts(context.get()).stream()
                    .collect(Collectors.partitioningBy(conjunct -> pushDownableSymbols.containsAll(extractUnique(conjunct))));

            PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(metadata, conjuncts.get(true)));

            if (!conjuncts.get(false).isEmpty()) {
                rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(metadata, conjuncts.get(false)));
            }
            return rewrittenNode;
        }

        @Override
        public PlanNode visitSort(SortNode node, RewriteContext context)
        {
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitUnion(UnionNode node, RewriteContext context)
        {
            boolean modified = false;
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); i++) {
                Expression sourcePredicate = inlineSymbols(node.sourceSymbolMap(i), context.get());
                PlanNode source = node.getSources().get(i);
                PlanNode rewrittenSource = context.rewrite(source, sourcePredicate);
                if (rewrittenSource != source) {
                    modified = true;
                }
                builder.add(rewrittenSource);
            }

            if (modified) {
                return new UnionNode(node.getId(), builder.build(), node.getSymbolMapping(), node.getOutputSymbols());
            }

            return node;
        }

        @Deprecated
        @Override
        public PlanNode visitFilter(FilterNode node, RewriteContext context)
        {
            PlanNode rewrittenPlan = context.rewrite(node.getSource(), combineConjuncts(metadata, node.getPredicate(), context.get()));
            if (!(rewrittenPlan instanceof FilterNode rewrittenFilterNode)) {
                return rewrittenPlan;
            }

            if (!areExpressionsEquivalent(rewrittenFilterNode.getPredicate(), node.getPredicate())
                    || node.getSource() != rewrittenFilterNode.getSource()) {
                return rewrittenPlan;
            }

            return node;
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();

            // See if we can rewrite outer joins in terms of a plain inner join
            node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);

            Expression leftEffectivePredicate = effectivePredicateExtractor.extract(session, node.getLeft(), types, typeAnalyzer);
            Expression rightEffectivePredicate = effectivePredicateExtractor.extract(session, node.getRight(), types, typeAnalyzer);
            Expression joinPredicate = extractJoinPredicate(node);

            Expression leftPredicate;
            Expression rightPredicate;
            Expression postJoinPredicate;
            Expression newJoinPredicate;

            switch (node.getType()) {
                case INNER -> {
                    InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(
                            inheritedPredicate,
                            leftEffectivePredicate,
                            rightEffectivePredicate,
                            joinPredicate,
                            node.getLeft().getOutputSymbols(),
                            node.getRight().getOutputSymbols());
                    leftPredicate = innerJoinPushDownResult.getLeftPredicate();
                    rightPredicate = innerJoinPushDownResult.getRightPredicate();
                    postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
                    newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
                }
                case LEFT -> {
                    OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(
                            inheritedPredicate,
                            leftEffectivePredicate,
                            rightEffectivePredicate,
                            joinPredicate,
                            node.getLeft().getOutputSymbols(),
                            node.getRight().getOutputSymbols());
                    leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
                    rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
                    postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
                    newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
                }
                case RIGHT -> {
                    OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(
                            inheritedPredicate,
                            rightEffectivePredicate,
                            leftEffectivePredicate,
                            joinPredicate,
                            node.getRight().getOutputSymbols(),
                            node.getLeft().getOutputSymbols());
                    leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
                    rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
                    postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
                    newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate();
                }
                case FULL -> {
                    leftPredicate = TRUE_LITERAL;
                    rightPredicate = TRUE_LITERAL;
                    postJoinPredicate = inheritedPredicate;
                    newJoinPredicate = joinPredicate;
                }
                default -> throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
            }

            newJoinPredicate = simplifyExpression(newJoinPredicate);

            // Create identity projections for all existing symbols
            Assignments.Builder leftProjections = Assignments.builder();
            leftProjections.putAll(node.getLeft()
                    .getOutputSymbols().stream()
                    .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));

            Assignments.Builder rightProjections = Assignments.builder();
            rightProjections.putAll(node.getRight()
                    .getOutputSymbols().stream()
                    .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));

            // Create new projections for the new join clauses
            List equiJoinClauses = new ArrayList<>();
            ImmutableList.Builder joinFilterBuilder = ImmutableList.builder();
            for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
                if (joinEqualityExpression(conjunct, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) {
                    ComparisonExpression equality = (ComparisonExpression) conjunct;

                    boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(equality.getLeft()));
                    Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight();
                    Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft();

                    Symbol leftSymbol = symbolForExpression(leftExpression);
                    if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) {
                        leftProjections.put(leftSymbol, leftExpression);
                    }

                    Symbol rightSymbol = symbolForExpression(rightExpression);
                    if (!node.getRight().getOutputSymbols().contains(rightSymbol)) {
                        rightProjections.put(rightSymbol, rightExpression);
                    }

                    equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
                }
                else {
                    joinFilterBuilder.add(conjunct);
                }
            }

            List joinFilter = joinFilterBuilder.build();
            DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node, equiJoinClauses, joinFilter, session, idAllocator);
            Map dynamicFilters = dynamicFiltersResult.getDynamicFilters();
            leftPredicate = combineConjuncts(metadata, leftPredicate, combineConjuncts(metadata, dynamicFiltersResult.getPredicates()));

            PlanNode leftSource;
            PlanNode rightSource;
            boolean equiJoinClausesUnmodified = ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
            if (!equiJoinClausesUnmodified) {
                leftSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getLeft(), leftProjections.build()), leftPredicate);
                rightSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getRight(), rightProjections.build()), rightPredicate);
            }
            else {
                leftSource = context.rewrite(node.getLeft(), leftPredicate);
                rightSource = context.rewrite(node.getRight(), rightPredicate);
            }

            Optional newJoinFilter = Optional.of(combineConjuncts(metadata, joinFilter));
            if (newJoinFilter.get().equals(TRUE_LITERAL)) {
                newJoinFilter = Optional.empty();
            }

            if (node.getType() == INNER && newJoinFilter.isPresent() && equiJoinClauses.isEmpty()) {
                // if we do not have any equi conjunct we do not pushdown non-equality condition into
                // inner join, so we plan execution as nested-loops-join followed by filter instead
                // hash join.
                // todo: remove the code when we have support for filter function in nested loop join
                postJoinPredicate = combineConjuncts(metadata, postJoinPredicate, newJoinFilter.get());
                newJoinFilter = Optional.empty();
            }

            boolean filtersEquivalent =
                    newJoinFilter.isPresent() == node.getFilter().isPresent() &&
                            (newJoinFilter.isEmpty() || areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get()));

            PlanNode output = node;
            if (leftSource != node.getLeft() ||
                    rightSource != node.getRight() ||
                    !filtersEquivalent ||
                    !dynamicFilters.equals(node.getDynamicFilters()) ||
                    !equiJoinClausesUnmodified) {
                leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
                rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());

                output = new JoinNode(
                        node.getId(),
                        node.getType(),
                        leftSource,
                        rightSource,
                        equiJoinClauses,
                        leftSource.getOutputSymbols(),
                        rightSource.getOutputSymbols(),
                        node.isMaySkipOutputDuplicates(),
                        newJoinFilter,
                        node.getLeftHashSymbol(),
                        node.getRightHashSymbol(),
                        node.getDistributionType(),
                        node.isSpillable(),
                        dynamicFilters,
                        node.getReorderJoinStatsAndCost());
            }

            if (!postJoinPredicate.equals(TRUE_LITERAL)) {
                output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
            }

            if (!node.getOutputSymbols().equals(output.getOutputSymbols())) {
                output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputSymbols()));
            }

            return output;
        }

        // TODO: collect min/max ranges for inequality dynamic filters (https://github.com/trinodb/trino/issues/5754)
        // TODO: support for complex inequalities, e.g. left < right + 10 (https://github.com/trinodb/trino/issues/5755)
        private DynamicFiltersResult createDynamicFilters(
                JoinNode node,
                List equiJoinClauses,
                List joinFilterClauses,
                Session session,
                PlanNodeIdAllocator idAllocator)
        {
            if ((node.getType() != INNER && node.getType() != RIGHT) || !isEnableDynamicFiltering(session) || !dynamicFiltering) {
                return new DynamicFiltersResult(ImmutableMap.of(), ImmutableList.of());
            }

            List clauses = Streams.concat(
                            equiJoinClauses
                                    .stream()
                                    .map(clause -> new DynamicFilterExpression(
                                            new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()))),
                            joinFilterClauses.stream()
                                    .flatMap(Rewriter::tryConvertBetweenIntoComparisons)
                                    .filter(clause -> joinDynamicFilteringExpression(clause, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols()))
                                    .map(expression -> {
                                        if (expression instanceof NotExpression notExpression) {
                                            ComparisonExpression comparison = (ComparisonExpression) notExpression.getValue();
                                            return new DynamicFilterExpression(new ComparisonExpression(EQUAL, comparison.getLeft(), comparison.getRight()), true);
                                        }
                                        return new DynamicFilterExpression((ComparisonExpression) expression);
                                    })
                                    .map(expression -> {
                                        ComparisonExpression comparison = expression.getComparison();
                                        Expression leftExpression = comparison.getLeft();
                                        Expression rightExpression = comparison.getRight();
                                        boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression));
                                        return new DynamicFilterExpression(
                                                new ComparisonExpression(
                                                        alignedComparison ? comparison.getOperator() : comparison.getOperator().flip(),
                                                        alignedComparison ? leftExpression : rightExpression,
                                                        alignedComparison ? rightExpression : leftExpression),
                                                expression.isNullAllowed());
                                    }))
                    .collect(toImmutableList());

            // New equiJoinClauses could potentially not contain symbols used in current dynamic filters.
            // Since we use PredicatePushdown to push dynamic filters themselves,
            // instead of separate ApplyDynamicFilters rule we derive dynamic filters within PredicatePushdown itself.
            // Even if equiJoinClauses.equals(node.getCriteria), current dynamic filters may not match equiJoinClauses

            // Collect build symbols:
            Set buildSymbols = clauses.stream()
                    .map(DynamicFilterExpression::getComparison)
                    .map(ComparisonExpression::getRight)
                    .map(Symbol::from)
                    .collect(toImmutableSet());

            // Allocate new dynamic filter IDs for each build symbol:
            BiMap buildSymbolToDynamicFilter = HashBiMap.create(node.getDynamicFilters()).inverse();
            for (Symbol buildSymbol : buildSymbols) {
                buildSymbolToDynamicFilter.computeIfAbsent(
                        buildSymbol,
                        key -> new DynamicFilterId("df_" + idAllocator.getNextId().toString()));
            }

            // Multiple probe symbols may depend on a single build symbol / dynamic filter ID:
            List predicates = clauses
                    .stream()
                    .map(clause -> {
                        ComparisonExpression comparison = clause.getComparison();
                        Expression probeExpression = comparison.getLeft();
                        Symbol buildSymbol = Symbol.from(comparison.getRight());
                        // we can take type of buildSymbol instead probeExpression as comparison expression must have the same type on both sides
                        Type type = symbolAllocator.getTypes().get(buildSymbol);
                        DynamicFilterId id = requireNonNull(buildSymbolToDynamicFilter.get(buildSymbol), () -> "missing dynamic filter for symbol " + buildSymbol);
                        return createDynamicFilterExpression(metadata, id, type, probeExpression, comparison.getOperator(), clause.isNullAllowed());
                    })
                    .collect(toImmutableList());
            // Return a mapping from build symbols to corresponding dynamic filter IDs:
            return new DynamicFiltersResult(buildSymbolToDynamicFilter.inverse(), predicates);
        }

        private static Stream tryConvertBetweenIntoComparisons(Expression clause)
        {
            if (clause instanceof BetweenPredicate between) {
                return Stream.of(
                        new ComparisonExpression(GREATER_THAN_OR_EQUAL, between.getValue(), between.getMin()),
                        new ComparisonExpression(LESS_THAN_OR_EQUAL, between.getValue(), between.getMax()));
            }
            return Stream.of(clause);
        }

        private static class DynamicFilterExpression
        {
            private final ComparisonExpression comparison;
            private final boolean nullAllowed;

            private DynamicFilterExpression(ComparisonExpression comparison)
            {
                this(comparison, false);
            }

            private DynamicFilterExpression(ComparisonExpression comparison, boolean nullAllowed)
            {
                this.comparison = requireNonNull(comparison, "comparison is null");
                this.nullAllowed = nullAllowed;
            }

            public ComparisonExpression getComparison()
            {
                return comparison;
            }

            public boolean isNullAllowed()
            {
                return nullAllowed;
            }
        }

        private static class DynamicFiltersResult
        {
            private final Map dynamicFilters;
            private final List predicates;

            public DynamicFiltersResult(Map dynamicFilters, List predicates)
            {
                this.dynamicFilters = ImmutableMap.copyOf(dynamicFilters);
                this.predicates = ImmutableList.copyOf(predicates);
            }

            public Map getDynamicFilters()
            {
                return dynamicFilters;
            }

            public List getPredicates()
            {
                return predicates;
            }
        }

        @Override
        public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();

            // See if we can rewrite left join in terms of a plain inner join
            if (node.getType() == SpatialJoinNode.Type.LEFT && canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate)) {
                node = new SpatialJoinNode(node.getId(), SpatialJoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getOutputSymbols(), node.getFilter(), node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree());
            }

            Expression leftEffectivePredicate = effectivePredicateExtractor.extract(session, node.getLeft(), types, typeAnalyzer);
            Expression rightEffectivePredicate = effectivePredicateExtractor.extract(session, node.getRight(), types, typeAnalyzer);
            Expression joinPredicate = node.getFilter();

            Expression leftPredicate;
            Expression rightPredicate;
            Expression postJoinPredicate;
            Expression newJoinPredicate;

            switch (node.getType()) {
                case INNER -> {
                    InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(
                            inheritedPredicate,
                            leftEffectivePredicate,
                            rightEffectivePredicate,
                            joinPredicate,
                            node.getLeft().getOutputSymbols(),
                            node.getRight().getOutputSymbols());
                    leftPredicate = innerJoinPushDownResult.getLeftPredicate();
                    rightPredicate = innerJoinPushDownResult.getRightPredicate();
                    postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
                    newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
                }
                case LEFT -> {
                    OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(
                            inheritedPredicate,
                            leftEffectivePredicate,
                            rightEffectivePredicate,
                            joinPredicate,
                            node.getLeft().getOutputSymbols(),
                            node.getRight().getOutputSymbols());
                    leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
                    rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
                    postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
                    newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
                }
                default -> throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType());
            }

            newJoinPredicate = simplifyExpression(newJoinPredicate);
            verify(!newJoinPredicate.equals(BooleanLiteral.FALSE_LITERAL), "Spatial join predicate is missing");

            PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate);
            PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate);

            PlanNode output = node;
            if (leftSource != node.getLeft() ||
                    rightSource != node.getRight() ||
                    !areExpressionsEquivalent(newJoinPredicate, joinPredicate)) {
                // Create identity projections for all existing symbols
                Assignments.Builder leftProjections = Assignments.builder();
                leftProjections.putAll(node.getLeft()
                        .getOutputSymbols().stream()
                        .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));

                Assignments.Builder rightProjections = Assignments.builder();
                rightProjections.putAll(node.getRight()
                        .getOutputSymbols().stream()
                        .collect(toImmutableMap(key -> key, Symbol::toSymbolReference)));

                leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
                rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());

                output = new SpatialJoinNode(
                        node.getId(),
                        node.getType(),
                        leftSource,
                        rightSource,
                        node.getOutputSymbols(),
                        newJoinPredicate,
                        node.getLeftPartitionSymbol(),
                        node.getRightPartitionSymbol(),
                        node.getKdbTree());
            }

            if (!postJoinPredicate.equals(TRUE_LITERAL)) {
                output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
            }

            return output;
        }

        private Symbol symbolForExpression(Expression expression)
        {
            if (expression instanceof SymbolReference) {
                return Symbol.from(expression);
            }

            return symbolAllocator.newSymbol(expression, typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression));
        }

        private OuterJoinPushDownResult processLimitedOuterJoin(
                Expression inheritedPredicate,
                Expression outerEffectivePredicate,
                Expression innerEffectivePredicate,
                Expression joinPredicate,
                Collection outerSymbols,
                Collection innerSymbols)
        {
            checkArgument(outerSymbols.containsAll(extractUnique(outerEffectivePredicate)), "outerEffectivePredicate must only contain symbols from outerSymbols");
            checkArgument(innerSymbols.containsAll(extractUnique(innerEffectivePredicate)), "innerEffectivePredicate must only contain symbols from innerSymbols");

            ImmutableList.Builder outerPushdownConjuncts = ImmutableList.builder();
            ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder();
            ImmutableList.Builder postJoinConjuncts = ImmutableList.builder();
            ImmutableList.Builder joinConjuncts = ImmutableList.builder();

            // Strip out non-deterministic conjuncts
            extractConjuncts(inheritedPredicate).stream()
                    .filter(expression -> !isDeterministic(expression, metadata))
                    .forEach(postJoinConjuncts::add);
            inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);

            outerEffectivePredicate = filterDeterministicConjuncts(metadata, outerEffectivePredicate);
            innerEffectivePredicate = filterDeterministicConjuncts(metadata, innerEffectivePredicate);
            extractConjuncts(joinPredicate).stream()
                    .filter(expression -> !isDeterministic(expression, metadata))
                    .forEach(joinConjuncts::add);
            joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate);

            // Generate equality inferences
            EqualityInference inheritedInference = new EqualityInference(metadata, inheritedPredicate);
            EqualityInference outerInference = new EqualityInference(metadata, inheritedPredicate, outerEffectivePredicate);

            Set innerScope = ImmutableSet.copyOf(innerSymbols);
            Set outerScope = ImmutableSet.copyOf(outerSymbols);

            EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(outerScope);
            Expression outerOnlyInheritedEqualities = combineConjuncts(metadata, equalityPartition.getScopeEqualities());
            EqualityInference potentialNullSymbolInference = new EqualityInference(metadata, outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate);

            // Push outer and join equalities into the inner side. For example:
            // SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah'

            EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = new EqualityInference(metadata, outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate);
            innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(innerScope).getScopeEqualities());

            // TODO: we can further improve simplifying the equalities by considering other relationships from the outer side
            EqualityInference.EqualityPartition joinEqualityPartition = new EqualityInference(metadata, joinPredicate).generateEqualitiesPartitionedBy(innerScope);
            innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities());
            joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities())
                    .addAll(joinEqualityPartition.getScopeStraddlingEqualities());

            // Add the equalities from the inferences back in
            outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
            postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
            postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

            // See if we can push inherited predicates down
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                Expression outerRewritten = outerInference.rewrite(conjunct, outerScope);
                if (outerRewritten != null) {
                    outerPushdownConjuncts.add(outerRewritten);

                    // A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side
                    Expression innerRewritten = potentialNullSymbolInference.rewrite(outerRewritten, innerScope);
                    if (innerRewritten != null) {
                        innerPushdownConjuncts.add(innerRewritten);
                    }
                }
                else {
                    postJoinConjuncts.add(conjunct);
                }
            });

            // See if we can push down any outer effective predicates to the inner side
            EqualityInference.nonInferrableConjuncts(metadata, outerEffectivePredicate)
                    .map(conjunct -> potentialNullSymbolInference.rewrite(conjunct, innerScope))
                    .filter(Objects::nonNull)
                    .forEach(innerPushdownConjuncts::add);

            // See if we can push down join predicates to the inner side
            EqualityInference.nonInferrableConjuncts(metadata, joinPredicate).forEach(conjunct -> {
                Expression innerRewritten = potentialNullSymbolInference.rewrite(conjunct, innerScope);
                if (innerRewritten != null) {
                    innerPushdownConjuncts.add(innerRewritten);
                }
                else {
                    joinConjuncts.add(conjunct);
                }
            });

            return new OuterJoinPushDownResult(combineConjuncts(metadata, outerPushdownConjuncts.build()),
                    combineConjuncts(metadata, innerPushdownConjuncts.build()),
                    combineConjuncts(metadata, joinConjuncts.build()),
                    combineConjuncts(metadata, postJoinConjuncts.build()));
        }

        private static class OuterJoinPushDownResult
        {
            private final Expression outerJoinPredicate;
            private final Expression innerJoinPredicate;
            private final Expression joinPredicate;
            private final Expression postJoinPredicate;

            private OuterJoinPushDownResult(Expression outerJoinPredicate, Expression innerJoinPredicate, Expression joinPredicate, Expression postJoinPredicate)
            {
                this.outerJoinPredicate = outerJoinPredicate;
                this.innerJoinPredicate = innerJoinPredicate;
                this.joinPredicate = joinPredicate;
                this.postJoinPredicate = postJoinPredicate;
            }

            private Expression getOuterJoinPredicate()
            {
                return outerJoinPredicate;
            }

            private Expression getInnerJoinPredicate()
            {
                return innerJoinPredicate;
            }

            public Expression getJoinPredicate()
            {
                return joinPredicate;
            }

            private Expression getPostJoinPredicate()
            {
                return postJoinPredicate;
            }
        }

        private InnerJoinPushDownResult processInnerJoin(
                Expression inheritedPredicate,
                Expression leftEffectivePredicate,
                Expression rightEffectivePredicate,
                Expression joinPredicate,
                Collection leftSymbols,
                Collection rightSymbols)
        {
            checkArgument(leftSymbols.containsAll(extractUnique(leftEffectivePredicate)), "leftEffectivePredicate must only contain symbols from leftSymbols");
            checkArgument(rightSymbols.containsAll(extractUnique(rightEffectivePredicate)), "rightEffectivePredicate must only contain symbols from rightSymbols");

            ImmutableList.Builder leftPushDownConjuncts = ImmutableList.builder();
            ImmutableList.Builder rightPushDownConjuncts = ImmutableList.builder();
            ImmutableList.Builder joinConjuncts = ImmutableList.builder();

            // Strip out non-deterministic conjuncts
            extractConjuncts(inheritedPredicate).stream()
                    .filter(deterministic -> !isDeterministic(deterministic, metadata))
                    .forEach(joinConjuncts::add);
            inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);

            extractConjuncts(joinPredicate).stream()
                    .filter(expression -> !isDeterministic(expression, metadata))
                    .forEach(joinConjuncts::add);
            joinPredicate = filterDeterministicConjuncts(metadata, joinPredicate);

            leftEffectivePredicate = filterDeterministicConjuncts(metadata, leftEffectivePredicate);
            rightEffectivePredicate = filterDeterministicConjuncts(metadata, rightEffectivePredicate);

            ImmutableSet leftScope = ImmutableSet.copyOf(leftSymbols);
            ImmutableSet rightScope = ImmutableSet.copyOf(rightSymbols);

            // Attempt to simplify the effective left/right predicates with the predicate we're pushing down
            // This, effectively, inlines any constants derived from such predicate
            EqualityInference predicateInference = new EqualityInference(metadata, inheritedPredicate);
            Expression simplifiedLeftEffectivePredicate = predicateInference.rewrite(leftEffectivePredicate, leftScope);
            Expression simplifiedRightEffectivePredicate = predicateInference.rewrite(rightEffectivePredicate, rightScope);

            // Generate equality inferences
            EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate, simplifiedRightEffectivePredicate);
            EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate, simplifiedRightEffectivePredicate);
            EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate);

            // Add equalities from the inference back in
            leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(leftScope).getScopeEqualities());
            rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(rightScope).getScopeEqualities());
            joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(leftScope).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate

            // Sort through conjuncts in inheritedPredicate that were not used for inference
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                Expression leftRewrittenConjunct = allInference.rewrite(conjunct, leftScope);
                if (leftRewrittenConjunct != null) {
                    leftPushDownConjuncts.add(leftRewrittenConjunct);
                }

                Expression rightRewrittenConjunct = allInference.rewrite(conjunct, rightScope);
                if (rightRewrittenConjunct != null) {
                    rightPushDownConjuncts.add(rightRewrittenConjunct);
                }

                // Drop predicate after join only if unable to push down to either side
                if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
                    joinConjuncts.add(conjunct);
                }
            });

            // See if we can push the right effective predicate to the left side
            EqualityInference.nonInferrableConjuncts(metadata, simplifiedRightEffectivePredicate)
                    .map(conjunct -> allInference.rewrite(conjunct, leftScope))
                    .filter(Objects::nonNull)
                    .forEach(leftPushDownConjuncts::add);

            // See if we can push the left effective predicate to the right side
            EqualityInference.nonInferrableConjuncts(metadata, simplifiedLeftEffectivePredicate)
                    .map(conjunct -> allInference.rewrite(conjunct, rightScope))
                    .filter(Objects::nonNull)
                    .forEach(rightPushDownConjuncts::add);

            // See if we can push any parts of the join predicates to either side
            EqualityInference.nonInferrableConjuncts(metadata, joinPredicate).forEach(conjunct -> {
                Expression leftRewritten = allInference.rewrite(conjunct, leftScope);
                if (leftRewritten != null) {
                    leftPushDownConjuncts.add(leftRewritten);
                }

                Expression rightRewritten = allInference.rewrite(conjunct, rightScope);
                if (rightRewritten != null) {
                    rightPushDownConjuncts.add(rightRewritten);
                }

                if (leftRewritten == null && rightRewritten == null) {
                    joinConjuncts.add(conjunct);
                }
            });

            return new InnerJoinPushDownResult(
                    combineConjuncts(metadata, leftPushDownConjuncts.build()),
                    combineConjuncts(metadata, rightPushDownConjuncts.build()),
                    combineConjuncts(metadata, joinConjuncts.build()),
                    TRUE_LITERAL);
        }

        private static class InnerJoinPushDownResult
        {
            private final Expression leftPredicate;
            private final Expression rightPredicate;
            private final Expression joinPredicate;
            private final Expression postJoinPredicate;

            private InnerJoinPushDownResult(Expression leftPredicate, Expression rightPredicate, Expression joinPredicate, Expression postJoinPredicate)
            {
                this.leftPredicate = leftPredicate;
                this.rightPredicate = rightPredicate;
                this.joinPredicate = joinPredicate;
                this.postJoinPredicate = postJoinPredicate;
            }

            private Expression getLeftPredicate()
            {
                return leftPredicate;
            }

            private Expression getRightPredicate()
            {
                return rightPredicate;
            }

            private Expression getJoinPredicate()
            {
                return joinPredicate;
            }

            private Expression getPostJoinPredicate()
            {
                return postJoinPredicate;
            }
        }

        private Expression extractJoinPredicate(JoinNode joinNode)
        {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
                builder.add(equiJoinClause.toExpression());
            }
            joinNode.getFilter().ifPresent(builder::add);
            return combineConjuncts(metadata, builder.build());
        }

        private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, Expression inheritedPredicate)
        {
            checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType());

            if (node.getType() == JoinNode.Type.INNER) {
                return node;
            }

            if (node.getType() == JoinNode.Type.FULL) {
                boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate);
                boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate);
                if (!canConvertToLeftJoin && !canConvertToRightJoin) {
                    return node;
                }
                if (canConvertToLeftJoin && canConvertToRightJoin) {
                    return new JoinNode(
                            node.getId(),
                            INNER,
                            node.getLeft(),
                            node.getRight(),
                            node.getCriteria(),
                            node.getLeftOutputSymbols(),
                            node.getRightOutputSymbols(),
                            node.isMaySkipOutputDuplicates(),
                            node.getFilter(),
                            node.getLeftHashSymbol(),
                            node.getRightHashSymbol(),
                            node.getDistributionType(),
                            node.isSpillable(),
                            node.getDynamicFilters(),
                            node.getReorderJoinStatsAndCost());
                }
                return new JoinNode(
                        node.getId(),
                        canConvertToLeftJoin ? LEFT : RIGHT,
                        node.getLeft(),
                        node.getRight(),
                        node.getCriteria(),
                        node.getLeftOutputSymbols(),
                        node.getRightOutputSymbols(),
                        node.isMaySkipOutputDuplicates(),
                        node.getFilter(),
                        node.getLeftHashSymbol(),
                        node.getRightHashSymbol(),
                        node.getDistributionType(),
                        node.isSpillable(),
                        node.getDynamicFilters(),
                        node.getReorderJoinStatsAndCost());
            }

            if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) ||
                    node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) {
                return node;
            }
            return new JoinNode(
                    node.getId(),
                    JoinNode.Type.INNER,
                    node.getLeft(),
                    node.getRight(),
                    node.getCriteria(),
                    node.getLeftOutputSymbols(),
                    node.getRightOutputSymbols(),
                    node.isMaySkipOutputDuplicates(),
                    node.getFilter(),
                    node.getLeftHashSymbol(),
                    node.getRightHashSymbol(),
                    node.getDistributionType(),
                    node.isSpillable(),
                    node.getDynamicFilters(),
                    node.getReorderJoinStatsAndCost());
        }

        private boolean canConvertOuterToInner(List innerSymbolsForOuterJoin, Expression inheritedPredicate)
        {
            Set innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin);
            for (Expression conjunct : extractConjuncts(inheritedPredicate)) {
                if (isDeterministic(conjunct, metadata)) {
                    // Ignore a conjunct for this test if we cannot deterministically get responses from it
                    Object response = nullInputEvaluator(innerSymbols, conjunct);
                    if (response == null || response instanceof NullLiteral || Boolean.FALSE.equals(response)) {
                        // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join
                        // then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join.
                        // So, let's just rewrite this join as an INNER join
                        return true;
                    }
                }
            }
            return false;
        }

        // Temporary implementation for joins because the SimplifyExpressions optimizers cannot run properly on join clauses
        private Expression simplifyExpression(Expression expression)
        {
            Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression);
            ExpressionInterpreter optimizer = new ExpressionInterpreter(expression, plannerContext, session, expressionTypes);
            return literalEncoder.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(NodeRef.of(expression)));
        }

        private boolean areExpressionsEquivalent(Expression leftExpression, Expression rightExpression)
        {
            return expressionEquivalence.areExpressionsEquivalent(session, leftExpression, rightExpression, types);
        }

        /**
         * Evaluates an expression's response to binding the specified input symbols to NULL
         */
        private Object nullInputEvaluator(Collection nullSymbols, Expression expression)
        {
            Map, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression);
            return new ExpressionInterpreter(expression, plannerContext, session, expressionTypes)
                    .optimize(symbol -> nullSymbols.contains(symbol) ? null : symbol.toSymbolReference());
        }

        private boolean joinEqualityExpression(Expression expression, Collection leftSymbols, Collection rightSymbols)
        {
            return joinComparisonExpression(expression, leftSymbols, rightSymbols, ImmutableSet.of(EQUAL));
        }

        private boolean joinDynamicFilteringExpression(Expression expression, Collection leftSymbols, Collection rightSymbols)
        {
            ComparisonExpression comparison;
            if (expression instanceof NotExpression notExpression) {
                boolean isDistinctFrom = joinComparisonExpression(notExpression.getValue(), leftSymbols, rightSymbols, ImmutableSet.of(IS_DISTINCT_FROM));
                if (!isDistinctFrom) {
                    return false;
                }
                comparison = (ComparisonExpression) notExpression.getValue();
                Set expressionTypes = ImmutableSet.of(
                        typeAnalyzer.getType(session, types, comparison.getLeft()),
                        typeAnalyzer.getType(session, types, comparison.getRight()));
                // Dynamic filtering is not supported with IS NOT DISTINCT FROM clause on REAL or DOUBLE types to avoid dealing with NaN values
                if (expressionTypes.contains(REAL) || expressionTypes.contains(DOUBLE)) {
                    return false;
                }
            }
            else {
                if (!joinComparisonExpression(expression, leftSymbols, rightSymbols, DYNAMIC_FILTERING_SUPPORTED_COMPARISONS)) {
                    return false;
                }
                comparison = (ComparisonExpression) expression;
            }

            // Build side expression must be a symbol reference, since DynamicFilterSourceOperator can only collect column values (not expressions)
            return (comparison.getRight() instanceof SymbolReference && rightSymbols.contains(Symbol.from(comparison.getRight())))
                    || (comparison.getLeft() instanceof SymbolReference && rightSymbols.contains(Symbol.from(comparison.getLeft())));
        }

        private boolean joinComparisonExpression(Expression expression, Collection leftSymbols, Collection rightSymbols, Set operators)
        {
            // At this point in time, our join predicates need to be deterministic
            if (expression instanceof ComparisonExpression comparison && isDeterministic(expression, metadata)) {
                if (operators.contains(comparison.getOperator())) {
                    Set symbols1 = extractUnique(comparison.getLeft());
                    Set symbols2 = extractUnique(comparison.getRight());
                    if (symbols1.isEmpty() || symbols2.isEmpty()) {
                        return false;
                    }
                    return (leftSymbols.containsAll(symbols1) && rightSymbols.containsAll(symbols2)) ||
                            (rightSymbols.containsAll(symbols1) && leftSymbols.containsAll(symbols2));
                }
            }
            return false;
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();
            if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput().toSymbolReference())) {
                return visitNonFilteringSemiJoin(node, context);
            }
            return visitFilteringSemiJoin(node, context);
        }

        private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();
            List sourceConjuncts = new ArrayList<>();
            List postJoinConjuncts = new ArrayList<>();

            // TODO: see if there are predicates that can be inferred from the semi join output

            PlanNode rewrittenFilteringSource = context.defaultRewrite(node.getFilteringSource(), TRUE_LITERAL);

            // Push inheritedPredicates down to the source if they don't involve the semi join output
            ImmutableSet sourceScope = ImmutableSet.copyOf(node.getSource().getOutputSymbols());
            EqualityInference inheritedInference = new EqualityInference(metadata, inheritedPredicate);
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                Expression rewrittenConjunct = inheritedInference.rewrite(conjunct, sourceScope);
                // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down
                if (rewrittenConjunct != null) {
                    sourceConjuncts.add(rewrittenConjunct);
                }
                else {
                    postJoinConjuncts.add(conjunct);
                }
            });

            // Add the inherited equality predicates back in
            EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(sourceScope);
            sourceConjuncts.addAll(equalityPartition.getScopeEqualities());
            postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
            postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

            PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(metadata, sourceConjuncts));

            PlanNode output = node;
            if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) {
                output = new SemiJoinNode(
                        node.getId(),
                        rewrittenSource,
                        rewrittenFilteringSource,
                        node.getSourceJoinSymbol(),
                        node.getFilteringSourceJoinSymbol(),
                        node.getSemiJoinOutput(),
                        node.getSourceHashSymbol(),
                        node.getFilteringSourceHashSymbol(),
                        node.getDistributionType(),
                        Optional.empty());
            }
            if (!postJoinConjuncts.isEmpty()) {
                output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(metadata, postJoinConjuncts));
            }
            return output;
        }

        private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();
            Expression deterministicInheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);
            Expression sourceEffectivePredicate = filterDeterministicConjuncts(metadata, effectivePredicateExtractor.extract(session, node.getSource(), types, typeAnalyzer));
            Expression filteringSourceEffectivePredicate = filterDeterministicConjuncts(metadata, effectivePredicateExtractor.extract(session, node.getFilteringSource(), types, typeAnalyzer));
            Expression joinExpression = new ComparisonExpression(
                    EQUAL,
                    node.getSourceJoinSymbol().toSymbolReference(),
                    node.getFilteringSourceJoinSymbol().toSymbolReference());

            List sourceSymbols = node.getSource().getOutputSymbols();
            List filteringSourceSymbols = node.getFilteringSource().getOutputSymbols();

            List sourceConjuncts = new ArrayList<>();
            List filteringSourceConjuncts = new ArrayList<>();
            List postJoinConjuncts = new ArrayList<>();

            // Generate equality inferences
            EqualityInference allInference = new EqualityInference(metadata, deterministicInheritedPredicate, sourceEffectivePredicate, filteringSourceEffectivePredicate, joinExpression);
            EqualityInference allInferenceWithoutSourceInferred = new EqualityInference(metadata, deterministicInheritedPredicate, filteringSourceEffectivePredicate, joinExpression);
            EqualityInference allInferenceWithoutFilteringSourceInferred = new EqualityInference(metadata, deterministicInheritedPredicate, sourceEffectivePredicate, joinExpression);

            // Push inheritedPredicates down to the source if they don't involve the semi join output
            Set sourceScope = ImmutableSet.copyOf(sourceSymbols);
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                Expression rewrittenConjunct = allInference.rewrite(conjunct, sourceScope);
                // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down
                if (rewrittenConjunct != null) {
                    sourceConjuncts.add(rewrittenConjunct);
                }
                else {
                    postJoinConjuncts.add(conjunct);
                }
            });

            // Push inheritedPredicates down to the filtering source if possible
            Set filterScope = ImmutableSet.copyOf(filteringSourceSymbols);
            EqualityInference.nonInferrableConjuncts(metadata, deterministicInheritedPredicate).forEach(conjunct -> {
                Expression rewrittenConjunct = allInference.rewrite(conjunct, filterScope);
                // We cannot push non-deterministic predicates to filtering side. Each filtering side row have to be
                // logically reevaluated for each source row.
                if (rewrittenConjunct != null) {
                    filteringSourceConjuncts.add(rewrittenConjunct);
                }
            });

            // move effective predicate conjuncts source <-> filter
            // See if we can push the filtering source effective predicate to the source side
            EqualityInference.nonInferrableConjuncts(metadata, filteringSourceEffectivePredicate)
                    .map(conjunct -> allInference.rewrite(conjunct, sourceScope))
                    .filter(Objects::nonNull)
                    .forEach(sourceConjuncts::add);

            // See if we can push the source effective predicate to the filtering source side
            EqualityInference.nonInferrableConjuncts(metadata, sourceEffectivePredicate)
                    .map(conjunct -> allInference.rewrite(conjunct, filterScope))
                    .filter(Objects::nonNull)
                    .forEach(filteringSourceConjuncts::add);

            // Add equalities from the inference back in
            sourceConjuncts.addAll(allInferenceWithoutSourceInferred.generateEqualitiesPartitionedBy(sourceScope).getScopeEqualities());
            filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(filterScope).getScopeEqualities());

            // Add dynamic filtering predicate
            Optional dynamicFilterId = node.getDynamicFilterId();
            if (dynamicFilterId.isEmpty() && isEnableDynamicFiltering(session) && dynamicFiltering) {
                dynamicFilterId = Optional.of(new DynamicFilterId("df_" + idAllocator.getNextId().toString()));
                Symbol sourceSymbol = node.getSourceJoinSymbol();
                sourceConjuncts.add(createDynamicFilterExpression(
                        metadata,
                        dynamicFilterId.get(),
                        symbolAllocator.getTypes().get(sourceSymbol),
                        sourceSymbol.toSymbolReference(),
                        EQUAL));
            }

            PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(metadata, sourceConjuncts));
            PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), combineConjuncts(metadata, filteringSourceConjuncts));

            PlanNode output = node;
            if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource() || !dynamicFilterId.equals(node.getDynamicFilterId())) {
                output = new SemiJoinNode(
                        node.getId(),
                        rewrittenSource,
                        rewrittenFilteringSource,
                        node.getSourceJoinSymbol(),
                        node.getFilteringSourceJoinSymbol(),
                        node.getSemiJoinOutput(),
                        node.getSourceHashSymbol(),
                        node.getFilteringSourceHashSymbol(),
                        node.getDistributionType(),
                        dynamicFilterId);
            }
            if (!postJoinConjuncts.isEmpty()) {
                output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(metadata, postJoinConjuncts));
            }
            return output;
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext context)
        {
            if (node.hasEmptyGroupingSet()) {
                // TODO: in case of grouping sets, we should be able to push the filters over grouping keys below the aggregation
                // and also preserve the filter above the aggregation if it has an empty grouping set
                return visitPlan(node, context);
            }

            Expression inheritedPredicate = context.get();

            EqualityInference equalityInference = new EqualityInference(metadata, inheritedPredicate);

            List pushdownConjuncts = new ArrayList<>();
            List postAggregationConjuncts = new ArrayList<>();

            // Strip out non-deterministic conjuncts
            extractConjuncts(inheritedPredicate).stream()
                    .filter(expression -> !isDeterministic(expression, metadata))
                    .forEach(postAggregationConjuncts::add);
            inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);

            // Sort non-equality predicates by those that can be pushed down and those that cannot
            Set groupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                if (node.getGroupIdSymbol().isPresent() && extractUnique(conjunct).contains(node.getGroupIdSymbol().get())) {
                    // aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we
                    // need to preserve any predicates that evaluate the group id to run after the aggregation
                    // TODO: we should be able to infer if conditions on grouping() correspond to global grouping sets to determine whether
                    // we need to do this for each specific case
                    postAggregationConjuncts.add(conjunct);
                }
                else {
                    Expression rewrittenConjunct = equalityInference.rewrite(conjunct, groupingKeys);
                    if (rewrittenConjunct != null) {
                        pushdownConjuncts.add(rewrittenConjunct);
                    }
                    else {
                        postAggregationConjuncts.add(conjunct);
                    }
                }
            });

            // Add the equality predicates back in
            EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(groupingKeys);
            pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
            postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
            postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

            PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(metadata, pushdownConjuncts));

            PlanNode output = node;
            if (rewrittenSource != node.getSource()) {
                output = AggregationNode.builderFrom(node)
                        .setSource(rewrittenSource)
                        .setPreGroupedSymbols(ImmutableList.of())
                        .build();
            }
            if (!postAggregationConjuncts.isEmpty()) {
                output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(metadata, postAggregationConjuncts));
            }
            return output;
        }

        @Override
        public PlanNode visitUnnest(UnnestNode node, RewriteContext context)
        {
            Expression inheritedPredicate = context.get();
            if (node.getJoinType() == RIGHT || node.getJoinType() == FULL) {
                return new FilterNode(idAllocator.getNextId(), node, inheritedPredicate);
            }

            //TODO for LEFT or INNER join type, push down UnnestNode's filter on replicate symbols
            EqualityInference equalityInference = new EqualityInference(metadata, inheritedPredicate);

            List pushdownConjuncts = new ArrayList<>();
            List postUnnestConjuncts = new ArrayList<>();

            // Strip out non-deterministic conjuncts
            extractConjuncts(inheritedPredicate).stream()
                    .filter(expression -> !isDeterministic(expression, metadata))
                    .forEach(postUnnestConjuncts::add);
            inheritedPredicate = filterDeterministicConjuncts(metadata, inheritedPredicate);

            // Sort non-equality predicates by those that can be pushed down and those that cannot
            Set replicatedSymbols = ImmutableSet.copyOf(node.getReplicateSymbols());
            EqualityInference.nonInferrableConjuncts(metadata, inheritedPredicate).forEach(conjunct -> {
                Expression rewrittenConjunct = equalityInference.rewrite(conjunct, replicatedSymbols);
                if (rewrittenConjunct != null) {
                    pushdownConjuncts.add(rewrittenConjunct);
                }
                else {
                    postUnnestConjuncts.add(conjunct);
                }
            });

            // Add the equality predicates back in
            EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(replicatedSymbols);
            pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
            postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
            postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

            PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(metadata, pushdownConjuncts));

            PlanNode output = node;
            if (rewrittenSource != node.getSource()) {
                output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getMappings(), node.getOrdinalitySymbol(), node.getJoinType(), node.getFilter());
            }
            if (!postUnnestConjuncts.isEmpty()) {
                output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(metadata, postUnnestConjuncts));
            }
            return output;
        }

        @Override
        public PlanNode visitSample(SampleNode node, RewriteContext context)
        {
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, RewriteContext context)
        {
            Expression predicate = simplifyExpression(context.get());

            if (!TRUE_LITERAL.equals(predicate)) {
                return new FilterNode(idAllocator.getNextId(), node, predicate);
            }

            return node;
        }

        @Override
        public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context)
        {
            Set predicateSymbols = extractUnique(context.get());
            checkState(!predicateSymbols.contains(node.getIdColumn()), "UniqueId in predicate is not yet supported");
            return context.defaultRewrite(node, context.get());
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy