com.hazelcast.org.apache.calcite.rel.rules.JoinToMultiJoinRule Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.org.apache.calcite.rel.rules;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeField;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.rex.RexVisitorImpl;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.com.google.common.collect.ImmutableMap;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static java.util.Objects.requireNonNull;
/**
* Planner rule to flatten a tree of
* {@link com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin}s
* into a single {@link MultiJoin} with N inputs.
*
* An input is not flattened if
* the input is a null generating input in an outer join, i.e., either input in
* a full outer join, the right hand side of a left outer join, or the left hand
* side of a right outer join.
*
*
Join conditions are also pulled up from the inputs into the topmost
* {@link MultiJoin},
* unless the input corresponds to a null generating input in an outer join,
*
*
Outer join information is also stored in the {@link MultiJoin}. A
* boolean flag indicates if the join is a full outer join, and in the case of
* left and right outer joins, the join type and outer join conditions are
* stored in arrays in the {@link MultiJoin}. This outer join information is
* associated with the null generating input in the outer join. So, in the case
* of a a left outer join between A and B, the information is associated with B,
* not A.
*
*
Here are examples of the {@link MultiJoin}s constructed after this rule
* has been applied on following join trees.
*
*
* - A JOIN B → MJ(A, B)
*
*
- A JOIN B JOIN C → MJ(A, B, C)
*
*
- A LEFT JOIN B → MJ(A, B), left outer join on input#1
*
*
- A RIGHT JOIN B → MJ(A, B), right outer join on input#0
*
*
- A FULL JOIN B → MJ[full](A, B)
*
*
- A LEFT JOIN (B JOIN C) → MJ(A, MJ(B, C))), left outer join on
* input#1 in the outermost MultiJoin
*
*
- (A JOIN B) LEFT JOIN C → MJ(A, B, C), left outer join on input#2
*
*
- (A LEFT JOIN B) JOIN C → MJ(MJ(A, B), C), left outer join on input#1
* of the inner MultiJoin TODO
*
*
- A LEFT JOIN (B FULL JOIN C) → MJ(A, MJ[full](B, C)), left outer join
* on input#1 in the outermost MultiJoin
*
*
- (A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) →
* MJ[full](MJ(A, B), MJ(C, D)), left outer join on input #1 in the first
* inner MultiJoin and right outer join on input#0 in the second inner
* MultiJoin
*
*
* The constructor is parameterized to allow any sub-class of
* {@link com.hazelcast.org.apache.calcite.rel.core.Join}, not just
* {@link com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin}.
*
* @see com.hazelcast.org.apache.calcite.rel.rules.FilterMultiJoinMergeRule
* @see com.hazelcast.org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule
* @see CoreRules#JOIN_TO_MULTI_JOIN
*/
@Value.Enclosing
public class JoinToMultiJoinRule
extends RelRule
implements TransformationRule {
/** Creates a JoinToMultiJoinRule. */
protected JoinToMultiJoinRule(Config config) {
super(config);
}
@Deprecated // to be removed before 2.0
public JoinToMultiJoinRule(Class extends Join> clazz) {
this(Config.DEFAULT.withOperandFor(clazz));
}
@Deprecated // to be removed before 2.0
public JoinToMultiJoinRule(Class extends Join> joinClass,
RelBuilderFactory relBuilderFactory) {
this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory)
.as(Config.class)
.withOperandFor(joinClass));
}
//~ Methods ----------------------------------------------------------------
@Override public boolean matches(RelOptRuleCall call) {
final Join origJoin = call.rel(0);
return origJoin.getJoinType().projectsRight();
}
@Override public void onMatch(RelOptRuleCall call) {
final Join origJoin = call.rel(0);
final RelNode left = call.rel(1);
final RelNode right = call.rel(2);
// combine the children MultiJoin inputs into an array of inputs
// for the new MultiJoin
final List<@Nullable ImmutableBitSet> projFieldsList = new ArrayList<>();
final List joinFieldRefCountsList = new ArrayList<>();
final List newInputs =
combineInputs(
origJoin,
left,
right,
projFieldsList,
joinFieldRefCountsList);
// combine the outer join information from the left and right
// inputs, and include the outer join information from the current
// join, if it's a left/right outer join
final List> joinSpecs = new ArrayList<>();
combineOuterJoins(
origJoin,
newInputs,
left,
right,
joinSpecs);
// pull up the join filters from the children MultiJoinRels and
// combine them with the join filter associated with this LogicalJoin to
// form the join filter for the new MultiJoin
List<@Nullable RexNode> newJoinFilters = combineJoinFilters(origJoin, left, right);
// add on the join field reference counts for the join condition
// associated with this LogicalJoin
final ImmutableMap newJoinFieldRefCountsMap =
addOnJoinFieldRefCounts(newInputs,
origJoin.getRowType().getFieldCount(),
origJoin.getCondition(),
joinFieldRefCountsList);
List<@Nullable RexNode> newPostJoinFilters =
combinePostJoinFilters(origJoin, left, right);
final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder();
RelNode multiJoin =
new MultiJoin(
origJoin.getCluster(),
newInputs,
RexUtil.composeConjunction(rexBuilder, newJoinFilters),
origJoin.getRowType(),
origJoin.getJoinType() == JoinRelType.FULL,
Pair.right(joinSpecs),
Pair.left(joinSpecs),
projFieldsList,
newJoinFieldRefCountsMap,
RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true));
call.transformTo(multiJoin);
}
/**
* Combines the inputs into a LogicalJoin into an array of inputs.
*
* @param join original join
* @param left left input into join
* @param right right input into join
* @param projFieldsList returns a list of the new combined projection
* fields
* @param joinFieldRefCountsList returns a list of the new combined join
* field reference counts
* @return combined left and right inputs in an array
*/
private static List combineInputs(
Join join,
RelNode left,
RelNode right,
List<@Nullable ImmutableBitSet> projFieldsList,
List joinFieldRefCountsList) {
final List newInputs = new ArrayList<>();
// leave the null generating sides of an outer join intact; don't
// pull up those children inputs into the array we're constructing
if (canCombine(left, join.getJoinType().generatesNullsOnLeft())) {
final MultiJoin leftMultiJoin = (MultiJoin) left;
for (int i = 0; i < left.getInputs().size(); i++) {
newInputs.add(leftMultiJoin.getInput(i));
projFieldsList.add(leftMultiJoin.getProjFields().get(i));
joinFieldRefCountsList.add(
requireNonNull(leftMultiJoin.getJoinFieldRefCountsMap().get(i)).toIntArray());
}
} else {
newInputs.add(left);
projFieldsList.add(null);
joinFieldRefCountsList.add(
new int[left.getRowType().getFieldCount()]);
}
if (canCombine(right, join.getJoinType().generatesNullsOnRight())) {
final MultiJoin rightMultiJoin = (MultiJoin) right;
for (int i = 0; i < right.getInputs().size(); i++) {
newInputs.add(rightMultiJoin.getInput(i));
projFieldsList.add(
rightMultiJoin.getProjFields().get(i));
joinFieldRefCountsList.add(
requireNonNull(rightMultiJoin.getJoinFieldRefCountsMap().get(i)).toIntArray());
}
} else {
newInputs.add(right);
projFieldsList.add(null);
joinFieldRefCountsList.add(
new int[right.getRowType().getFieldCount()]);
}
return newInputs;
}
/**
* Combines the outer join conditions and join types from the left and right
* join inputs. If the join itself is either a left or right outer join,
* then the join condition corresponding to the join is also set in the
* position corresponding to the null-generating input into the join. The
* join type is also set.
*
* @param joinRel join rel
* @param combinedInputs the combined inputs to the join
* @param left left child of the joinrel
* @param right right child of the joinrel
* @param joinSpecs the list where the join types and conditions will be
* copied
*/
private static void combineOuterJoins(
Join joinRel,
@SuppressWarnings("unused") List combinedInputs,
RelNode left,
RelNode right,
List> joinSpecs) {
JoinRelType joinType = joinRel.getJoinType();
boolean leftCombined =
canCombine(left, joinType.generatesNullsOnLeft());
boolean rightCombined =
canCombine(right, joinType.generatesNullsOnRight());
switch (joinType) {
case LEFT:
if (leftCombined) {
copyOuterJoinInfo(
(MultiJoin) left,
joinSpecs,
0,
null,
null);
} else {
joinSpecs.add(Pair.of(JoinRelType.INNER, (@Nullable RexNode) null));
}
joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
break;
case RIGHT:
joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
if (rightCombined) {
copyOuterJoinInfo(
(MultiJoin) right,
joinSpecs,
left.getRowType().getFieldCount(),
right.getRowType().getFieldList(),
joinRel.getRowType().getFieldList());
} else {
joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
}
break;
default:
if (leftCombined) {
copyOuterJoinInfo(
(MultiJoin) left,
joinSpecs,
0,
null,
null);
} else {
joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
}
if (rightCombined) {
copyOuterJoinInfo(
(MultiJoin) right,
joinSpecs,
left.getRowType().getFieldCount(),
right.getRowType().getFieldList(),
joinRel.getRowType().getFieldList());
} else {
joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null));
}
}
}
/**
* Copies outer join data from a source MultiJoin to a new set of arrays.
* Also adjusts the conditions to reflect the new position of an input if
* that input ends up being shifted to the right.
*
* @param multiJoin the source MultiJoin
* @param destJoinSpecs the list where the join types and conditions will
* be copied
* @param adjustmentAmount if > 0, the amount the RexInputRefs in the join
* conditions need to be adjusted by
* @param srcFields the source fields that the original join conditions
* are referencing
* @param destFields the destination fields that the new join conditions
*/
private static void copyOuterJoinInfo(
MultiJoin multiJoin,
List> destJoinSpecs,
int adjustmentAmount,
@Nullable List srcFields,
@Nullable List destFields) {
final List> srcJoinSpecs =
Pair.zip(
multiJoin.getJoinTypes(),
multiJoin.getOuterJoinConditions());
if (adjustmentAmount == 0) {
destJoinSpecs.addAll(srcJoinSpecs);
} else {
assert srcFields != null;
assert destFields != null;
int nFields = srcFields.size();
int[] adjustments = new int[nFields];
for (int idx = 0; idx < nFields; idx++) {
adjustments[idx] = adjustmentAmount;
}
for (Pair src
: srcJoinSpecs) {
destJoinSpecs.add(
Pair.of(
src.left,
src.right == null
? null
: src.right.accept(
new RelOptUtil.RexInputConverter(
multiJoin.getCluster().getRexBuilder(),
srcFields, destFields, adjustments))));
}
}
}
/**
* Combines the join filters from the left and right inputs (if they are
* MultiJoinRels) with the join filter in the joinrel into a single AND'd
* join filter, unless the inputs correspond to null generating inputs in an
* outer join.
*
* @param join Join
* @param left Left input of the join
* @param right Right input of the join
* @return combined join filters AND-ed together
*/
private static List<@Nullable RexNode> combineJoinFilters(
Join join,
RelNode left,
RelNode right) {
JoinRelType joinType = join.getJoinType();
// AND the join condition if this isn't a left or right outer join;
// in those cases, the outer join condition is already tracked
// separately
final List<@Nullable RexNode> filters = new ArrayList<>();
if ((joinType != JoinRelType.LEFT) && (joinType != JoinRelType.RIGHT)) {
filters.add(join.getCondition());
}
if (canCombine(left, joinType.generatesNullsOnLeft())) {
filters.add(((MultiJoin) left).getJoinFilter());
}
// Need to adjust the RexInputs of the right child, since
// those need to shift over to the right
if (canCombine(right, joinType.generatesNullsOnRight())) {
MultiJoin multiJoin = (MultiJoin) right;
filters.add(
shiftRightFilter(join, left, multiJoin,
multiJoin.getJoinFilter()));
}
return filters;
}
/**
* Returns whether an input can be merged into a given relational expression
* without changing semantics.
*
* @param input input into a join
* @param nullGenerating true if the input is null generating
* @return true if the input can be combined into a parent MultiJoin
*/
private static boolean canCombine(RelNode input, boolean nullGenerating) {
return input instanceof MultiJoin
&& !((MultiJoin) input).isFullOuterJoin()
&& !((MultiJoin) input).containsOuter()
&& !nullGenerating;
}
/**
* Shifts a filter originating from the right child of the LogicalJoin to the
* right, to reflect the filter now being applied on the resulting
* MultiJoin.
*
* @param joinRel the original LogicalJoin
* @param left the left child of the LogicalJoin
* @param right the right child of the LogicalJoin
* @param rightFilter the filter originating from the right child
* @return the adjusted right filter
*/
private static @Nullable RexNode shiftRightFilter(
Join joinRel,
RelNode left,
MultiJoin right,
@Nullable RexNode rightFilter) {
if (rightFilter == null) {
return null;
}
int nFieldsOnLeft = left.getRowType().getFieldList().size();
int nFieldsOnRight = right.getRowType().getFieldList().size();
int[] adjustments = new int[nFieldsOnRight];
for (int i = 0; i < nFieldsOnRight; i++) {
adjustments[i] = nFieldsOnLeft;
}
rightFilter =
rightFilter.accept(
new RelOptUtil.RexInputConverter(
joinRel.getCluster().getRexBuilder(),
right.getRowType().getFieldList(),
joinRel.getRowType().getFieldList(),
adjustments));
return rightFilter;
}
/**
* Adds on to the existing join condition reference counts the references
* from the new join condition.
*
* @param multiJoinInputs inputs into the new MultiJoin
* @param nTotalFields total number of fields in the MultiJoin
* @param joinCondition the new join condition
* @param origJoinFieldRefCounts existing join condition reference counts
*
* @return Map containing the new join condition
*/
private static ImmutableMap addOnJoinFieldRefCounts(
List multiJoinInputs,
int nTotalFields,
RexNode joinCondition,
List origJoinFieldRefCounts) {
// count the input references in the join condition
int[] joinCondRefCounts = new int[nTotalFields];
joinCondition.accept(new InputReferenceCounter(joinCondRefCounts));
// first, make a copy of the ref counters
final Map refCountsMap = new HashMap<>();
int nInputs = multiJoinInputs.size();
int currInput = 0;
for (int[] origRefCounts : origJoinFieldRefCounts) {
refCountsMap.put(
currInput,
origRefCounts.clone());
currInput++;
}
// add on to the counts for each input into the MultiJoin the
// reference counts computed for the current join condition
currInput = -1;
int startField = 0;
int nFields = 0;
for (int i = 0; i < nTotalFields; i++) {
if (joinCondRefCounts[i] == 0) {
continue;
}
while (i >= (startField + nFields)) {
startField += nFields;
currInput++;
assert currInput < nInputs;
nFields =
multiJoinInputs.get(currInput).getRowType().getFieldCount();
}
final int key = currInput;
int[] refCounts = requireNonNull(refCountsMap.get(key),
() -> "refCountsMap.get(currInput) for " + key);
refCounts[i - startField] += joinCondRefCounts[i];
}
final ImmutableMap.Builder builder =
ImmutableMap.builder();
for (Map.Entry entry : refCountsMap.entrySet()) {
builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue()));
}
return builder.build();
}
/**
* Combines the post-join filters from the left and right inputs (if they
* are MultiJoinRels) into a single AND'd filter.
*
* @param joinRel the original LogicalJoin
* @param left left child of the LogicalJoin
* @param right right child of the LogicalJoin
* @return combined post-join filters AND'd together
*/
private static List<@Nullable RexNode> combinePostJoinFilters(
Join joinRel,
RelNode left,
RelNode right) {
final List<@Nullable RexNode> filters = new ArrayList<>();
if (right instanceof MultiJoin) {
final MultiJoin multiRight = (MultiJoin) right;
filters.add(
shiftRightFilter(joinRel, left, multiRight,
multiRight.getPostJoinFilter()));
}
if (left instanceof MultiJoin) {
filters.add(((MultiJoin) left).getPostJoinFilter());
}
return filters;
}
//~ Inner Classes ----------------------------------------------------------
/**
* Visitor that keeps a reference count of the inputs used by an expression.
*/
private static class InputReferenceCounter extends RexVisitorImpl {
private final int[] refCounts;
InputReferenceCounter(int[] refCounts) {
super(true);
this.refCounts = refCounts;
}
@Override public Void visitInputRef(RexInputRef inputRef) {
refCounts[inputRef.getIndex()]++;
return null;
}
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableJoinToMultiJoinRule.Config.of()
.withOperandFor(LogicalJoin.class);
@Override default JoinToMultiJoinRule toRule() {
return new JoinToMultiJoinRule(this);
}
/** Defines an operand tree for the given classes. */
default Config withOperandFor(Class extends Join> joinClass) {
return withOperandSupplier(b0 ->
b0.operand(joinClass).inputs(
b1 -> b1.operand(RelNode.class).anyInputs(),
b2 -> b2.operand(RelNode.class).anyInputs()))
.as(Config.class);
}
}
}