All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.ImplementIntersectAll Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.planner.plan.Patterns.Intersect.distinct;
import static io.trino.sql.planner.plan.Patterns.intersect;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static java.util.Objects.requireNonNull;

/**
 * Implement INTERSECT ALL using union, window and filter.
 * 

* Transforms: *

 * - Intersect all
 *   output: a, b
 *     - Source1 (a1, b1)
 *     - Source2 (a2, b2)
 *     - Source3 (a3, b3)
 * 
* Into: *
 * - Project (prune helper symbols)
 *   output: a, b
 *     - Filter (row_number <= least(least(count1, count2), count3))
 *         - Window (partition by a, b)
 *           count1 <- count(marker1)
 *           count2 <- count(marker2)
 *           count3 <- count(marker3)
 *           row_number <- row_number()
 *               - Union
 *                 output: a, b, marker1, marker2, marker3
 *                   - Project (marker1 <- true, marker2 <- null, marker3 <- null)
 *                       - Source1 (a1, b1)
 *                   - Project (marker1 <- null, marker2 <- true, marker3 <- null)
 *                       - Source2 (a2, b2)
 *                   - Project (marker1 <- null, marker2 <- null, marker3 <- true)
 *                       - Source3 (a3, b3)
 * 
*/ public class ImplementIntersectAll implements Rule { private static final Pattern PATTERN = intersect() .with(distinct().equalTo(false)); private final Metadata metadata; public ImplementIntersectAll(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override public Pattern getPattern() { return PATTERN; } @Override public Result apply(IntersectNode node, Captures captures, Context context) { SetOperationNodeTranslator translator = new SetOperationNodeTranslator(context.getSession(), metadata, context.getSymbolAllocator(), context.getIdAllocator()); SetOperationNodeTranslator.TranslationResult result = translator.makeSetContainmentPlanForAll(node); // compute expected multiplicity for every row checkState(result.getCountSymbols().size() > 0, "IntersectNode translation result has no count symbols"); ResolvedFunction least = metadata.resolveBuiltinFunction("least", fromTypes(BIGINT, BIGINT)); Expression minCount = result.getCountSymbols().get(0).toSymbolReference(); for (int i = 1; i < result.getCountSymbols().size(); i++) { minCount = new FunctionCall(least.toQualifiedName(), ImmutableList.of(minCount, result.getCountSymbols().get(i).toSymbolReference())); } // filter rows so that expected number of rows remains Expression removeExtraRows = new ComparisonExpression(LESS_THAN_OR_EQUAL, result.getRowNumberSymbol().toSymbolReference(), minCount); FilterNode filter = new FilterNode( context.getIdAllocator().getNextId(), result.getPlanNode(), removeExtraRows); // prune helper symbols ProjectNode project = new ProjectNode( context.getIdAllocator().getNextId(), filter, Assignments.identity(node.getOutputSymbols())); return Result.ofPlanNode(project); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy