All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.sql.planner.iterative.rule.TransformCorrelatedInPredicateToJoin Maven / Gradle / Ivy

There is a newer version: 350
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.IsNotNullPredicate;
import io.prestosql.sql.tree.IsNullPredicate;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NotExpression;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SearchedCaseExpression;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.sql.tree.WhenClause;
import io.prestosql.sql.util.AstUtils;

import javax.annotation.Nullable;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.prestosql.matching.Pattern.nonEmpty;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.sql.ExpressionUtils.and;
import static io.prestosql.sql.ExpressionUtils.or;
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.prestosql.sql.planner.plan.Patterns.Apply.correlation;
import static io.prestosql.sql.planner.plan.Patterns.applyNode;
import static java.util.Objects.requireNonNull;

/**
 * Replaces correlated ApplyNode with InPredicate expression with SemiJoin
 * 

* Transforms: *

 * - Apply (output: a in B.b)
 *    - input: some plan A producing symbol a
 *    - subquery: some plan B producing symbol b, using symbols from A
 * 
* Into: *
 * - Project (output: CASE WHEN (countmatches > 0) THEN true WHEN (countnullmatches > 0) THEN null ELSE false END)
 *   - Aggregate (countmatches=count(*) where a, b not null; countnullmatches where a,b null but buildSideKnownNonNull is not null)
 *     grouping by (A'.*)
 *     - LeftJoin on (A and B correlation condition)
 *       - AssignUniqueId (A')
 *         - A
 * 
*

* * @see TransformCorrelatedScalarAggregationToJoin */ public class TransformCorrelatedInPredicateToJoin implements Rule { private static final Pattern PATTERN = applyNode() .with(nonEmpty(correlation())); private final ResolvedFunction countFunction; public TransformCorrelatedInPredicateToJoin(Metadata metadata) { countFunction = metadata.resolveFunction(QualifiedName.of("count"), ImmutableList.of()); } @Override public Pattern getPattern() { return PATTERN; } @Override public Result apply(ApplyNode apply, Captures captures, Context context) { Assignments subqueryAssignments = apply.getSubqueryAssignments(); if (subqueryAssignments.size() != 1) { return Result.empty(); } Expression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions()); if (!(assignmentExpression instanceof InPredicate)) { return Result.empty(); } InPredicate inPredicate = (InPredicate) assignmentExpression; Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols()); return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); } private Result apply( ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { Optional decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation()) .decorrelate(apply.getSubquery()); if (decorrelated.isEmpty()) { return Result.empty(); } PlanNode projection = buildInPredicateEquivalent( apply, inPredicate, inPredicateOutputSymbol, decorrelated.get(), idAllocator, symbolAllocator); return Result.ofPlanNode(projection); } private PlanNode buildInPredicateEquivalent( ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { Expression correlationCondition = and(decorrelated.getCorrelatedPredicates()); PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode(); AssignUniqueId probeSide = new AssignUniqueId( idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", BIGINT)); Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT); ProjectNode buildSide = new ProjectNode( idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder() .putIdentities(decorrelatedBuildSource.getOutputSymbols()) .put(buildSideKnownNonNull, bigint(0)) .build()); Symbol probeSideSymbol = Symbol.from(inPredicate.getValue()); Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList()); Expression joinExpression = and( or( new IsNullPredicate(probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), new IsNullPredicate(buildSideSymbol.toSymbolReference())), correlationCondition); JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression); Symbol matchConditionSymbol = symbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN); Expression matchCondition = and( isNotNull(probeSideSymbol), isNotNull(buildSideSymbol)); Symbol nullMatchConditionSymbol = symbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN); Expression nullMatchCondition = and( isNotNull(buildSideKnownNonNull), not(matchCondition)); ProjectNode preProjection = new ProjectNode( idAllocator.getNextId(), leftOuterJoin, Assignments.builder() .putIdentities(leftOuterJoin.getOutputSymbols()) .put(matchConditionSymbol, matchCondition) .put(nullMatchConditionSymbol, nullMatchCondition) .build()); Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT); Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT); AggregationNode aggregation = new AggregationNode( idAllocator.getNextId(), preProjection, ImmutableMap.builder() .put(countMatchesSymbol, countWithFilter(matchConditionSymbol)) .put(countNullMatchesSymbol, countWithFilter(nullMatchConditionSymbol)) .build(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()); // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( ImmutableList.of( new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false))); return new ProjectNode( idAllocator.getNextId(), aggregation, Assignments.builder() .putIdentities(apply.getInput().getOutputSymbols()) .put(inPredicateOutputSymbol, inPredicateEquivalent) .build()); } private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, Expression joinExpression) { return new JoinNode( idAllocator.getNextId(), JoinNode.Type.LEFT, probeSide, buildSide, ImmutableList.of(), probeSide.getOutputSymbols(), buildSide.getOutputSymbols(), Optional.of(joinExpression), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()); } private AggregationNode.Aggregation countWithFilter(Symbol filter) { return new AggregationNode.Aggregation( countFunction, ImmutableList.of(), false, Optional.of(filter), Optional.empty(), Optional.empty()); /* mask */ } private static Expression isGreaterThan(Symbol symbol, long value) { return new ComparisonExpression( ComparisonExpression.Operator.GREATER_THAN, symbol.toSymbolReference(), bigint(value)); } private static Expression not(Expression booleanExpression) { return new NotExpression(booleanExpression); } private static Expression isNotNull(Symbol symbol) { return new IsNotNullPredicate(symbol.toSymbolReference()); } private static Expression bigint(long value) { return new Cast(new LongLiteral(String.valueOf(value)), toSqlType(BIGINT)); } private static Expression booleanConstant(@Nullable Boolean value) { if (value == null) { return new Cast(new NullLiteral(), toSqlType(BOOLEAN)); } return new BooleanLiteral(value.toString()); } private static class DecorrelatingVisitor extends PlanVisitor, PlanNode> { private final Lookup lookup; private final Set correlation; public DecorrelatingVisitor(Lookup lookup, Iterable correlation) { this.lookup = requireNonNull(lookup, "lookup is null"); this.correlation = ImmutableSet.copyOf(requireNonNull(correlation, "correlation is null")); } public Optional decorrelate(PlanNode reference) { return lookup.resolve(reference).accept(this, reference); } @Override public Optional visitProject(ProjectNode node, PlanNode reference) { if (isCorrelatedShallowly(node)) { // TODO: handle correlated projection return Optional.empty(); } Optional result = decorrelate(node.getSource()); return result.map(decorrelated -> { Assignments.Builder assignments = Assignments.builder() .putAll(node.getAssignments()); // Pull up all symbols used by a filter (except correlation) decorrelated.getCorrelatedPredicates().stream() .flatMap(AstUtils::preOrder) .filter(SymbolReference.class::isInstance) .map(SymbolReference.class::cast) .filter(symbolReference -> !correlation.contains(Symbol.from(symbolReference))) .forEach(symbolReference -> assignments.putIdentity(Symbol.from(symbolReference))); return new Decorrelated( decorrelated.getCorrelatedPredicates(), new ProjectNode( node.getId(), decorrelated.getDecorrelatedNode(), assignments.build())); }); } @Override public Optional visitFilter(FilterNode node, PlanNode reference) { Optional result = decorrelate(node.getSource()); return result.map(decorrelated -> new Decorrelated( ImmutableList.builder() .addAll(decorrelated.getCorrelatedPredicates()) // No need to retain uncorrelated conditions, predicate push down will push them back .add(node.getPredicate()) .build(), decorrelated.getDecorrelatedNode())); } @Override protected Optional visitPlan(PlanNode node, PlanNode reference) { if (isCorrelatedRecursively(node)) { return Optional.empty(); } else { return Optional.of(new Decorrelated(ImmutableList.of(), reference)); } } private boolean isCorrelatedRecursively(PlanNode node) { if (isCorrelatedShallowly(node)) { return true; } return node.getSources().stream() .map(lookup::resolve) .anyMatch(this::isCorrelatedRecursively); } private boolean isCorrelatedShallowly(PlanNode node) { return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains); } } private static class Decorrelated { private final List correlatedPredicates; private final PlanNode decorrelatedNode; public Decorrelated(List correlatedPredicates, PlanNode decorrelatedNode) { this.correlatedPredicates = ImmutableList.copyOf(requireNonNull(correlatedPredicates, "correlatedPredicates is null")); this.decorrelatedNode = requireNonNull(decorrelatedNode, "decorrelatedNode is null"); } public List getCorrelatedPredicates() { return correlatedPredicates; } public PlanNode getDecorrelatedNode() { return decorrelatedNode; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy