All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.planner.SortExpressionExtractor Maven / Gradle / Ivy

There is a newer version: 468
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.Metadata;
import io.trino.operator.join.SortedPositionLinks;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static java.util.Collections.singletonList;
import static java.util.Comparator.comparing;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

/**
 * Extracts sort expression to be used for creating {@link SortedPositionLinks} from join filter expression.
 * Currently this class can extract sort and search expressions from filter function conjuncts of shape:
 * 

* {@code A.a < f(B.x, B.y, B.z)} or {@code f(B.x, B.y, B.z) < A.a} *

* where {@code a} is the build side symbol reference and {@code x,y,z} are probe * side symbol references. Any of inequality operators ({@code <,<=,>,>=}) can be used. * Same build side symbol need to be used in all conjuncts. */ public final class SortExpressionExtractor { /* TODO: This class could be extended to handle any expressions like: A.a * sin(A.b) / log(B.x) < cos(B.z) by transforming it to: f(A.a, A.b) < g(B.x, B.z) Where f(...) and g(...) would be some functions/expressions. That would allow us to perform binary search on arbitrary complex expressions by sorting position links according to the result of f(...) function. */ private SortExpressionExtractor() {} public static Optional extractSortExpression(Metadata metadata, Set buildSymbols, Expression filter) { List filterConjuncts = ExpressionUtils.extractConjuncts(filter); SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols); List sortExpressionCandidates = ImmutableList.copyOf(filterConjuncts.stream() .filter(expression -> DeterminismEvaluator.isDeterministic(expression, metadata)) .map(visitor::process) .filter(Optional::isPresent) .map(Optional::get) .collect(toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge)) .values()); // For now heuristically pick sort expression which has most search expressions assigned to it. // TODO: make it cost based decision based on symbol statistics return sortExpressionCandidates.stream() .sorted(comparing(context -> -1 * context.getSearchExpressions().size())) .findFirst(); } private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right) { checkArgument(left.getSortExpression().equals(right.getSortExpression())); ImmutableList.Builder searchExpressions = ImmutableList.builder(); searchExpressions.addAll(left.getSearchExpressions()); searchExpressions.addAll(right.getSearchExpressions()); return new SortExpressionContext(left.getSortExpression(), searchExpressions.build()); } private static class SortExpressionVisitor extends AstVisitor, Void> { private final Set buildSymbols; public SortExpressionVisitor(Set buildSymbols) { this.buildSymbols = buildSymbols; } @Override protected Optional visitExpression(Expression expression, Void context) { return Optional.empty(); } @Override protected Optional visitComparisonExpression(ComparisonExpression comparison, Void context) { return switch (comparison.getOperator()) { case GREATER_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> { Optional sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight()); boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft()); if (sortChannel.isEmpty()) { sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft()); hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { yield sortChannel.map(symbolReference -> new SortExpressionContext(symbolReference, singletonList(comparison))); } yield Optional.empty(); } default -> Optional.empty(); }; } @Override protected Optional visitBetweenPredicate(BetweenPredicate node, Void context) { Optional result = visitComparisonExpression(new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()), context); if (result.isPresent()) { return result; } return visitComparisonExpression(new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()), context); } } private static Optional asBuildSymbolReference(Set buildLayout, Expression expression) { // Currently we only support symbol as sort expression on build side if (expression instanceof SymbolReference symbolReference) { if (buildLayout.contains(new Symbol(symbolReference.getName()))) { return Optional.of(symbolReference); } } return Optional.empty(); } private static boolean hasBuildSymbolReference(Set buildSymbols, Expression expression) { return new BuildSymbolReferenceFinder(buildSymbols).process(expression); } private static class BuildSymbolReferenceFinder extends AstVisitor { private final Set buildSymbols; public BuildSymbolReferenceFinder(Set buildSymbols) { this.buildSymbols = buildSymbols.stream() .map(Symbol::getName) .collect(toImmutableSet()); } @Override protected Boolean visitNode(Node node, Void context) { for (Node child : node.getChildren()) { if (process(child, context)) { return true; } } return false; } @Override protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context) { return buildSymbols.contains(symbolReference.getName()); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy