com.hazelcast.org.apache.calcite.rel.metadata.RelMdPredicates Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.org.apache.calcite.rel.metadata;
import com.hazelcast.org.apache.calcite.linq4j.Linq4j;
import com.hazelcast.org.apache.calcite.linq4j.Ord;
import com.hazelcast.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptPredicateList;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.plan.RexImplicationChecker;
import com.hazelcast.org.apache.calcite.plan.Strong;
import com.hazelcast.org.apache.calcite.plan.hep.HepRelVertex;
import com.hazelcast.org.apache.calcite.plan.volcano.RelSubset;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Aggregate;
import com.hazelcast.org.apache.calcite.rel.core.Exchange;
import com.hazelcast.org.apache.calcite.rel.core.Filter;
import com.hazelcast.org.apache.calcite.rel.core.Intersect;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.core.Minus;
import com.hazelcast.org.apache.calcite.rel.core.Project;
import com.hazelcast.org.apache.calcite.rel.core.Sort;
import com.hazelcast.org.apache.calcite.rel.core.TableModify;
import com.hazelcast.org.apache.calcite.rel.core.TableScan;
import com.hazelcast.org.apache.calcite.rel.core.Union;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexExecutor;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexPermuteInputsShuttle;
import com.hazelcast.org.apache.calcite.rex.RexSimplify;
import com.hazelcast.org.apache.calcite.rex.RexUnknownAs;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.rex.RexVisitorImpl;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.SqlOperator;
import com.hazelcast.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import com.hazelcast.org.apache.calcite.util.BitSets;
import com.hazelcast.org.apache.calcite.util.Bug;
import com.hazelcast.org.apache.calcite.util.BuiltInMethod;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.apache.calcite.util.mapping.Mapping;
import com.hazelcast.org.apache.calcite.util.mapping.MappingType;
import com.hazelcast.org.apache.calcite.util.mapping.Mappings;
import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
/**
* Utility to infer Predicates that are applicable above a RelNode.
*
* This is currently used by
* {@link com.hazelcast.org.apache.calcite.rel.rules.JoinPushTransitivePredicatesRule} to
* infer Predicates that can be inferred from one side of a Join
* to the other.
*
*
The PullUp Strategy is sound but not complete. Here are some of the
* limitations:
*
*
* - For Aggregations we only PullUp predicates that only contain
* Grouping Keys. This can be extended to infer predicates on Aggregation
* expressions from expressions on the aggregated columns. For e.g.
*
* select a, max(b) from R1 where b > 7
* → max(b) > 7 or max(b) is null
*
*
* - For Projections we only look at columns that are projected without
* any function applied. So:
*
* select a from R1 where a > 7
* → "a > 7" is pulled up from the Projection.
* select a + 1 from R1 where a + 1 > 7
* → "a + 1 gt; 7" is not pulled up
*
*
* - There are several restrictions on Joins:
*
* - We only pullUp inferred predicates for now. Pulling up existing
* predicates causes an explosion of duplicates. The existing predicates
* are pushed back down as new predicates. Once we have rules to eliminate
* duplicate Filter conditions, we should pullUp all predicates.
*
*
- For Left Outer: we infer new predicates from the left and set them
* as applicable on the Right side. No predicates are pulledUp.
*
*
- Right Outer Joins are handled in an analogous manner.
*
*
- For Full Outer Joins no predicates are pulledUp or inferred.
*
*
*/
public class RelMdPredicates
implements MetadataHandler {
public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider
.reflectiveSource(BuiltInMethod.PREDICATES.method, new RelMdPredicates());
private static final List EMPTY_LIST = ImmutableList.of();
public MetadataDef getDef() {
return BuiltInMetadata.Predicates.DEF;
}
/** Catch-all implementation for
* {@link BuiltInMetadata.Predicates#getPredicates()},
* invoked using reflection.
*
* @see com.hazelcast.org.apache.calcite.rel.metadata.RelMetadataQuery#getPulledUpPredicates(RelNode)
*/
public RelOptPredicateList getPredicates(RelNode rel, RelMetadataQuery mq) {
return RelOptPredicateList.EMPTY;
}
public RelOptPredicateList getPredicates(HepRelVertex rel,
RelMetadataQuery mq) {
return mq.getPulledUpPredicates(rel.getCurrentRel());
}
/**
* Infers predicates for a table scan.
*/
public RelOptPredicateList getPredicates(TableScan table,
RelMetadataQuery mq) {
return RelOptPredicateList.EMPTY;
}
/**
* Infers predicates for a project.
*
*
* - create a mapping from input to projection. Map only positions that
* directly reference an input column.
*
- Expressions that only contain above columns are retained in the
* Project's pullExpressions list.
*
- For e.g. expression 'a + e = 9' below will not be pulled up because 'e'
* is not in the projection list.
*
*
* inputPullUpExprs: {a > 7, b + c < 10, a + e = 9}
* projectionExprs: {a, b, c, e / 2}
* projectionPullupExprs: {a > 7, b + c < 10}
*
*
*
*/
public RelOptPredicateList getPredicates(Project project,
RelMetadataQuery mq) {
final RelNode input = project.getInput();
final RexBuilder rexBuilder = project.getCluster().getRexBuilder();
final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input);
final List projectPullUpPredicates = new ArrayList<>();
ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder();
Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION,
input.getRowType().getFieldCount(),
project.getRowType().getFieldCount());
for (Ord expr : Ord.zip(project.getProjects())) {
if (expr.e instanceof RexInputRef) {
int sIdx = ((RexInputRef) expr.e).getIndex();
m.set(sIdx, expr.i);
columnsMappedBuilder.set(sIdx);
// Project can also generate constants. We need to include them.
} else if (RexLiteral.isNullLiteral(expr.e)) {
projectPullUpPredicates.add(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL,
rexBuilder.makeInputRef(project, expr.i)));
} else if (RexUtil.isConstant(expr.e)) {
final List args =
ImmutableList.of(rexBuilder.makeInputRef(project, expr.i), expr.e);
final SqlOperator op = args.get(0).getType().isNullable()
|| args.get(1).getType().isNullable()
? SqlStdOperatorTable.IS_NOT_DISTINCT_FROM
: SqlStdOperatorTable.EQUALS;
projectPullUpPredicates.add(rexBuilder.makeCall(op, args));
}
}
// Go over childPullUpPredicates. If a predicate only contains columns in
// 'columnsMapped' construct a new predicate based on mapping.
final ImmutableBitSet columnsMapped = columnsMappedBuilder.build();
for (RexNode r : inputInfo.pulledUpPredicates) {
RexNode r2 = projectPredicate(rexBuilder, input, r, columnsMapped);
if (!r2.isAlwaysTrue()) {
r2 = r2.accept(new RexPermuteInputsShuttle(m, input));
projectPullUpPredicates.add(r2);
}
}
return RelOptPredicateList.of(rexBuilder, projectPullUpPredicates);
}
/** Converts a predicate on a particular set of columns into a predicate on
* a subset of those columns, weakening if necessary.
*
* If not possible to simplify, returns {@code true}, which is the weakest
* possible predicate.
*
*
Examples:
* - The predicate {@code $7 = $9} on columns [7]
* becomes {@code $7 is not null}
*
- The predicate {@code $7 = $9 + $11} on columns [7, 9]
* becomes {@code $7 is not null or $9 is not null}
*
- The predicate {@code $7 = $9 and $9 = 5} on columns [7] becomes
* {@code $7 = 5}
*
- The predicate
* {@code $7 = $9 and ($9 = $1 or $9 = $2) and $1 > 3 and $2 > 10}
* on columns [7] becomes {@code $7 > 3}
*
*
* We currently only handle examples 1 and 2.
*
* @param rexBuilder Rex builder
* @param input Input relational expression
* @param r Predicate expression
* @param columnsMapped Columns which the final predicate can reference
* @return Predicate expression narrowed to reference only certain columns
*/
private RexNode projectPredicate(final RexBuilder rexBuilder, RelNode input,
RexNode r, ImmutableBitSet columnsMapped) {
ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
if (columnsMapped.contains(rCols)) {
// All required columns are present. No need to weaken.
return r;
}
if (columnsMapped.intersects(rCols)) {
final List list = new ArrayList<>();
for (int c : columnsMapped.intersect(rCols)) {
if (input.getRowType().getFieldList().get(c).getType().isNullable()
&& Strong.isNull(r, ImmutableBitSet.of(c))) {
list.add(
rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
rexBuilder.makeInputRef(input, c)));
}
}
if (!list.isEmpty()) {
return RexUtil.composeDisjunction(rexBuilder, list);
}
}
// Cannot weaken to anything non-trivial
return rexBuilder.makeLiteral(true);
}
/**
* Add the Filter condition to the pulledPredicates list from the input.
*/
public RelOptPredicateList getPredicates(Filter filter, RelMetadataQuery mq) {
final RelNode input = filter.getInput();
final RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input);
return Util.first(inputInfo, RelOptPredicateList.EMPTY)
.union(rexBuilder,
RelOptPredicateList.of(rexBuilder,
RexUtil.retainDeterministic(
RelOptUtil.conjunctions(filter.getCondition()))));
}
/**
* Infers predicates for a {@link com.hazelcast.org.apache.calcite.rel.core.Join} (including
* {@code SemiJoin}).
*/
public RelOptPredicateList getPredicates(Join join, RelMetadataQuery mq) {
RelOptCluster cluster = join.getCluster();
RexBuilder rexBuilder = cluster.getRexBuilder();
final RexExecutor executor =
Util.first(cluster.getPlanner().getExecutor(), RexUtil.EXECUTOR);
final RelNode left = join.getInput(0);
final RelNode right = join.getInput(1);
final RelOptPredicateList leftInfo = mq.getPulledUpPredicates(left);
final RelOptPredicateList rightInfo = mq.getPulledUpPredicates(right);
JoinConditionBasedPredicateInference joinInference =
new JoinConditionBasedPredicateInference(join,
RexUtil.composeConjunction(rexBuilder, leftInfo.pulledUpPredicates),
RexUtil.composeConjunction(rexBuilder, rightInfo.pulledUpPredicates),
new RexSimplify(rexBuilder, RelOptPredicateList.EMPTY, executor));
return joinInference.inferPredicates(false);
}
/**
* Infers predicates for an Aggregate.
*
* Pulls up predicates that only contains references to columns in the
* GroupSet. For e.g.
*
*
* inputPullUpExprs : { a > 7, b + c < 10, a + e = 9}
* groupSet : { a, b}
* pulledUpExprs : { a > 7}
*
*/
public RelOptPredicateList getPredicates(Aggregate agg, RelMetadataQuery mq) {
final RelNode input = agg.getInput();
final RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input);
final List aggPullUpPredicates = new ArrayList<>();
ImmutableBitSet groupKeys = agg.getGroupSet();
if (groupKeys.isEmpty()) {
// "GROUP BY ()" can convert an empty relation to a non-empty relation, so
// it is not valid to pull up predicates. In particular, consider the
// predicate "false": it is valid on all input rows (trivially - there are
// no rows!) but not on the output (there is one row).
return RelOptPredicateList.EMPTY;
}
Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION,
input.getRowType().getFieldCount(), agg.getRowType().getFieldCount());
int i = 0;
for (int j : groupKeys) {
m.set(j, i++);
}
for (RexNode r : inputInfo.pulledUpPredicates) {
ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
if (groupKeys.contains(rCols)) {
r = r.accept(new RexPermuteInputsShuttle(m, input));
aggPullUpPredicates.add(r);
}
}
return RelOptPredicateList.of(rexBuilder, aggPullUpPredicates);
}
/**
* Infers predicates for a Union.
*/
public RelOptPredicateList getPredicates(Union union, RelMetadataQuery mq) {
final RexBuilder rexBuilder = union.getCluster().getRexBuilder();
Set finalPredicates = new HashSet<>();
final List finalResidualPredicates = new ArrayList<>();
for (Ord input : Ord.zip(union.getInputs())) {
RelOptPredicateList info = mq.getPulledUpPredicates(input.e);
if (info.pulledUpPredicates.isEmpty()) {
return RelOptPredicateList.EMPTY;
}
final Set predicates = new HashSet<>();
final List residualPredicates = new ArrayList<>();
for (RexNode pred : info.pulledUpPredicates) {
if (input.i == 0) {
predicates.add(pred);
continue;
}
if (finalPredicates.contains(pred)) {
predicates.add(pred);
} else {
residualPredicates.add(pred);
}
}
// Add new residual predicates
finalResidualPredicates.add(RexUtil.composeConjunction(rexBuilder, residualPredicates));
// Add those that are not part of the final set to residual
for (RexNode e : finalPredicates) {
if (!predicates.contains(e)) {
// This node was in previous union inputs, but it is not in this one
for (int j = 0; j < input.i; j++) {
finalResidualPredicates.set(j,
RexUtil.composeConjunction(rexBuilder,
Arrays.asList(finalResidualPredicates.get(j), e)));
}
}
}
// Final predicates
finalPredicates = predicates;
}
final List predicates = new ArrayList<>(finalPredicates);
final RelOptCluster cluster = union.getCluster();
final RexExecutor executor =
Util.first(cluster.getPlanner().getExecutor(), RexUtil.EXECUTOR);
RexNode disjunctivePredicate =
new RexSimplify(rexBuilder, RelOptPredicateList.EMPTY, executor)
.simplifyUnknownAs(rexBuilder.makeCall(SqlStdOperatorTable.OR, finalResidualPredicates),
RexUnknownAs.FALSE);
if (!disjunctivePredicate.isAlwaysTrue()) {
predicates.add(disjunctivePredicate);
}
return RelOptPredicateList.of(rexBuilder, predicates);
}
/**
* Infers predicates for a Intersect.
*/
public RelOptPredicateList getPredicates(Intersect intersect, RelMetadataQuery mq) {
final RexBuilder rexBuilder = intersect.getCluster().getRexBuilder();
final RexExecutor executor =
Util.first(intersect.getCluster().getPlanner().getExecutor(), RexUtil.EXECUTOR);
final RexImplicationChecker rexImplicationChecker =
new RexImplicationChecker(rexBuilder, executor, intersect.getRowType());
Set finalPredicates = new HashSet<>();
for (Ord input : Ord.zip(intersect.getInputs())) {
RelOptPredicateList info = mq.getPulledUpPredicates(input.e);
if (info == null || info.pulledUpPredicates.isEmpty()) {
continue;
}
for (RexNode pred: info.pulledUpPredicates) {
if (finalPredicates.stream().anyMatch(
finalPred -> rexImplicationChecker.implies(finalPred, pred))) {
// There's already a stricter predicate in finalPredicates,
// thus no need to count this one.
continue;
}
// Remove looser predicate and add this one into finalPredicates
finalPredicates = finalPredicates.stream()
.filter(finalPred -> !rexImplicationChecker.implies(pred, finalPred))
.collect(Collectors.toSet());
finalPredicates.add(pred);
}
}
return RelOptPredicateList.of(rexBuilder, finalPredicates);
}
/**
* Infers predicates for a Minus.
*/
public RelOptPredicateList getPredicates(Minus minus, RelMetadataQuery mq) {
return mq.getPulledUpPredicates(minus.getInput(0));
}
/**
* Infers predicates for a Sort.
*/
public RelOptPredicateList getPredicates(Sort sort, RelMetadataQuery mq) {
RelNode input = sort.getInput();
return mq.getPulledUpPredicates(input);
}
/**
* Infers predicates for a TableModify.
*/
public RelOptPredicateList getPredicates(TableModify tableModify, RelMetadataQuery mq) {
return mq.getPulledUpPredicates(tableModify.getInput());
}
/**
* Infers predicates for an Exchange.
*/
public RelOptPredicateList getPredicates(Exchange exchange,
RelMetadataQuery mq) {
RelNode input = exchange.getInput();
return mq.getPulledUpPredicates(input);
}
/** @see RelMetadataQuery#getPulledUpPredicates(RelNode) */
public RelOptPredicateList getPredicates(RelSubset r,
RelMetadataQuery mq) {
if (!Bug.CALCITE_1048_FIXED) {
return RelOptPredicateList.EMPTY;
}
final RexBuilder rexBuilder = r.getCluster().getRexBuilder();
RelOptPredicateList list = null;
for (RelNode r2 : r.getRels()) {
RelOptPredicateList list2 = mq.getPulledUpPredicates(r2);
if (list2 != null) {
list = list == null ? list2 : list.union(rexBuilder, list2);
}
}
return Util.first(list, RelOptPredicateList.EMPTY);
}
/**
* Utility to infer predicates from one side of the join that apply on the
* other side.
*
* Contract is:
*
* - initialize with a {@link com.hazelcast.org.apache.calcite.rel.core.Join} and
* optional predicates applicable on its left and right subtrees.
*
*
- you can
* then ask it for equivalentPredicate(s) given a predicate.
*
*
*
* So for:
*
* - '
R1(x) join R2(y) on x = y
' a call for
* equivalentPredicates on 'x > 7
' will return '
* [y > 7]
'
* - '
R1(x) join R2(y) on x = y join R3(z) on y = z
' a call for
* equivalentPredicates on the second join 'x > 7
' will return
*
*/
static class JoinConditionBasedPredicateInference {
final Join joinRel;
final int nSysFields;
final int nFieldsLeft;
final int nFieldsRight;
final ImmutableBitSet leftFieldsBitSet;
final ImmutableBitSet rightFieldsBitSet;
final ImmutableBitSet allFieldsBitSet;
SortedMap equivalence;
final Map exprFields;
final Set allExprs;
final Set equalityPredicates;
final RexNode leftChildPredicates;
final RexNode rightChildPredicates;
final RexSimplify simplify;
JoinConditionBasedPredicateInference(Join joinRel, RexNode leftPredicates,
RexNode rightPredicates, RexSimplify simplify) {
super();
this.joinRel = joinRel;
this.simplify = simplify;
nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size();
nFieldsRight = joinRel.getRight().getRowType().getFieldList().size();
nSysFields = joinRel.getSystemFieldList().size();
leftFieldsBitSet = ImmutableBitSet.range(nSysFields,
nSysFields + nFieldsLeft);
rightFieldsBitSet = ImmutableBitSet.range(nSysFields + nFieldsLeft,
nSysFields + nFieldsLeft + nFieldsRight);
allFieldsBitSet = ImmutableBitSet.range(0,
nSysFields + nFieldsLeft + nFieldsRight);
exprFields = new HashMap<>();
allExprs = new HashSet<>();
if (leftPredicates == null) {
leftChildPredicates = null;
} else {
Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(
nSysFields + nFieldsLeft, nSysFields, 0, nFieldsLeft);
leftChildPredicates = leftPredicates.accept(
new RexPermuteInputsShuttle(leftMapping, joinRel.getInput(0)));
allExprs.add(leftChildPredicates);
for (RexNode r : RelOptUtil.conjunctions(leftChildPredicates)) {
exprFields.put(r, RelOptUtil.InputFinder.bits(r));
allExprs.add(r);
}
}
if (rightPredicates == null) {
rightChildPredicates = null;
} else {
Mappings.TargetMapping rightMapping = Mappings.createShiftMapping(
nSysFields + nFieldsLeft + nFieldsRight,
nSysFields + nFieldsLeft, 0, nFieldsRight);
rightChildPredicates = rightPredicates.accept(
new RexPermuteInputsShuttle(rightMapping, joinRel.getInput(1)));
allExprs.add(rightChildPredicates);
for (RexNode r : RelOptUtil.conjunctions(rightChildPredicates)) {
exprFields.put(r, RelOptUtil.InputFinder.bits(r));
allExprs.add(r);
}
}
equivalence = new TreeMap<>();
equalityPredicates = new HashSet<>();
for (int i = 0; i < nSysFields + nFieldsLeft + nFieldsRight; i++) {
equivalence.put(i, BitSets.of(i));
}
// Only process equivalences found in the join conditions. Processing
// Equivalences from the left or right side infer predicates that are
// already present in the Tree below the join.
RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
List exprs =
RelOptUtil.conjunctions(
compose(rexBuilder, ImmutableList.of(joinRel.getCondition())));
final EquivalenceFinder eF = new EquivalenceFinder();
exprs.forEach(input -> input.accept(eF));
equivalence = BitSets.closure(equivalence);
}
/**
* The PullUp Strategy is sound but not complete.
*
* - We only pullUp inferred predicates for now. Pulling up existing
* predicates causes an explosion of duplicates. The existing predicates are
* pushed back down as new predicates. Once we have rules to eliminate
* duplicate Filter conditions, we should pullUp all predicates.
*
- For Left Outer: we infer new predicates from the left and set them as
* applicable on the Right side. No predicates are pulledUp.
*
- Right Outer Joins are handled in an analogous manner.
*
- For Full Outer Joins no predicates are pulledUp or inferred.
*
*/
public RelOptPredicateList inferPredicates(
boolean includeEqualityInference) {
final List inferredPredicates = new ArrayList<>();
final Set allExprs = new HashSet<>(this.allExprs);
final JoinRelType joinType = joinRel.getJoinType();
switch (joinType) {
case SEMI:
case INNER:
case LEFT:
infer(leftChildPredicates, allExprs, inferredPredicates,
includeEqualityInference,
joinType == JoinRelType.LEFT ? rightFieldsBitSet
: allFieldsBitSet);
break;
}
switch (joinType) {
case SEMI:
case INNER:
case RIGHT:
infer(rightChildPredicates, allExprs, inferredPredicates,
includeEqualityInference,
joinType == JoinRelType.RIGHT ? leftFieldsBitSet
: allFieldsBitSet);
break;
}
Mappings.TargetMapping rightMapping = Mappings.createShiftMapping(
nSysFields + nFieldsLeft + nFieldsRight,
0, nSysFields + nFieldsLeft, nFieldsRight);
final RexPermuteInputsShuttle rightPermute =
new RexPermuteInputsShuttle(rightMapping, joinRel);
Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(
nSysFields + nFieldsLeft, 0, nSysFields, nFieldsLeft);
final RexPermuteInputsShuttle leftPermute =
new RexPermuteInputsShuttle(leftMapping, joinRel);
final List leftInferredPredicates = new ArrayList<>();
final List rightInferredPredicates = new ArrayList<>();
for (RexNode iP : inferredPredicates) {
ImmutableBitSet iPBitSet = RelOptUtil.InputFinder.bits(iP);
if (leftFieldsBitSet.contains(iPBitSet)) {
leftInferredPredicates.add(iP.accept(leftPermute));
} else if (rightFieldsBitSet.contains(iPBitSet)) {
rightInferredPredicates.add(iP.accept(rightPermute));
}
}
final RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
switch (joinType) {
case SEMI:
Iterable pulledUpPredicates;
pulledUpPredicates = Iterables.concat(
RelOptUtil.conjunctions(leftChildPredicates),
leftInferredPredicates);
return RelOptPredicateList.of(rexBuilder, pulledUpPredicates,
leftInferredPredicates, rightInferredPredicates);
case INNER:
pulledUpPredicates = Iterables.concat(
RelOptUtil.conjunctions(leftChildPredicates),
RelOptUtil.conjunctions(rightChildPredicates),
RexUtil.retainDeterministic(
RelOptUtil.conjunctions(joinRel.getCondition())),
inferredPredicates);
return RelOptPredicateList.of(rexBuilder, pulledUpPredicates,
leftInferredPredicates, rightInferredPredicates);
case LEFT:
return RelOptPredicateList.of(rexBuilder,
RelOptUtil.conjunctions(leftChildPredicates),
leftInferredPredicates, rightInferredPredicates);
case RIGHT:
return RelOptPredicateList.of(rexBuilder,
RelOptUtil.conjunctions(rightChildPredicates),
inferredPredicates, EMPTY_LIST);
default:
assert inferredPredicates.size() == 0;
return RelOptPredicateList.EMPTY;
}
}
public RexNode left() {
return leftChildPredicates;
}
public RexNode right() {
return rightChildPredicates;
}
private void infer(RexNode predicates, Set allExprs,
List inferredPredicates, boolean includeEqualityInference,
ImmutableBitSet inferringFields) {
for (RexNode r : RelOptUtil.conjunctions(predicates)) {
if (!includeEqualityInference
&& equalityPredicates.contains(r)) {
continue;
}
for (Mapping m : mappings(r)) {
RexNode tr = r.accept(
new RexPermuteInputsShuttle(m, joinRel.getInput(0),
joinRel.getInput(1)));
// Filter predicates can be already simplified, so we should work with
// simplified RexNode versions as well. It also allows prevent of having
// some duplicates in in result pulledUpPredicates
RexNode simplifiedTarget =
simplify.simplifyFilterPredicates(RelOptUtil.conjunctions(tr));
if (checkTarget(inferringFields, allExprs, tr)
&& checkTarget(inferringFields, allExprs, simplifiedTarget)) {
inferredPredicates.add(simplifiedTarget);
allExprs.add(simplifiedTarget);
}
}
}
}
Iterable mappings(final RexNode predicate) {
final ImmutableBitSet fields = exprFields.get(predicate);
if (fields.cardinality() == 0) {
return Collections.emptyList();
}
return () -> new ExprsItr(fields);
}
private boolean checkTarget(ImmutableBitSet inferringFields,
Set allExprs, RexNode tr) {
return inferringFields.contains(RelOptUtil.InputFinder.bits(tr))
&& !allExprs.contains(tr)
&& !isAlwaysTrue(tr);
}
private void markAsEquivalent(int p1, int p2) {
BitSet b = equivalence.get(p1);
b.set(p2);
b = equivalence.get(p2);
b.set(p1);
}
@Nonnull RexNode compose(RexBuilder rexBuilder, Iterable exprs) {
exprs = Linq4j.asEnumerable(exprs).where(Objects::nonNull);
return RexUtil.composeConjunction(rexBuilder, exprs);
}
/**
* Find expressions of the form 'col_x = col_y'.
*/
class EquivalenceFinder extends RexVisitorImpl {
protected EquivalenceFinder() {
super(true);
}
@Override public Void visitCall(RexCall call) {
if (call.getOperator().getKind() == SqlKind.EQUALS) {
int lPos = pos(call.getOperands().get(0));
int rPos = pos(call.getOperands().get(1));
if (lPos != -1 && rPos != -1) {
markAsEquivalent(lPos, rPos);
equalityPredicates.add(call);
}
}
return null;
}
}
/**
* Given an expression returns all the possible substitutions.
*
* For example, for an expression 'a + b + c' and the following
* equivalences:
* a : {a, b}
* b : {a, b}
* c : {c, e}
*
*
* The following Mappings will be returned:
*
* {a → a, b → a, c → c}
* {a → a, b → a, c → e}
* {a → a, b → b, c → c}
* {a → a, b → b, c → e}
* {a → b, b → a, c → c}
* {a → b, b → a, c → e}
* {a → b, b → b, c → c}
* {a → b, b → b, c → e}
*
*
* which imply the following inferences:
*
* a + a + c
* a + a + e
* a + b + c
* a + b + e
* b + a + c
* b + a + e
* b + b + c
* b + b + e
*
*/
class ExprsItr implements Iterator {
final int[] columns;
final BitSet[] columnSets;
final int[] iterationIdx;
Mapping nextMapping;
boolean firstCall;
ExprsItr(ImmutableBitSet fields) {
nextMapping = null;
columns = new int[fields.cardinality()];
columnSets = new BitSet[fields.cardinality()];
iterationIdx = new int[fields.cardinality()];
for (int j = 0, i = fields.nextSetBit(0); i >= 0; i = fields
.nextSetBit(i + 1), j++) {
columns[j] = i;
columnSets[j] = equivalence.get(i);
iterationIdx[j] = 0;
}
firstCall = true;
}
public boolean hasNext() {
if (firstCall) {
initializeMapping();
firstCall = false;
} else {
computeNextMapping(iterationIdx.length - 1);
}
return nextMapping != null;
}
public Mapping next() {
return nextMapping;
}
public void remove() {
throw new UnsupportedOperationException();
}
private void computeNextMapping(int level) {
int t = columnSets[level].nextSetBit(iterationIdx[level]);
if (t < 0) {
if (level == 0) {
nextMapping = null;
} else {
int tmp = columnSets[level].nextSetBit(0);
nextMapping.set(columns[level], tmp);
iterationIdx[level] = tmp + 1;
computeNextMapping(level - 1);
}
} else {
nextMapping.set(columns[level], t);
iterationIdx[level] = t + 1;
}
}
private void initializeMapping() {
nextMapping = Mappings.create(MappingType.PARTIAL_FUNCTION,
nSysFields + nFieldsLeft + nFieldsRight,
nSysFields + nFieldsLeft + nFieldsRight);
for (int i = 0; i < columnSets.length; i++) {
BitSet c = columnSets[i];
int t = c.nextSetBit(iterationIdx[i]);
if (t < 0) {
nextMapping = null;
return;
}
nextMapping.set(columns[i], t);
iterationIdx[i] = t + 1;
}
}
}
private int pos(RexNode expr) {
if (expr instanceof RexInputRef) {
return ((RexInputRef) expr).getIndex();
}
return -1;
}
private boolean isAlwaysTrue(RexNode predicate) {
if (predicate instanceof RexCall) {
RexCall c = (RexCall) predicate;
if (c.getOperator().getKind() == SqlKind.EQUALS) {
int lPos = pos(c.getOperands().get(0));
int rPos = pos(c.getOperands().get(1));
return lPos != -1 && lPos == rPos;
}
}
return predicate.isAlwaysTrue();
}
}
}