All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.metadata.Metadata;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.LogicalExpression;

import java.util.Collection;
import java.util.List;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.ExpressionUtils.combinePredicates;
import static io.trino.sql.ExpressionUtils.extractPredicates;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.tree.LogicalExpression.Operator.OR;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

public final class ExtractCommonPredicatesExpressionRewriter
{
    public static Expression extractCommonPredicates(Metadata metadata, Expression expression)
    {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(metadata), expression, NodeContext.ROOT_NODE);
    }

    private ExtractCommonPredicatesExpressionRewriter() {}

    private static class Visitor
            extends ExpressionRewriter
    {
        private final Metadata metadata;

        public Visitor(Metadata metadata)
        {
            this.metadata = requireNonNull(metadata, "metadata is null");
        }

        @Override
        protected Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter treeRewriter)
        {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
            }

            return null;
        }

        @Override
        public Expression rewriteLogicalExpression(LogicalExpression node, NodeContext context, ExpressionTreeRewriter treeRewriter)
        {
            Expression expression = combinePredicates(
                    metadata,
                    node.getOperator(),
                    extractPredicates(node.getOperator(), node).stream()
                            .map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE))
                            .collect(toImmutableList()));

            if (!(expression instanceof LogicalExpression)) {
                return expression;
            }

            Expression simplified = extractCommonPredicates((LogicalExpression) expression);

            // Prefer AND LogicalBinaryExpression at the root if possible
            if (context.isRootNode() && simplified instanceof LogicalExpression && ((LogicalExpression) simplified).getOperator() == OR) {
                return distributeIfPossible((LogicalExpression) simplified);
            }

            return simplified;
        }

        private Expression extractCommonPredicates(LogicalExpression node)
        {
            List> subPredicates = getSubPredicates(node);

            Set commonPredicates = ImmutableSet.copyOf(subPredicates.stream()
                    .map(this::filterDeterministicPredicates)
                    .reduce(Sets::intersection)
                    .orElse(emptySet()));

            List> uncorrelatedSubPredicates = subPredicates.stream()
                    .map(predicateList -> removeAll(predicateList, commonPredicates))
                    .collect(toImmutableList());

            LogicalExpression.Operator flippedOperator = node.getOperator().flip();

            List uncorrelatedPredicates = uncorrelatedSubPredicates.stream()
                    .map(predicate -> combinePredicates(metadata, flippedOperator, predicate))
                    .collect(toImmutableList());
            Expression combinedUncorrelatedPredicates = combinePredicates(metadata, node.getOperator(), uncorrelatedPredicates);

            return combinePredicates(metadata, flippedOperator, ImmutableList.builder()
                    .addAll(commonPredicates)
                    .add(combinedUncorrelatedPredicates)
                    .build());
        }

        private static List> getSubPredicates(LogicalExpression expression)
        {
            return extractPredicates(expression.getOperator(), expression).stream()
                    .map(predicate -> predicate instanceof LogicalExpression ?
                            extractPredicates((LogicalExpression) predicate) : ImmutableList.of(predicate))
                    .collect(toImmutableList());
        }

        /**
         * Applies the boolean distributive property.
         * 

* For example: * ( A & B ) | ( C & D ) => ( A | C ) & ( A | D ) & ( B | C ) & ( B | D) *

* Returns the original expression if the expression is non-deterministic or if the distribution will * expand the expression by too much. */ private Expression distributeIfPossible(LogicalExpression expression) { if (!isDeterministic(expression, metadata)) { // Do not distribute boolean expressions if there are any non-deterministic elements // TODO: This can be optimized further if non-deterministic elements are not repeated return expression; } List> subPredicates = getSubPredicates(expression).stream() .map(ImmutableSet::copyOf) .collect(toList()); int originalBaseExpressions = subPredicates.stream() .mapToInt(Set::size) .sum(); int newBaseExpressions; try { newBaseExpressions = Math.multiplyExact(subPredicates.stream() .mapToInt(Set::size) .reduce(Math::multiplyExact) .getAsInt(), subPredicates.size()); } catch (ArithmeticException e) { // Integer overflow from multiplication means there are too many expressions return expression; } if (newBaseExpressions > originalBaseExpressions * 2) { // Do not distribute boolean expressions if it would create 2x more base expressions // (e.g. A, B, C, D from the above example). This is just an arbitrary heuristic to // avoid cross product expression explosion. return expression; } Set> crossProduct = Sets.cartesianProduct(subPredicates); return combinePredicates( metadata, expression.getOperator().flip(), crossProduct.stream() .map(expressions -> combinePredicates(metadata, expression.getOperator(), expressions)) .collect(toImmutableList())); } private Set filterDeterministicPredicates(List predicates) { return predicates.stream() .filter(expression -> isDeterministic(expression, metadata)) .collect(toSet()); } private static List removeAll(Collection collection, Collection elementsToRemove) { return collection.stream() .filter(element -> !elementsToRemove.contains(element)) .collect(toImmutableList()); } } private enum NodeContext { ROOT_NODE, NOT_ROOT_NODE; boolean isRootNode() { return this == ROOT_NODE; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy