All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.PlanCopier Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner;

import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.optimizations.UnaliasSymbolReferences;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.ExceptNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.OffsetNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SampleNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;

import java.util.List;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
 * Clones plan and assigns new PlanNodeIds to the copied PlanNodes.
 * Also, replaces all symbols in the copied plan with new symbols.
 * The original and copied plans can be safely used in different
 * branches of plan.
 */
public final class PlanCopier
{
    private PlanCopier() {}

    public static NodeAndMappings copyPlan(PlanNode plan, List fields, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
    {
        return copyPlan(plan, fields, symbolAllocator, idAllocator, Lookup.noLookup());
    }

    public static NodeAndMappings copyPlan(PlanNode plan, List fields, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup)
    {
        PlanNode copy = SimplePlanRewriter.rewriteWith(new Copier(idAllocator, lookup), plan, null);
        return new UnaliasSymbolReferences().reallocateSymbols(copy, fields, symbolAllocator);
    }

    private static class Copier
            extends SimplePlanRewriter
    {
        private final PlanNodeIdAllocator idAllocator;
        private final Lookup lookup;

        private Copier(PlanNodeIdAllocator idAllocator, Lookup lookup)
        {
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
            this.lookup = requireNonNull(lookup, "lookup is null");
        }

        @Override
        protected PlanNode visitPlan(PlanNode node, RewriteContext context)
        {
            throw new UnsupportedOperationException("plan copying not implemented for " + node.getClass().getSimpleName());
        }

        @Override
        public PlanNode visitGroupReference(GroupReference node, RewriteContext context)
        {
            return context.rewrite(lookup.resolve(node));
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext context)
        {
            return AggregationNode.builderFrom(node)
                    .setId(idAllocator.getNextId())
                    .setSource(context.rewrite(node.getSource()))
                    .build();
        }

        @Override
        public PlanNode visitFilter(FilterNode node, RewriteContext context)
        {
            return new FilterNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getPredicate());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext context)
        {
            return new ProjectNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAssignments());
        }

        @Override
        public PlanNode visitTopN(TopNNode node, RewriteContext context)
        {
            return new TopNNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount(), node.getOrderingScheme(), node.getStep());
        }

        @Override
        public PlanNode visitOffset(OffsetNode node, RewriteContext context)
        {
            return new OffsetNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount());
        }

        @Override
        public PlanNode visitLimit(LimitNode node, RewriteContext context)
        {
            return new LimitNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getCount(), node.getTiesResolvingScheme(), node.isPartial(), node.getPreSortedInputs());
        }

        @Override
        public PlanNode visitSample(SampleNode node, RewriteContext context)
        {
            return new SampleNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getSampleRatio(), node.getSampleType());
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, RewriteContext context)
        {
            return new TableScanNode(
                    idAllocator.getNextId(),
                    node.getTable(),
                    node.getOutputSymbols(),
                    node.getAssignments(),
                    node.getEnforcedConstraint(),
                    node.getStatistics(),
                    node.isUpdateTarget(),
                    node.getUseConnectorNodePartitioning());
        }

        @Override
        public PlanNode visitValues(ValuesNode node, RewriteContext context)
        {
            return new ValuesNode(idAllocator.getNextId(), node.getOutputSymbols(), node.getRowCount(), node.getRows());
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext context)
        {
            return new JoinNode(
                    idAllocator.getNextId(),
                    node.getType(),
                    context.rewrite(node.getLeft()),
                    context.rewrite(node.getRight()),
                    node.getCriteria(),
                    node.getLeftOutputSymbols(),
                    node.getRightOutputSymbols(),
                    node.isMaySkipOutputDuplicates(),
                    node.getFilter(),
                    node.getLeftHashSymbol(),
                    node.getRightHashSymbol(),
                    node.getDistributionType(),
                    node.isSpillable(),
                    node.getDynamicFilters(),
                    node.getReorderJoinStatsAndCost());
        }

        @Override
        public PlanNode visitDynamicFilterSource(DynamicFilterSourceNode node, RewriteContext context)
        {
            return new DynamicFilterSourceNode(idAllocator.getNextId(), node.getSource(), node.getDynamicFilters());
        }

        @Override
        public PlanNode visitSort(SortNode node, RewriteContext context)
        {
            return new SortNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getOrderingScheme(), node.isPartial());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, RewriteContext context)
        {
            return new WindowNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getSpecification(), node.getWindowFunctions(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix());
        }

        @Override
        public PlanNode visitPatternRecognition(PatternRecognitionNode node, RewriteContext context)
        {
            return new PatternRecognitionNode(
                    idAllocator.getNextId(),
                    context.rewrite(node.getSource()),
                    node.getSpecification(),
                    node.getHashSymbol(),
                    node.getPrePartitionedInputs(),
                    node.getPreSortedOrderPrefix(),
                    node.getWindowFunctions(),
                    node.getMeasures(),
                    node.getCommonBaseFrame(),
                    node.getRowsPerMatch(),
                    node.getSkipToLabels(),
                    node.getSkipToPosition(),
                    node.isInitial(),
                    node.getPattern(),
                    node.getVariableDefinitions());
        }

        @Override
        public PlanNode visitUnion(UnionNode node, RewriteContext context)
        {
            List copiedSources = node.getSources().stream()
                    .map(context::rewrite)
                    .collect(toImmutableList());
            return new UnionNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols());
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, RewriteContext context)
        {
            List copiedSources = node.getSources().stream()
                    .map(context::rewrite)
                    .collect(toImmutableList());
            return new IntersectNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols(), node.isDistinct());
        }

        @Override
        public PlanNode visitExcept(ExceptNode node, RewriteContext context)
        {
            List copiedSources = node.getSources().stream()
                    .map(context::rewrite)
                    .collect(toImmutableList());
            return new ExceptNode(idAllocator.getNextId(), copiedSources, node.getSymbolMapping(), node.getOutputSymbols(), node.isDistinct());
        }

        @Override
        public PlanNode visitUnnest(UnnestNode node, RewriteContext context)
        {
            return new UnnestNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getReplicateSymbols(), node.getMappings(), node.getOrdinalitySymbol(), node.getJoinType());
        }

        @Override
        public PlanNode visitGroupId(GroupIdNode node, RewriteContext context)
        {
            return new GroupIdNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getGroupingSets(), node.getGroupingColumns(), node.getAggregationArguments(), node.getGroupIdSymbol());
        }

        @Override
        public PlanNode visitEnforceSingleRow(EnforceSingleRowNode node, RewriteContext context)
        {
            return new EnforceSingleRowNode(idAllocator.getNextId(), context.rewrite(node.getSource()));
        }

        @Override
        public PlanNode visitApply(ApplyNode node, RewriteContext context)
        {
            return new ApplyNode(idAllocator.getNextId(), context.rewrite(node.getInput()), context.rewrite(node.getSubquery()), node.getSubqueryAssignments(), node.getCorrelation(), node.getOriginSubquery());
        }

        @Override
        public PlanNode visitCorrelatedJoin(CorrelatedJoinNode node, RewriteContext context)
        {
            return new CorrelatedJoinNode(idAllocator.getNextId(), context.rewrite(node.getInput()), context.rewrite(node.getSubquery()), node.getCorrelation(), node.getType(), node.getFilter(), node.getOriginSubquery());
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy