All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.sanity.ValidateDependenciesChecker Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.sanity;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.AdaptivePlanNode;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.ExceptNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.ExplainAnalyzeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.IndexJoinNode;
import io.trino.sql.planner.plan.IndexSourceNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.MarkDistinctNode;
import io.trino.sql.planner.plan.MergeProcessorNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.OffsetNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PatternRecognitionNode.Measure;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RefreshMaterializedViewNode;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.SampleNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SetOperationNode;
import io.trino.sql.planner.plan.SimpleTableExecuteNode;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.StatisticAggregationsDescriptor;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableDeleteNode;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableFunctionNode;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableUpdateNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointers;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer;

/**
 * Ensures that all dependencies (i.e., symbols in expressions) for a plan node are provided by its source nodes
 */
public final class ValidateDependenciesChecker
        implements PlanSanityChecker.Checker
{
    @Override
    public void validate(PlanNode plan,
            Session session,
            PlannerContext plannerContext,
            WarningCollector warningCollector)
    {
        validate(plan);
    }

    public static void validate(PlanNode plan)
    {
        plan.accept(new Visitor(), ImmutableSet.of());
    }

    private static class Visitor
            extends PlanVisitor>
    {
        @Override
        protected Void visitPlan(PlanNode node, Set boundSymbols)
        {
            throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName());
        }

        @Override
        public Void visitAdaptivePlanNode(AdaptivePlanNode node, Set boundSymbols)
        {
            PlanNode source = node.getCurrentPlan();
            source.accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitAggregation(AggregationNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            checkDependencies(inputs, node.getGroupingKeys(), "Invalid node. Grouping key symbols (%s) not in source plan output (%s)", node.getGroupingKeys(), node.getSource().getOutputSymbols());

            for (Aggregation aggregation : node.getAggregations().values()) {
                Set dependencies = extractUnique(aggregation);
                checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols());
            }

            return null;
        }

        @Override
        public Void visitGroupId(GroupIdNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkDependencies(source.getOutputSymbols(), node.getInputSymbols(), "Invalid node. Grouping symbols (%s) not in source plan output (%s)", node.getInputSymbols(), source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitMarkDistinct(MarkDistinctNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkDependencies(source.getOutputSymbols(), node.getDistinctSymbols(), "Invalid node. Mark distinct symbols (%s) not in source plan output (%s)", node.getDistinctSymbols(), source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitPatternRecognition(PatternRecognitionNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);

            checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols());
            if (node.getOrderingScheme().isPresent()) {
                checkDependencies(
                        inputs,
                        node.getOrderingScheme().get().orderBy(),
                        "Invalid node. Order by symbols (%s) not in source plan output (%s)",
                        node.getOrderingScheme().get().orderBy(), node.getSource().getOutputSymbols());
            }

            node.getCommonBaseFrame()
                    .flatMap(WindowNode.Frame::getEndValue)
                    .ifPresent(value -> checkDependencies(inputs, ImmutableList.of(value), "Invalid node. Frame end symbol (%s) not in source plan output (%s)", value, node.getSource().getOutputSymbols()));

            for (WindowNode.Function function : node.getWindowFunctions().values()) {
                Set dependencies = extractUnique(function);
                checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols());
            }

            Set measuresSymbols = node.getMeasures().values().stream()
                    .map(Measure::getExpressionAndValuePointers)
                    .map(ExpressionAndValuePointers::getInputSymbols)
                    .flatMap(Collection::stream)
                    .collect(toImmutableSet());
            checkDependencies(inputs, measuresSymbols, "Invalid node. Symbols used in measures (%s) not in source plan output (%s)", measuresSymbols, node.getSource().getOutputSymbols());

            node.getCommonBaseFrame()
                    .flatMap(WindowNode.Frame::getEndValue)
                    .ifPresent(symbol -> checkDependencies(inputs, ImmutableSet.of(symbol), "Invalid node. Frame offset symbol (%s) not in source plan output (%s)", symbol, node.getSource().getOutputSymbols()));

            Set variableDefinitionsSymbols = node.getVariableDefinitions().values().stream()
                    .map(ExpressionAndValuePointers::getInputSymbols)
                    .flatMap(Collection::stream)
                    .collect(toImmutableSet());
            checkDependencies(inputs, variableDefinitionsSymbols, "Invalid node. Symbols used in measures (%s) not in source plan output (%s)", variableDefinitionsSymbols, node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitTableFunction(TableFunctionNode node, Set boundSymbols)
        {
            for (int i = 0; i < node.getSources().size(); i++) {
                PlanNode source = node.getSources().get(i);
                source.accept(this, boundSymbols);
                Set inputs = createInputs(source, boundSymbols);
                TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i);

                checkDependencies(
                        inputs,
                        argumentProperties.requiredColumns(),
                        "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)",
                        argumentProperties.argumentName(),
                        argumentProperties.requiredColumns(),
                        source.getOutputSymbols());
                argumentProperties.specification().ifPresent(specification -> {
                    checkDependencies(
                            inputs,
                            specification.partitionBy(),
                            "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)",
                            argumentProperties.argumentName(),
                            specification.partitionBy(),
                            source.getOutputSymbols());
                    specification.orderingScheme().ifPresent(orderingScheme -> {
                        checkDependencies(
                                inputs,
                                orderingScheme.orderBy(),
                                "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)",
                                argumentProperties.argumentName(),
                                orderingScheme.orderBy(),
                                source.getOutputSymbols());
                    });
                });
                Set passThroughSymbols = argumentProperties.passThroughSpecification().columns().stream()
                        .map(PassThroughColumn::symbol)
                        .collect(toImmutableSet());
                checkDependencies(
                        inputs,
                        passThroughSymbols,
                        "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)",
                        argumentProperties.argumentName(),
                        passThroughSymbols,
                        source.getOutputSymbols());
            }

            return null;
        }

        @Override
        public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundSymbols)
        {
            if (node.getSource().isEmpty()) {
                return null;
            }

            PlanNode source = node.getSource().orElseThrow();
            source.accept(this, boundSymbols);

            Set inputs = createInputs(source, boundSymbols);

            Set passThroughSymbols = node.getPassThroughSpecifications().stream()
                    .map(PassThroughSpecification::columns)
                    .flatMap(Collection::stream)
                    .map(PassThroughColumn::symbol)
                    .collect(toImmutableSet());
            checkDependencies(
                    inputs,
                    passThroughSymbols,
                    "Invalid node. Pass-through symbols (%s) not in source plan output (%s)",
                    passThroughSymbols,
                    source.getOutputSymbols());

            Set requiredSymbols = node.getRequiredSymbols().stream()
                    .flatMap(Collection::stream)
                    .collect(toImmutableSet());
            checkDependencies(
                    inputs,
                    requiredSymbols,
                    "Invalid node. Required symbols (%s) not in source plan output (%s)",
                    requiredSymbols,
                    source.getOutputSymbols());

            node.getMarkerSymbols().ifPresent(mapping -> {
                checkDependencies(
                        inputs,
                        mapping.keySet(),
                        "Invalid node. Source symbols (%s) not in source plan output (%s)",
                        mapping.keySet(),
                        source.getOutputSymbols());
                checkDependencies(
                        inputs,
                        mapping.values(),
                        "Invalid node. Source marker symbols (%s) not in source plan output (%s)",
                        mapping.values(),
                        source.getOutputSymbols());
            });

            node.getSpecification().ifPresent(specification -> {
                checkDependencies(
                        inputs,
                        specification.partitionBy(),
                        "Invalid node. Partition by symbols (%s) not in source plan output (%s)",
                        specification.partitionBy(),
                        source.getOutputSymbols());
                specification.orderingScheme().ifPresent(orderingScheme -> {
                    checkDependencies(
                            inputs,
                            orderingScheme.orderBy(),
                            "Invalid node. Order by symbols (%s) not in source plan output (%s)",
                            orderingScheme.orderBy(),
                            source.getOutputSymbols());
                });
            });

            return null;
        }

        @Override
        public Void visitWindow(WindowNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);

            checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols());
            if (node.getOrderingScheme().isPresent()) {
                checkDependencies(
                        inputs,
                        node.getOrderingScheme().get().orderBy(),
                        "Invalid node. Order by symbols (%s) not in source plan output (%s)",
                        node.getOrderingScheme().get().orderBy(), node.getSource().getOutputSymbols());
            }

            ImmutableList.Builder bounds = ImmutableList.builder();
            for (WindowNode.Frame frame : node.getFrames()) {
                if (frame.getStartValue().isPresent()) {
                    bounds.add(frame.getStartValue().get());
                }
                if (frame.getEndValue().isPresent()) {
                    bounds.add(frame.getEndValue().get());
                }
            }
            checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols());

            ImmutableList.Builder symbolsForFrameBoundsComparison = ImmutableList.builder();
            for (WindowNode.Frame frame : node.getFrames()) {
                if (frame.getSortKeyCoercedForFrameStartComparison().isPresent()) {
                    symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameStartComparison().get());
                }
                if (frame.getSortKeyCoercedForFrameEndComparison().isPresent()) {
                    symbolsForFrameBoundsComparison.add(frame.getSortKeyCoercedForFrameEndComparison().get());
                }
            }
            checkDependencies(inputs, symbolsForFrameBoundsComparison.build(), "Invalid node. Symbols for frame bound comparison (%s) not in source plan output (%s)", symbolsForFrameBoundsComparison.build(), node.getSource().getOutputSymbols());

            for (WindowNode.Function function : node.getWindowFunctions().values()) {
                Set dependencies = extractUnique(function);
                checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols());
            }

            return null;
        }

        @Override
        public Void visitTopNRanking(TopNRankingNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols());
            checkDependencies(
                    inputs,
                    node.getOrderingScheme().orderBy(),
                    "Invalid node. Order by symbols (%s) not in source plan output (%s)",
                    node.getOrderingScheme().orderBy(), node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitRowNumber(RowNumberNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkDependencies(source.getOutputSymbols(), node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitFilter(FilterNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols());

            Set dependencies = extractUnique(node.getPredicate());
            checkDependencies(inputs, dependencies, "Invalid node. Predicate dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitSample(SampleNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitProject(ProjectNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            for (Expression expression : node.getAssignments().getExpressions()) {
                Set dependencies = extractUnique(expression);
                checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs);
            }

            return null;
        }

        @Override
        public Void visitTopN(TopNNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkDependencies(
                    inputs,
                    node.getOrderingScheme().orderBy(),
                    "Invalid node. Order by dependencies (%s) not in source plan output (%s)",
                    node.getOrderingScheme().orderBy(),
                    node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitSort(SortNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            Set inputs = createInputs(source, boundSymbols);
            checkDependencies(inputs, node.getOutputSymbols(), "Invalid node. Output symbols (%s) not in source plan output (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkDependencies(
                    inputs,
                    node.getOrderingScheme().orderBy(),
                    "Invalid node. Order by dependencies (%s) not in source plan output (%s)",
                    node.getOrderingScheme().orderBy(), node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitOutput(OutputNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitOffset(OffsetNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitLimit(LimitNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            if (node.getTiesResolvingScheme().isPresent()) {
                checkDependencies(
                        createInputs(source, boundSymbols),
                        node.getTiesResolvingScheme().get().orderBy(),
                        "Invalid node. Ties resolving dependencies (%s) not in source plan output (%s)",
                        node.getTiesResolvingScheme().get().orderBy(), node.getSource().getOutputSymbols());
            }

            checkDependencies(
                    source.getOutputSymbols(),
                    node.getPreSortedInputs(),
                    "Invalid node. Pre-sorted input column dependencies (%s) not in source plan output (%s)",
                    node.getPreSortedInputs(),
                    source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitDistinctLimit(DistinctLimitNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitJoin(JoinNode node, Set boundSymbols)
        {
            node.getLeft().accept(this, boundSymbols);
            node.getRight().accept(this, boundSymbols);

            Set leftInputs = createInputs(node.getLeft(), boundSymbols);
            Set rightInputs = createInputs(node.getRight(), boundSymbols);
            Set allInputs = ImmutableSet.builder()
                    .addAll(leftInputs)
                    .addAll(rightInputs)
                    .build();

            for (JoinNode.EquiJoinClause clause : node.getCriteria()) {
                checkArgument(leftInputs.contains(clause.getLeft()), "Symbol from join clause (%s) not in left source (%s)", clause.getLeft(), node.getLeft().getOutputSymbols());
                checkArgument(rightInputs.contains(clause.getRight()), "Symbol from join clause (%s) not in right source (%s)", clause.getRight(), node.getRight().getOutputSymbols());
            }

            node.getFilter().ifPresent(predicate -> {
                Set predicateSymbols = extractUnique(predicate);
                checkArgument(
                        allInputs.containsAll(predicateSymbols),
                        "Symbol from filter (%s) not in sources (%s)",
                        predicateSymbols,
                        allInputs);
            });

            if (node.isCrossJoin()) {
                Set inputs = ImmutableSet.builder()
                        .addAll(node.getLeft().getOutputSymbols())
                        .addAll(node.getRight().getOutputSymbols())
                        .build();
                checkDependencies(node.getOutputSymbols(), inputs, "Cross join output symbols (%s) must contain all of the source symbols (%s)", node.getOutputSymbols(), inputs);
            }

            return null;
        }

        @Override
        public Void visitSemiJoin(SemiJoinNode node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols);
            node.getFilteringSource().accept(this, boundSymbols);

            checkArgument(node.getSource().getOutputSymbols().contains(node.getSourceJoinSymbol()), "Symbol from semi join clause (%s) not in source (%s)", node.getSourceJoinSymbol(), node.getSource().getOutputSymbols());
            checkArgument(node.getFilteringSource().getOutputSymbols().contains(node.getFilteringSourceJoinSymbol()), "Symbol from semi join clause (%s) not in filtering source (%s)", node.getSourceJoinSymbol(), node.getFilteringSource().getOutputSymbols());

            Set outputs = createInputs(node, boundSymbols);
            checkArgument(outputs.containsAll(node.getSource().getOutputSymbols()), "Semi join output symbols (%s) must contain all of the source symbols (%s)", node.getOutputSymbols(), node.getSource().getOutputSymbols());
            checkArgument(outputs.contains(node.getSemiJoinOutput()),
                    "Semi join output symbols (%s) must contain join result (%s)",
                    node.getOutputSymbols(),
                    node.getSemiJoinOutput());

            return null;
        }

        @Override
        public Void visitSpatialJoin(SpatialJoinNode node, Set boundSymbols)
        {
            node.getLeft().accept(this, boundSymbols);
            node.getRight().accept(this, boundSymbols);

            Set leftInputs = createInputs(node.getLeft(), boundSymbols);
            Set rightInputs = createInputs(node.getRight(), boundSymbols);
            Set allInputs = ImmutableSet.builder()
                    .addAll(leftInputs)
                    .addAll(rightInputs)
                    .build();

            Set predicateSymbols = extractUnique(node.getFilter());
            checkArgument(
                    allInputs.containsAll(predicateSymbols),
                    "Symbol from filter (%s) not in sources (%s)",
                    predicateSymbols,
                    allInputs);

            checkLeftOutputSymbolsBeforeRight(node.getLeft().getOutputSymbols(), node.getOutputSymbols());
            return null;
        }

        private void checkLeftOutputSymbolsBeforeRight(List leftSymbols, List outputSymbols)
        {
            int leftMaxPosition = -1;
            Optional rightMinPosition = Optional.empty();
            Set leftSymbolsSet = new HashSet<>(leftSymbols);
            for (int i = 0; i < outputSymbols.size(); i++) {
                Symbol symbol = outputSymbols.get(i);
                if (leftSymbolsSet.contains(symbol)) {
                    leftMaxPosition = i;
                }
                else if (rightMinPosition.isEmpty()) {
                    rightMinPosition = Optional.of(i);
                }
            }
            checkState(rightMinPosition.isEmpty() || rightMinPosition.get() > leftMaxPosition, "Not all left output symbols are before right output symbols");
        }

        @Override
        public Void visitIndexJoin(IndexJoinNode node, Set boundSymbols)
        {
            node.getProbeSource().accept(this, boundSymbols);
            node.getIndexSource().accept(this, boundSymbols);

            Set probeInputs = createInputs(node.getProbeSource(), boundSymbols);
            Set indexSourceInputs = createInputs(node.getIndexSource(), boundSymbols);
            for (IndexJoinNode.EquiJoinClause clause : node.getCriteria()) {
                checkArgument(probeInputs.contains(clause.getProbe()), "Probe symbol from index join clause (%s) not in probe source (%s)", clause.getProbe(), node.getProbeSource().getOutputSymbols());
                checkArgument(indexSourceInputs.contains(clause.getIndex()), "Index symbol from index join clause (%s) not in index source (%s)", clause.getIndex(), node.getIndexSource().getOutputSymbols());
            }

            Set lookupSymbols = node.getCriteria().stream()
                    .map(IndexJoinNode.EquiJoinClause::getIndex)
                    .collect(toImmutableSet());
            Map trace = IndexKeyTracer.trace(node.getIndexSource(), lookupSymbols);
            checkArgument(!trace.isEmpty(), "Index lookup symbols are not traceable to index source: %s", lookupSymbols);

            return null;
        }

        @Override
        public Void visitIndexSource(IndexSourceNode node, Set boundSymbols)
        {
            checkDependencies(node.getOutputSymbols(), node.getLookupSymbols(), "Lookup symbols must be part of output symbols");
            checkDependencies(node.getAssignments().keySet(), node.getOutputSymbols(), "Assignments must contain mappings for output symbols");

            return null;
        }

        @Override
        public Void visitDynamicFilterSource(DynamicFilterSourceNode node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols); // visit child
            checkDependencies(node.getOutputSymbols(), node.getDynamicFilters().values(), "Dynamic filter symbols must be part of output symbols");

            return null;
        }

        @Override
        public Void visitTableScan(TableScanNode node, Set boundSymbols)
        {
            //We don't have to do a check here as TableScanNode has no dependencies.
            return null;
        }

        @Override
        public Void visitValues(ValuesNode node, Set boundSymbols)
        {
            Set correlatedDependencies = extractUnique(node);
            checkDependencies(
                    boundSymbols,
                    correlatedDependencies,
                    "Invalid node. Expression correlated dependencies (%s) not satisfied by (%s)",
                    correlatedDependencies,
                    boundSymbols);
            return null;
        }

        @Override
        public Void visitUnnest(UnnestNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols);

            ImmutableSet.Builder required = ImmutableSet.builder()
                    .addAll(node.getReplicateSymbols());

            node.getMappings().stream()
                    .map(UnnestNode.Mapping::getInput)
                    .forEach(required::add);

            checkDependencies(source.getOutputSymbols(), required.build(), "Invalid node. Dependencies (%s) not in source plan output (%s)", required, source.getOutputSymbols());

            return null;
        }

        @Override
        public Void visitRemoteSource(RemoteSourceNode node, Set boundSymbols)
        {
            return null;
        }

        @Override
        public Void visitExchange(ExchangeNode node, Set boundSymbols)
        {
            for (int i = 0; i < node.getSources().size(); i++) {
                PlanNode subplan = node.getSources().get(i);
                checkDependencies(subplan.getOutputSymbols(), node.getInputs().get(i), "EXCHANGE subplan must provide all of the necessary symbols");
                subplan.accept(this, boundSymbols); // visit child
            }

            checkDependencies(node.getOutputSymbols(), node.getPartitioningScheme().getOutputLayout(), "EXCHANGE must provide all of the necessary symbols for partition function");

            return null;
        }

        @Override
        public Void visitRefreshMaterializedView(RefreshMaterializedViewNode node, Set boundSymbols)
        {
            return null;
        }

        @Override
        public Void visitTableWriter(TableWriterNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitTableExecute(TableExecuteNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child
            return null;
        }

        @Override
        public Void visitMergeWriter(MergeWriterNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child
            return null;
        }

        @Override
        public Void visitMergeProcessor(MergeProcessorNode node, Set boundSymbols)
        {
            PlanNode source = node.getSource();
            source.accept(this, boundSymbols); // visit child

            checkArgument(source.getOutputSymbols().contains(node.getRowIdSymbol()), "Invalid node. rowId symbol (%s) is not in source plan output (%s)", node.getRowIdSymbol(), node.getSource().getOutputSymbols());
            checkArgument(source.getOutputSymbols().contains(node.getMergeRowSymbol()), "Invalid node. Merge row symbol (%s) is not in source plan output (%s)", node.getMergeRowSymbol(), node.getSource().getOutputSymbols());

            return null;
        }

        @Override
        public Void visitSimpleTableExecuteNode(SimpleTableExecuteNode node, Set context)
        {
            return null;
        }

        @Override
        public Void visitTableDelete(TableDeleteNode node, Set boundSymbols)
        {
            return null;
        }

        @Override
        public Void visitTableUpdate(TableUpdateNode node, Set boundSymbols)
        {
            return null;
        }

        @Override
        public Void visitStatisticsWriterNode(StatisticsWriterNode node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols); // visit child

            StatisticAggregationsDescriptor descriptor = node.getDescriptor();
            Set dependencies = ImmutableSet.builder()
                    .addAll(descriptor.getGrouping().values())
                    .addAll(descriptor.getColumnStatistics().values())
                    .addAll(descriptor.getTableStatistics().values())
                    .build();
            List outputSymbols = node.getSource().getOutputSymbols();
            checkDependencies(outputSymbols, dependencies, "Invalid node. Dependencies (%s) not in source plan output (%s)", dependencies, outputSymbols);
            return null;
        }

        @Override
        public Void visitTableFinish(TableFinishNode node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitUnion(UnionNode node, Set boundSymbols)
        {
            return visitSetOperation(node, boundSymbols);
        }

        private Void visitSetOperation(SetOperationNode node, Set boundSymbols)
        {
            for (int i = 0; i < node.getSources().size(); i++) {
                PlanNode subplan = node.getSources().get(i);
                checkDependencies(subplan.getOutputSymbols(), node.sourceOutputLayout(i), "%s subplan must provide all of the necessary symbols", node.getClass().getSimpleName());
                subplan.accept(this, boundSymbols); // visit child
            }

            return null;
        }

        @Override
        public Void visitIntersect(IntersectNode node, Set boundSymbols)
        {
            return visitSetOperation(node, boundSymbols);
        }

        @Override
        public Void visitExcept(ExceptNode node, Set boundSymbols)
        {
            return visitSetOperation(node, boundSymbols);
        }

        @Override
        public Void visitEnforceSingleRow(EnforceSingleRowNode node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitAssignUniqueId(AssignUniqueId node, Set boundSymbols)
        {
            node.getSource().accept(this, boundSymbols); // visit child

            return null;
        }

        @Override
        public Void visitApply(ApplyNode node, Set boundSymbols)
        {
            Set subqueryCorrelation = ImmutableSet.builder()
                    .addAll(boundSymbols)
                    .addAll(node.getCorrelation())
                    .build();

            node.getInput().accept(this, boundSymbols); // visit child
            node.getSubquery().accept(this, subqueryCorrelation); // visit child

            checkDependencies(node.getInput().getOutputSymbols(), node.getCorrelation(), "APPLY input must provide all the necessary correlation symbols for subquery");

            ImmutableSet inputs = ImmutableSet.builder()
                    .addAll(createInputs(node.getSubquery(), boundSymbols))
                    .addAll(createInputs(node.getInput(), boundSymbols))
                    .build();

            List dependencies = node.getSubqueryAssignments().values().stream()
                    .flatMap(assignment -> switch (assignment) {
                        case ApplyNode.In in -> Stream.of(in.value(), in.reference());
                        case ApplyNode.QuantifiedComparison comparison -> Stream.of(comparison.value(), comparison.reference());
                        case ApplyNode.Exists unused -> Stream.empty();
                    })
                    .toList();

            checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs);

            return null;
        }

        @Override
        public Void visitCorrelatedJoin(CorrelatedJoinNode node, Set boundSymbols)
        {
            Set subqueryCorrelation = ImmutableSet.builder()
                    .addAll(boundSymbols)
                    .addAll(node.getCorrelation())
                    .build();

            node.getInput().accept(this, boundSymbols); // visit child
            node.getSubquery().accept(this, subqueryCorrelation); // visit child

            checkDependencies(
                    node.getInput().getOutputSymbols(),
                    node.getCorrelation(),
                    "Correlated JOIN input must provide all the necessary correlation symbols for subquery");

            Set inputs = ImmutableSet.builder()
                    .addAll(createInputs(node.getInput(), boundSymbols))
                    .addAll(createInputs(node.getSubquery(), boundSymbols))
                    .build();

            Set filterSymbols = extractUnique(node.getFilter());

            checkDependencies(inputs, filterSymbols, "filter symbols (%s) not in sources (%s)", filterSymbols, inputs);

            return null;
        }

        private static ImmutableSet createInputs(PlanNode source, Set boundSymbols)
        {
            return ImmutableSet.builder()
                    .addAll(source.getOutputSymbols())
                    .addAll(boundSymbols)
                    .build();
        }
    }

    private static void checkDependencies(Collection inputs, Collection required, String message, Object... parameters)
    {
        checkArgument(ImmutableSet.copyOf(inputs).containsAll(required), message, parameters);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy