All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Logical;
import io.trino.sql.planner.DeterminismEvaluator;

import java.util.Collection;
import java.util.List;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.ir.IrUtils.combinePredicates;
import static io.trino.sql.ir.IrUtils.extractPredicates;
import static io.trino.sql.ir.Logical.Operator.OR;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

public final class ExtractCommonPredicatesExpressionRewriter
{
    public static Expression extractCommonPredicates(Expression expression)
    {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression, NodeContext.ROOT_NODE);
    }

    private ExtractCommonPredicatesExpressionRewriter() {}

    private static class Visitor
            extends ExpressionRewriter
    {
        @Override
        protected Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter treeRewriter)
        {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
            }

            return null;
        }

        @Override
        public Expression rewriteLogical(Logical node, NodeContext context, ExpressionTreeRewriter treeRewriter)
        {
            Expression expression = combinePredicates(
                    node.operator(),
                    extractPredicates(node.operator(), node).stream()
                            .map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE))
                            .collect(toImmutableList()));

            if (!(expression instanceof Logical)) {
                return expression;
            }

            Expression simplified = extractCommonPredicates((Logical) expression);

            // Prefer AND LogicalBinaryExpression at the root if possible
            if (context.isRootNode() && simplified instanceof Logical && ((Logical) simplified).operator() == OR) {
                return distributeIfPossible((Logical) simplified);
            }

            return simplified;
        }

        private Expression extractCommonPredicates(Logical node)
        {
            List> subPredicates = getSubPredicates(node);

            Set commonPredicates = ImmutableSet.copyOf(subPredicates.stream()
                    .map(this::filterDeterministicPredicates)
                    .reduce(Sets::intersection)
                    .orElse(emptySet()));

            List> uncorrelatedSubPredicates = subPredicates.stream()
                    .map(predicateList -> removeAll(predicateList, commonPredicates))
                    .collect(toImmutableList());

            Logical.Operator flippedOperator = node.operator().flip();

            List uncorrelatedPredicates = uncorrelatedSubPredicates.stream()
                    .map(predicate -> combinePredicates(flippedOperator, predicate))
                    .collect(toImmutableList());
            Expression combinedUncorrelatedPredicates = combinePredicates(node.operator(), uncorrelatedPredicates);

            return combinePredicates(flippedOperator, ImmutableList.builder()
                    .addAll(commonPredicates)
                    .add(combinedUncorrelatedPredicates)
                    .build());
        }

        private static List> getSubPredicates(Logical expression)
        {
            return extractPredicates(expression.operator(), expression).stream()
                    .map(predicate -> predicate instanceof Logical ?
                            extractPredicates((Logical) predicate) : ImmutableList.of(predicate))
                    .collect(toImmutableList());
        }

        /**
         * Applies the boolean distributive property.
         * 

* For example: * ( A & B ) | ( C & D ) => ( A | C ) & ( A | D ) & ( B | C ) & ( B | D) *

* Returns the original expression if the expression is non-deterministic or if the distribution will * expand the expression by too much. */ private Expression distributeIfPossible(Logical expression) { if (!isDeterministic(expression)) { // Do not distribute boolean expressions if there are any non-deterministic elements // TODO: This can be optimized further if non-deterministic elements are not repeated return expression; } List> subPredicates = getSubPredicates(expression).stream() .map(ImmutableSet::copyOf) .collect(toList()); int originalBaseExpressions = subPredicates.stream() .mapToInt(Set::size) .sum(); int newBaseExpressions; try { newBaseExpressions = Math.multiplyExact(subPredicates.stream() .mapToInt(Set::size) .reduce(Math::multiplyExact) .getAsInt(), subPredicates.size()); } catch (ArithmeticException e) { // Integer overflow from multiplication means there are too many expressions return expression; } if (newBaseExpressions > originalBaseExpressions * 2) { // Do not distribute boolean expressions if it would create 2x more base expressions // (e.g. A, B, C, D from the above example). This is just an arbitrary heuristic to // avoid cross product expression explosion. return expression; } Set> crossProduct = Sets.cartesianProduct(subPredicates); return combinePredicates( expression.operator().flip(), crossProduct.stream() .map(expressions -> combinePredicates(expression.operator(), expressions)) .collect(toImmutableList())); } private Set filterDeterministicPredicates(List predicates) { return predicates.stream() .filter(DeterminismEvaluator::isDeterministic) .collect(toSet()); } private static List removeAll(Collection collection, Collection elementsToRemove) { return collection.stream() .filter(element -> !elementsToRemove.contains(element)) .collect(toImmutableList()); } } private enum NodeContext { ROOT_NODE, NOT_ROOT_NODE; boolean isRootNode() { return this == ROOT_NODE; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy