All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.sql.ir.IrUtils Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.ir;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Streams.stream;
import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.TRUE;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
public final class IrUtils
{
private IrUtils() {}
static void validateType(Type expected, Expression expression)
{
checkArgument(expected.equals(expression.type()), "Expected '%s' type but found '%s' for expression: %s", expected, expression.type(), expression);
}
public static List extractConjuncts(Expression expression)
{
return extractPredicates(Logical.Operator.AND, expression);
}
public static List extractDisjuncts(Expression expression)
{
return extractPredicates(Logical.Operator.OR, expression);
}
public static List extractPredicates(Logical expression)
{
return extractPredicates(expression.operator(), expression);
}
public static List extractPredicates(Logical.Operator operator, Expression expression)
{
ImmutableList.Builder resultBuilder = ImmutableList.builder();
extractPredicates(operator, expression, resultBuilder);
return resultBuilder.build();
}
private static void extractPredicates(Logical.Operator operator, Expression expression, ImmutableList.Builder resultBuilder)
{
if (expression instanceof Logical logical && logical.operator() == operator) {
for (Expression term : logical.terms()) {
extractPredicates(operator, term, resultBuilder);
}
}
else {
resultBuilder.add(expression);
}
}
public static Expression and(Expression... expressions)
{
return and(Arrays.asList(expressions));
}
public static Expression and(Collection expressions)
{
return logicalExpression(Logical.Operator.AND, expressions);
}
public static Expression or(Expression... expressions)
{
return or(Arrays.asList(expressions));
}
public static Expression or(Collection expressions)
{
return logicalExpression(Logical.Operator.OR, expressions);
}
public static Expression logicalExpression(Logical.Operator operator, Collection expressions)
{
requireNonNull(operator, "operator is null");
requireNonNull(expressions, "expressions is null");
if (expressions.isEmpty()) {
return switch (operator) {
case AND -> TRUE;
case OR -> FALSE;
};
}
if (expressions.size() == 1) {
return Iterables.getOnlyElement(expressions);
}
return new Logical(operator, ImmutableList.copyOf(expressions));
}
public static Expression combinePredicates(Logical.Operator operator, Collection expressions)
{
if (operator == Logical.Operator.AND) {
return combineConjuncts(expressions);
}
return combineDisjuncts(expressions);
}
public static Expression combineConjuncts(Expression... expressions)
{
return combineConjuncts(Arrays.asList(expressions));
}
public static Expression combineConjuncts(Collection expressions)
{
requireNonNull(expressions, "expressions is null");
List conjuncts = expressions.stream()
.flatMap(e -> extractConjuncts(e).stream())
.filter(e -> !e.equals(TRUE))
.collect(toList());
conjuncts = removeDuplicates(conjuncts);
if (conjuncts.contains(FALSE)) {
return FALSE;
}
return and(conjuncts);
}
public static Expression combineConjunctsWithDuplicates(Collection expressions)
{
requireNonNull(expressions, "expressions is null");
List conjuncts = expressions.stream()
.flatMap(e -> extractConjuncts(e).stream())
.filter(e -> !e.equals(TRUE))
.collect(toList());
if (conjuncts.contains(FALSE)) {
return FALSE;
}
return and(conjuncts);
}
public static Expression combineDisjuncts(Expression... expressions)
{
return combineDisjuncts(Arrays.asList(expressions));
}
public static Expression combineDisjuncts(Collection expressions)
{
return combineDisjunctsWithDefault(expressions, FALSE);
}
public static Expression combineDisjunctsWithDefault(Collection expressions, Expression emptyDefault)
{
requireNonNull(expressions, "expressions is null");
List disjuncts = expressions.stream()
.flatMap(e -> extractDisjuncts(e).stream())
.filter(e -> !e.equals(FALSE))
.collect(toList());
disjuncts = removeDuplicates(disjuncts);
if (disjuncts.contains(TRUE)) {
return TRUE;
}
return disjuncts.isEmpty() ? emptyDefault : or(disjuncts);
}
public static Expression filterDeterministicConjuncts(Expression expression)
{
return filterConjuncts(expression, DeterminismEvaluator::isDeterministic);
}
public static Expression filterNonDeterministicConjuncts(Expression expression)
{
return filterConjuncts(expression, not(DeterminismEvaluator::isDeterministic));
}
public static Expression filterConjuncts(Expression expression, Predicate predicate)
{
List conjuncts = extractConjuncts(expression).stream()
.filter(predicate)
.collect(toList());
return combineConjuncts(conjuncts);
}
@SafeVarargs
public static Function expressionOrNullSymbols(Predicate... nullSymbolScopes)
{
return expression -> {
ImmutableList.Builder resultDisjunct = ImmutableList.builder();
resultDisjunct.add(expression);
for (Predicate nullSymbolScope : nullSymbolScopes) {
List symbols = SymbolsExtractor.extractUnique(expression).stream()
.filter(nullSymbolScope)
.collect(toImmutableList());
if (symbols.isEmpty()) {
continue;
}
ImmutableList.Builder nullConjuncts = ImmutableList.builder();
for (Symbol symbol : symbols) {
nullConjuncts.add(new IsNull(symbol.toSymbolReference()));
}
resultDisjunct.add(and(nullConjuncts.build()));
}
return or(resultDisjunct.build());
};
}
/**
* Removes duplicate deterministic expressions. Preserves the relative order
* of the expressions in the list.
*/
private static List removeDuplicates(List expressions)
{
Set seen = new HashSet<>();
ImmutableList.Builder result = ImmutableList.builder();
for (Expression expression : expressions) {
if (!DeterminismEvaluator.isDeterministic(expression)) {
result.add(expression);
}
else if (!seen.contains(expression)) {
result.add(expression);
seen.add(expression);
}
}
return result.build();
}
public static Stream preOrder(Expression node)
{
return stream(
Traverser.forTree((SuccessorsFunction) Expression::children)
.depthFirstPreOrder(requireNonNull(node, "node is null")));
}
}