All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.ir.IrUtils Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.ir;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.stream;
import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.TRUE;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public final class IrUtils
{
    private IrUtils() {}

    static void validateType(Type expected, Expression expression)
    {
        checkArgument(expected.equals(expression.type()), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
    }

    public static List extractConjuncts(Expression expression)
    {
        return extractPredicates(Logical.Operator.AND, expression);
    }

    public static List extractDisjuncts(Expression expression)
    {
        return extractPredicates(Logical.Operator.OR, expression);
    }

    public static List extractPredicates(Logical expression)
    {
        return extractPredicates(expression.operator(), expression);
    }

    public static List extractPredicates(Logical.Operator operator, Expression expression)
    {
        ImmutableList.Builder resultBuilder = ImmutableList.builder();
        extractPredicates(operator, expression, resultBuilder);
        return resultBuilder.build();
    }

    private static void extractPredicates(Logical.Operator operator, Expression expression, ImmutableList.Builder resultBuilder)
    {
        if (expression instanceof Logical logical && logical.operator() == operator) {
            for (Expression term : logical.terms()) {
                extractPredicates(operator, term, resultBuilder);
            }
        }
        else {
            resultBuilder.add(expression);
        }
    }

    public static Expression and(Expression... expressions)
    {
        return and(Arrays.asList(expressions));
    }

    public static Expression and(Collection expressions)
    {
        return logicalExpression(Logical.Operator.AND, expressions);
    }

    public static Expression or(Expression... expressions)
    {
        return or(Arrays.asList(expressions));
    }

    public static Expression or(Collection expressions)
    {
        return logicalExpression(Logical.Operator.OR, expressions);
    }

    public static Expression logicalExpression(Logical.Operator operator, Collection expressions)
    {
        requireNonNull(operator, "operator is null");
        requireNonNull(expressions, "expressions is null");

        if (expressions.isEmpty()) {
            return switch (operator) {
                case AND -> TRUE;
                case OR -> FALSE;
            };
        }

        if (expressions.size() == 1) {
            return Iterables.getOnlyElement(expressions);
        }

        return new Logical(operator, ImmutableList.copyOf(expressions));
    }

    public static Expression combinePredicates(Logical.Operator operator, Collection expressions)
    {
        if (operator == Logical.Operator.AND) {
            return combineConjuncts(expressions);
        }

        return combineDisjuncts(expressions);
    }

    public static Expression combineConjuncts(Expression... expressions)
    {
        return combineConjuncts(Arrays.asList(expressions));
    }

    public static Expression combineConjuncts(Collection expressions)
    {
        requireNonNull(expressions, "expressions is null");

        List conjuncts = expressions.stream()
                .flatMap(e -> extractConjuncts(e).stream())
                .filter(e -> !e.equals(TRUE))
                .collect(toList());

        conjuncts = removeDuplicates(conjuncts);

        if (conjuncts.contains(FALSE)) {
            return FALSE;
        }

        return and(conjuncts);
    }

    public static Expression combineConjunctsWithDuplicates(Collection expressions)
    {
        requireNonNull(expressions, "expressions is null");

        List conjuncts = expressions.stream()
                .flatMap(e -> extractConjuncts(e).stream())
                .filter(e -> !e.equals(TRUE))
                .collect(toList());

        if (conjuncts.contains(FALSE)) {
            return FALSE;
        }

        return and(conjuncts);
    }

    public static Expression combineDisjuncts(Expression... expressions)
    {
        return combineDisjuncts(Arrays.asList(expressions));
    }

    public static Expression combineDisjuncts(Collection expressions)
    {
        return combineDisjunctsWithDefault(expressions, FALSE);
    }

    public static Expression combineDisjunctsWithDefault(Collection expressions, Expression emptyDefault)
    {
        requireNonNull(expressions, "expressions is null");

        List disjuncts = expressions.stream()
                .flatMap(e -> extractDisjuncts(e).stream())
                .filter(e -> !e.equals(FALSE))
                .collect(toList());

        disjuncts = removeDuplicates(disjuncts);

        if (disjuncts.contains(TRUE)) {
            return TRUE;
        }

        return disjuncts.isEmpty() ? emptyDefault : or(disjuncts);
    }

    public static Expression filterDeterministicConjuncts(Expression expression)
    {
        return filterConjuncts(expression, DeterminismEvaluator::isDeterministic);
    }

    public static Expression filterNonDeterministicConjuncts(Expression expression)
    {
        return filterConjuncts(expression, not(DeterminismEvaluator::isDeterministic));
    }

    public static Expression filterConjuncts(Expression expression, Predicate predicate)
    {
        List conjuncts = extractConjuncts(expression).stream()
                .filter(predicate)
                .collect(toList());

        return combineConjuncts(conjuncts);
    }

    @SafeVarargs
    public static Function expressionOrNullSymbols(Predicate... nullSymbolScopes)
    {
        return expression -> {
            ImmutableList.Builder resultDisjunct = ImmutableList.builder();
            resultDisjunct.add(expression);

            for (Predicate nullSymbolScope : nullSymbolScopes) {
                List symbols = SymbolsExtractor.extractUnique(expression).stream()
                        .filter(nullSymbolScope)
                        .collect(toImmutableList());

                if (symbols.isEmpty()) {
                    continue;
                }

                ImmutableList.Builder nullConjuncts = ImmutableList.builder();
                for (Symbol symbol : symbols) {
                    nullConjuncts.add(new IsNull(symbol.toSymbolReference()));
                }

                resultDisjunct.add(and(nullConjuncts.build()));
            }

            return or(resultDisjunct.build());
        };
    }

    /**
     * Removes duplicate deterministic expressions. Preserves the relative order
     * of the expressions in the list.
     */
    private static List removeDuplicates(List expressions)
    {
        Set seen = new HashSet<>();

        ImmutableList.Builder result = ImmutableList.builder();
        for (Expression expression : expressions) {
            if (!DeterminismEvaluator.isDeterministic(expression)) {
                result.add(expression);
            }
            else if (!seen.contains(expression)) {
                result.add(expression);
                seen.add(expression);
            }
        }

        return result.build();
    }

    public static Stream preOrder(Expression node)
    {
        return stream(
                Traverser.forTree((SuccessorsFunction) Expression::children)
                        .depthFirstPreOrder(requireNonNull(node, "node is null")));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy