All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.SetOperationNodeTranslator Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SetOperationNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowNode;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.concat;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_FOLLOWING;
import static io.trino.sql.planner.plan.FrameBoundType.UNBOUNDED_PRECEDING;
import static io.trino.sql.planner.plan.WindowFrameType.ROWS;
import static java.util.Objects.requireNonNull;

public class SetOperationNodeTranslator
{
    private static final String MARKER = "marker";
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final ResolvedFunction countFunction;
    private final ResolvedFunction rowNumberFunction;

    public TranslationResult makeSetContainmentPlanForDistinct(SetOperationNode node)
    {
        checkArgument(!(node instanceof UnionNode), "Cannot simplify a UnionNode");
        List markers = allocateSymbols(node.getSources().size(), MARKER, BOOLEAN);
        // identity projection for all the fields in each of the sources plus marker columns
        List withMarkers = appendMarkers(markers, node.getSources(), node);

        // add a union over all the rewritten sources. The outputs of the union have the same name as the
        // original intersect node
        List outputs = node.getOutputSymbols();
        UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));

        // add count aggregations
        List aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT);
        AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs);

        return new TranslationResult(aggregation, aggregationOutputs);
    }

    public SetOperationNodeTranslator(Session session, Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
    {
        this.symbolAllocator = requireNonNull(symbolAllocator, "SymbolAllocator is null");
        this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
        requireNonNull(metadata, "metadata is null");
        this.countFunction = metadata.resolveBuiltinFunction("count", fromTypes(BOOLEAN));
        this.rowNumberFunction = metadata.resolveBuiltinFunction("row_number", ImmutableList.of());
    }

    public TranslationResult makeSetContainmentPlanForAll(SetOperationNode node)
    {
        checkArgument(!(node instanceof UnionNode), "Cannot simplify a UnionNode");
        List markers = allocateSymbols(node.getSources().size(), MARKER, BOOLEAN);
        // identity projection for all the fields in each of the sources plus marker columns
        List withMarkers = appendMarkers(markers, node.getSources(), node);

        // add a union over all the rewritten sources
        List outputs = node.getOutputSymbols();
        UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));

        // add counts and row number
        List countOutputs = allocateSymbols(markers.size(), "count", BIGINT);
        Symbol rowNumberSymbol = symbolAllocator.newSymbol("row_number", BIGINT);
        WindowNode window = appendCounts(union, outputs, markers, countOutputs, rowNumberSymbol);

        // prune markers
        ProjectNode project = new ProjectNode(
                idAllocator.getNextId(),
                window,
                Assignments.identity(ImmutableList.copyOf(concat(outputs, countOutputs, ImmutableList.of(rowNumberSymbol)))));

        return new TranslationResult(project, countOutputs, Optional.of(rowNumberSymbol));
    }

    private List allocateSymbols(int count, String nameHint, Type type)
    {
        ImmutableList.Builder symbolsBuilder = ImmutableList.builder();
        for (int i = 0; i < count; i++) {
            symbolsBuilder.add(symbolAllocator.newSymbol(nameHint, type));
        }
        return symbolsBuilder.build();
    }

    private List appendMarkers(List markers, List nodes, SetOperationNode node)
    {
        ImmutableList.Builder result = ImmutableList.builder();
        for (int i = 0; i < nodes.size(); i++) {
            result.add(appendMarkers(idAllocator, symbolAllocator, nodes.get(i), i, markers, node.sourceSymbolMap(i)));
        }
        return result.build();
    }

    private static PlanNode appendMarkers(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, PlanNode source, int markerIndex, List markers, Map projections)
    {
        Assignments.Builder assignments = Assignments.builder();
        // add existing intersect symbols to projection
        for (Map.Entry entry : projections.entrySet()) {
            Symbol symbol = symbolAllocator.newSymbol(entry.getKey().name(), entry.getKey().type());
            assignments.put(symbol, entry.getValue());
        }

        // add extra marker fields to the projection
        for (int i = 0; i < markers.size(); ++i) {
            Expression expression = (i == markerIndex) ? TRUE : new Constant(BOOLEAN, null);
            assignments.put(symbolAllocator.newSymbol(markers.get(i).name(), BOOLEAN), expression);
        }

        return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
    }

    private UnionNode union(List nodes, List outputs)
    {
        ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
        for (PlanNode source : nodes) {
            for (int i = 0; i < source.getOutputSymbols().size(); i++) {
                outputsToInputs.put(outputs.get(i), source.getOutputSymbols().get(i));
            }
        }

        return new UnionNode(idAllocator.getNextId(), nodes, outputsToInputs.build(), outputs);
    }

    private AggregationNode computeCounts(UnionNode sourceNode, List originalColumns, List markers, List aggregationOutputs)
    {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();

        for (int i = 0; i < markers.size(); i++) {
            Symbol output = aggregationOutputs.get(i);
            aggregations.put(output, new AggregationNode.Aggregation(
                    countFunction,
                    ImmutableList.of(markers.get(i).toSymbolReference()),
                    false,
                    Optional.empty(),
                    Optional.empty(),
                    Optional.empty()));
        }

        return singleAggregation(idAllocator.getNextId(),
                sourceNode,
                aggregations.buildOrThrow(),
                singleGroupingSet(originalColumns));
    }

    private WindowNode appendCounts(UnionNode sourceNode, List originalColumns, List markers, List countOutputs, Symbol rowNumberSymbol)
    {
        ImmutableMap.Builder functions = ImmutableMap.builder();
        WindowNode.Frame defaultFrame = new WindowNode.Frame(ROWS, UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty());

        for (int i = 0; i < markers.size(); i++) {
            Symbol output = countOutputs.get(i);
            functions.put(output, new WindowNode.Function(
                    countFunction,
                    ImmutableList.of(markers.get(i).toSymbolReference()),
                    defaultFrame,
                    false));
        }

        functions.put(rowNumberSymbol, new WindowNode.Function(
                rowNumberFunction,
                ImmutableList.of(),
                defaultFrame,
                false));

        return new WindowNode(
                idAllocator.getNextId(),
                sourceNode,
                new DataOrganizationSpecification(originalColumns, Optional.empty()),
                functions.buildOrThrow(),
                Optional.empty(),
                ImmutableSet.of(),
                0);
    }

    public static class TranslationResult
    {
        private final PlanNode planNode;
        private final List countSymbols;
        private final Optional rowNumberSymbol;

        public TranslationResult(PlanNode planNode, List countSymbols)
        {
            this(planNode, countSymbols, Optional.empty());
        }

        public TranslationResult(PlanNode planNode, List countSymbols, Optional rowNumberSymbol)
        {
            this.planNode = requireNonNull(planNode, "planNode is null");
            this.countSymbols = ImmutableList.copyOf(requireNonNull(countSymbols, "countSymbols is null"));
            this.rowNumberSymbol = requireNonNull(rowNumberSymbol, "rowNumberSymbol is null");
        }

        public PlanNode getPlanNode()
        {
            return this.planNode;
        }

        public List getCountSymbols()
        {
            return countSymbols;
        }

        public Symbol getRowNumberSymbol()
        {
            checkState(rowNumberSymbol.isPresent(), "rowNumberSymbol is empty");
            return rowNumberSymbol.get();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy