![JAR search and dependency download from the Maven repository](/logo.png)
io.trino.sql.planner.optimizations.ExpressionEquivalence Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.optimizations;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import io.trino.Session;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.Metadata;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.SpecialForm.Form;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.tree.Expression;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.function.OperatorType.EQUAL;
import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.relational.SpecialForm.Form.AND;
import static io.trino.sql.relational.SpecialForm.Form.OR;
import static io.trino.sql.relational.SqlToRowExpressionTranslator.translate;
import static java.lang.Integer.min;
import static java.util.Objects.requireNonNull;
public class ExpressionEquivalence
{
private static final Ordering ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator());
private final Metadata metadata;
private final FunctionManager functionManager;
private final TypeAnalyzer typeAnalyzer;
private final CanonicalizationVisitor canonicalizationVisitor;
public ExpressionEquivalence(Metadata metadata, FunctionManager functionManager, TypeAnalyzer typeAnalyzer)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.functionManager = requireNonNull(functionManager, "functionManager is null");
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
this.canonicalizationVisitor = new CanonicalizationVisitor();
}
public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, TypeProvider types)
{
Map symbolInput = new HashMap<>();
int inputId = 0;
for (Entry entry : types.allTypes().entrySet()) {
symbolInput.put(entry.getKey(), inputId);
inputId++;
}
RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, types);
RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, types);
RowExpression canonicalizedLeft = leftRowExpression.accept(canonicalizationVisitor, null);
RowExpression canonicalizedRight = rightRowExpression.accept(canonicalizationVisitor, null);
return canonicalizedLeft.equals(canonicalizedRight);
}
private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, TypeProvider types)
{
return translate(
expression,
typeAnalyzer.getTypes(session, types, expression),
symbolInput,
metadata,
functionManager,
session,
false);
}
private static class CanonicalizationVisitor
implements RowExpressionVisitor
{
@Override
public RowExpression visitCall(CallExpression call, Void context)
{
call = new CallExpression(
call.getResolvedFunction(),
call.getArguments().stream()
.map(expression -> expression.accept(this, context))
.collect(toImmutableList()));
CatalogSchemaFunctionName callName = call.getResolvedFunction().getSignature().getName();
if (callName.equals(builtinFunctionName(EQUAL)) || callName.equals(builtinFunctionName(IS_DISTINCT_FROM))) {
// sort arguments
return new CallExpression(
call.getResolvedFunction(),
ROW_EXPRESSION_ORDERING.sortedCopy(call.getArguments()));
}
return call;
}
@Override
public RowExpression visitSpecialForm(SpecialForm specialForm, Void context)
{
specialForm = new SpecialForm(
specialForm.getForm(),
specialForm.getType(),
specialForm.getArguments().stream()
.map(expression -> expression.accept(this, context))
.collect(toImmutableList()),
specialForm.getFunctionDependencies());
if (specialForm.getForm() == AND || specialForm.getForm() == OR) {
// if we have nested calls (of the same type) flatten them
List flattenedArguments = flattenNestedCallArgs(specialForm);
// only consider distinct arguments
Set distinctArguments = ImmutableSet.copyOf(flattenedArguments);
if (distinctArguments.size() == 1) {
return Iterables.getOnlyElement(distinctArguments);
}
// canonicalize the argument order (i.e., sort them)
List sortedArguments = ROW_EXPRESSION_ORDERING.sortedCopy(distinctArguments);
return new SpecialForm(specialForm.getForm(), BOOLEAN, sortedArguments, specialForm.getFunctionDependencies());
}
return specialForm;
}
public static List flattenNestedCallArgs(SpecialForm specialForm)
{
Form form = specialForm.getForm();
ImmutableList.Builder newArguments = ImmutableList.builder();
for (RowExpression argument : specialForm.getArguments()) {
if (argument instanceof SpecialForm && form == ((SpecialForm) argument).getForm()) {
// same special form type, so flatten the args
newArguments.addAll(flattenNestedCallArgs((SpecialForm) argument));
}
else {
newArguments.add(argument);
}
}
return newArguments.build();
}
@Override
public RowExpression visitConstant(ConstantExpression constant, Void context)
{
return constant;
}
@Override
public RowExpression visitInputReference(InputReferenceExpression node, Void context)
{
return node;
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
{
return new LambdaDefinitionExpression(lambda.getArgumentTypes(), lambda.getArguments(), lambda.getBody().accept(this, context));
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
{
return reference;
}
}
private static class RowExpressionComparator
implements Comparator
{
private final Comparator
© 2015 - 2025 Weber Informatics LLC | Privacy Policy