All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.rel.rules.JoinToMultiJoinRule Maven / Gradle / Ivy

There is a newer version: 5.4.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.rel.rules;

import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin;
import com.hazelcast.org.apache.calcite.rel.type.RelDataTypeField;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.rex.RexVisitorImpl;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Pair;

import com.hazelcast.com.google.common.collect.ImmutableMap;

import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.Objects.requireNonNull;

/**
 * Planner rule to flatten a tree of
 * {@link com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin}s
 * into a single {@link MultiJoin} with N inputs.
 *
 * 

An input is not flattened if * the input is a null generating input in an outer join, i.e., either input in * a full outer join, the right hand side of a left outer join, or the left hand * side of a right outer join. * *

Join conditions are also pulled up from the inputs into the topmost * {@link MultiJoin}, * unless the input corresponds to a null generating input in an outer join, * *

Outer join information is also stored in the {@link MultiJoin}. A * boolean flag indicates if the join is a full outer join, and in the case of * left and right outer joins, the join type and outer join conditions are * stored in arrays in the {@link MultiJoin}. This outer join information is * associated with the null generating input in the outer join. So, in the case * of a a left outer join between A and B, the information is associated with B, * not A. * *

Here are examples of the {@link MultiJoin}s constructed after this rule * has been applied on following join trees. * *

    *
  • A JOIN B → MJ(A, B) * *
  • A JOIN B JOIN C → MJ(A, B, C) * *
  • A LEFT JOIN B → MJ(A, B), left outer join on input#1 * *
  • A RIGHT JOIN B → MJ(A, B), right outer join on input#0 * *
  • A FULL JOIN B → MJ[full](A, B) * *
  • A LEFT JOIN (B JOIN C) → MJ(A, MJ(B, C))), left outer join on * input#1 in the outermost MultiJoin * *
  • (A JOIN B) LEFT JOIN C → MJ(A, B, C), left outer join on input#2 * *
  • (A LEFT JOIN B) JOIN C → MJ(MJ(A, B), C), left outer join on input#1 * of the inner MultiJoin TODO * *
  • A LEFT JOIN (B FULL JOIN C) → MJ(A, MJ[full](B, C)), left outer join * on input#1 in the outermost MultiJoin * *
  • (A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) → * MJ[full](MJ(A, B), MJ(C, D)), left outer join on input #1 in the first * inner MultiJoin and right outer join on input#0 in the second inner * MultiJoin *
* *

The constructor is parameterized to allow any sub-class of * {@link com.hazelcast.org.apache.calcite.rel.core.Join}, not just * {@link com.hazelcast.org.apache.calcite.rel.logical.LogicalJoin}.

* * @see com.hazelcast.org.apache.calcite.rel.rules.FilterMultiJoinMergeRule * @see com.hazelcast.org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule * @see CoreRules#JOIN_TO_MULTI_JOIN */ @Value.Enclosing public class JoinToMultiJoinRule extends RelRule implements TransformationRule { /** Creates a JoinToMultiJoinRule. */ protected JoinToMultiJoinRule(Config config) { super(config); } @Deprecated // to be removed before 2.0 public JoinToMultiJoinRule(Class clazz) { this(Config.DEFAULT.withOperandFor(clazz)); } @Deprecated // to be removed before 2.0 public JoinToMultiJoinRule(Class joinClass, RelBuilderFactory relBuilderFactory) { this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory) .as(Config.class) .withOperandFor(joinClass)); } //~ Methods ---------------------------------------------------------------- @Override public boolean matches(RelOptRuleCall call) { final Join origJoin = call.rel(0); return origJoin.getJoinType().projectsRight(); } @Override public void onMatch(RelOptRuleCall call) { final Join origJoin = call.rel(0); final RelNode left = call.rel(1); final RelNode right = call.rel(2); // combine the children MultiJoin inputs into an array of inputs // for the new MultiJoin final List<@Nullable ImmutableBitSet> projFieldsList = new ArrayList<>(); final List joinFieldRefCountsList = new ArrayList<>(); final List newInputs = combineInputs( origJoin, left, right, projFieldsList, joinFieldRefCountsList); // combine the outer join information from the left and right // inputs, and include the outer join information from the current // join, if it's a left/right outer join final List> joinSpecs = new ArrayList<>(); combineOuterJoins( origJoin, newInputs, left, right, joinSpecs); // pull up the join filters from the children MultiJoinRels and // combine them with the join filter associated with this LogicalJoin to // form the join filter for the new MultiJoin List<@Nullable RexNode> newJoinFilters = combineJoinFilters(origJoin, left, right); // add on the join field reference counts for the join condition // associated with this LogicalJoin final ImmutableMap newJoinFieldRefCountsMap = addOnJoinFieldRefCounts(newInputs, origJoin.getRowType().getFieldCount(), origJoin.getCondition(), joinFieldRefCountsList); List<@Nullable RexNode> newPostJoinFilters = combinePostJoinFilters(origJoin, left, right); final RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder(); RelNode multiJoin = new MultiJoin( origJoin.getCluster(), newInputs, RexUtil.composeConjunction(rexBuilder, newJoinFilters), origJoin.getRowType(), origJoin.getJoinType() == JoinRelType.FULL, Pair.right(joinSpecs), Pair.left(joinSpecs), projFieldsList, newJoinFieldRefCountsMap, RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true)); call.transformTo(multiJoin); } /** * Combines the inputs into a LogicalJoin into an array of inputs. * * @param join original join * @param left left input into join * @param right right input into join * @param projFieldsList returns a list of the new combined projection * fields * @param joinFieldRefCountsList returns a list of the new combined join * field reference counts * @return combined left and right inputs in an array */ private static List combineInputs( Join join, RelNode left, RelNode right, List<@Nullable ImmutableBitSet> projFieldsList, List joinFieldRefCountsList) { final List newInputs = new ArrayList<>(); // leave the null generating sides of an outer join intact; don't // pull up those children inputs into the array we're constructing if (canCombine(left, join.getJoinType().generatesNullsOnLeft())) { final MultiJoin leftMultiJoin = (MultiJoin) left; for (int i = 0; i < left.getInputs().size(); i++) { newInputs.add(leftMultiJoin.getInput(i)); projFieldsList.add(leftMultiJoin.getProjFields().get(i)); joinFieldRefCountsList.add( requireNonNull(leftMultiJoin.getJoinFieldRefCountsMap().get(i)).toIntArray()); } } else { newInputs.add(left); projFieldsList.add(null); joinFieldRefCountsList.add( new int[left.getRowType().getFieldCount()]); } if (canCombine(right, join.getJoinType().generatesNullsOnRight())) { final MultiJoin rightMultiJoin = (MultiJoin) right; for (int i = 0; i < right.getInputs().size(); i++) { newInputs.add(rightMultiJoin.getInput(i)); projFieldsList.add( rightMultiJoin.getProjFields().get(i)); joinFieldRefCountsList.add( requireNonNull(rightMultiJoin.getJoinFieldRefCountsMap().get(i)).toIntArray()); } } else { newInputs.add(right); projFieldsList.add(null); joinFieldRefCountsList.add( new int[right.getRowType().getFieldCount()]); } return newInputs; } /** * Combines the outer join conditions and join types from the left and right * join inputs. If the join itself is either a left or right outer join, * then the join condition corresponding to the join is also set in the * position corresponding to the null-generating input into the join. The * join type is also set. * * @param joinRel join rel * @param combinedInputs the combined inputs to the join * @param left left child of the joinrel * @param right right child of the joinrel * @param joinSpecs the list where the join types and conditions will be * copied */ private static void combineOuterJoins( Join joinRel, @SuppressWarnings("unused") List combinedInputs, RelNode left, RelNode right, List> joinSpecs) { JoinRelType joinType = joinRel.getJoinType(); boolean leftCombined = canCombine(left, joinType.generatesNullsOnLeft()); boolean rightCombined = canCombine(right, joinType.generatesNullsOnRight()); switch (joinType) { case LEFT: if (leftCombined) { copyOuterJoinInfo( (MultiJoin) left, joinSpecs, 0, null, null); } else { joinSpecs.add(Pair.of(JoinRelType.INNER, (@Nullable RexNode) null)); } joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); break; case RIGHT: joinSpecs.add(Pair.of(joinType, joinRel.getCondition())); if (rightCombined) { copyOuterJoinInfo( (MultiJoin) right, joinSpecs, left.getRowType().getFieldCount(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList()); } else { joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); } break; default: if (leftCombined) { copyOuterJoinInfo( (MultiJoin) left, joinSpecs, 0, null, null); } else { joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); } if (rightCombined) { copyOuterJoinInfo( (MultiJoin) right, joinSpecs, left.getRowType().getFieldCount(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList()); } else { joinSpecs.add(Pair.of(JoinRelType.INNER, (RexNode) null)); } } } /** * Copies outer join data from a source MultiJoin to a new set of arrays. * Also adjusts the conditions to reflect the new position of an input if * that input ends up being shifted to the right. * * @param multiJoin the source MultiJoin * @param destJoinSpecs the list where the join types and conditions will * be copied * @param adjustmentAmount if > 0, the amount the RexInputRefs in the join * conditions need to be adjusted by * @param srcFields the source fields that the original join conditions * are referencing * @param destFields the destination fields that the new join conditions */ private static void copyOuterJoinInfo( MultiJoin multiJoin, List> destJoinSpecs, int adjustmentAmount, @Nullable List srcFields, @Nullable List destFields) { final List> srcJoinSpecs = Pair.zip( multiJoin.getJoinTypes(), multiJoin.getOuterJoinConditions()); if (adjustmentAmount == 0) { destJoinSpecs.addAll(srcJoinSpecs); } else { assert srcFields != null; assert destFields != null; int nFields = srcFields.size(); int[] adjustments = new int[nFields]; for (int idx = 0; idx < nFields; idx++) { adjustments[idx] = adjustmentAmount; } for (Pair src : srcJoinSpecs) { destJoinSpecs.add( Pair.of( src.left, src.right == null ? null : src.right.accept( new RelOptUtil.RexInputConverter( multiJoin.getCluster().getRexBuilder(), srcFields, destFields, adjustments)))); } } } /** * Combines the join filters from the left and right inputs (if they are * MultiJoinRels) with the join filter in the joinrel into a single AND'd * join filter, unless the inputs correspond to null generating inputs in an * outer join. * * @param join Join * @param left Left input of the join * @param right Right input of the join * @return combined join filters AND-ed together */ private static List<@Nullable RexNode> combineJoinFilters( Join join, RelNode left, RelNode right) { JoinRelType joinType = join.getJoinType(); // AND the join condition if this isn't a left or right outer join; // in those cases, the outer join condition is already tracked // separately final List<@Nullable RexNode> filters = new ArrayList<>(); if ((joinType != JoinRelType.LEFT) && (joinType != JoinRelType.RIGHT)) { filters.add(join.getCondition()); } if (canCombine(left, joinType.generatesNullsOnLeft())) { filters.add(((MultiJoin) left).getJoinFilter()); } // Need to adjust the RexInputs of the right child, since // those need to shift over to the right if (canCombine(right, joinType.generatesNullsOnRight())) { MultiJoin multiJoin = (MultiJoin) right; filters.add( shiftRightFilter(join, left, multiJoin, multiJoin.getJoinFilter())); } return filters; } /** * Returns whether an input can be merged into a given relational expression * without changing semantics. * * @param input input into a join * @param nullGenerating true if the input is null generating * @return true if the input can be combined into a parent MultiJoin */ private static boolean canCombine(RelNode input, boolean nullGenerating) { return input instanceof MultiJoin && !((MultiJoin) input).isFullOuterJoin() && !((MultiJoin) input).containsOuter() && !nullGenerating; } /** * Shifts a filter originating from the right child of the LogicalJoin to the * right, to reflect the filter now being applied on the resulting * MultiJoin. * * @param joinRel the original LogicalJoin * @param left the left child of the LogicalJoin * @param right the right child of the LogicalJoin * @param rightFilter the filter originating from the right child * @return the adjusted right filter */ private static @Nullable RexNode shiftRightFilter( Join joinRel, RelNode left, MultiJoin right, @Nullable RexNode rightFilter) { if (rightFilter == null) { return null; } int nFieldsOnLeft = left.getRowType().getFieldList().size(); int nFieldsOnRight = right.getRowType().getFieldList().size(); int[] adjustments = new int[nFieldsOnRight]; for (int i = 0; i < nFieldsOnRight; i++) { adjustments[i] = nFieldsOnLeft; } rightFilter = rightFilter.accept( new RelOptUtil.RexInputConverter( joinRel.getCluster().getRexBuilder(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList(), adjustments)); return rightFilter; } /** * Adds on to the existing join condition reference counts the references * from the new join condition. * * @param multiJoinInputs inputs into the new MultiJoin * @param nTotalFields total number of fields in the MultiJoin * @param joinCondition the new join condition * @param origJoinFieldRefCounts existing join condition reference counts * * @return Map containing the new join condition */ private static ImmutableMap addOnJoinFieldRefCounts( List multiJoinInputs, int nTotalFields, RexNode joinCondition, List origJoinFieldRefCounts) { // count the input references in the join condition int[] joinCondRefCounts = new int[nTotalFields]; joinCondition.accept(new InputReferenceCounter(joinCondRefCounts)); // first, make a copy of the ref counters final Map refCountsMap = new HashMap<>(); int nInputs = multiJoinInputs.size(); int currInput = 0; for (int[] origRefCounts : origJoinFieldRefCounts) { refCountsMap.put( currInput, origRefCounts.clone()); currInput++; } // add on to the counts for each input into the MultiJoin the // reference counts computed for the current join condition currInput = -1; int startField = 0; int nFields = 0; for (int i = 0; i < nTotalFields; i++) { if (joinCondRefCounts[i] == 0) { continue; } while (i >= (startField + nFields)) { startField += nFields; currInput++; assert currInput < nInputs; nFields = multiJoinInputs.get(currInput).getRowType().getFieldCount(); } final int key = currInput; int[] refCounts = requireNonNull(refCountsMap.get(key), () -> "refCountsMap.get(currInput) for " + key); refCounts[i - startField] += joinCondRefCounts[i]; } final ImmutableMap.Builder builder = ImmutableMap.builder(); for (Map.Entry entry : refCountsMap.entrySet()) { builder.put(entry.getKey(), ImmutableIntList.of(entry.getValue())); } return builder.build(); } /** * Combines the post-join filters from the left and right inputs (if they * are MultiJoinRels) into a single AND'd filter. * * @param joinRel the original LogicalJoin * @param left left child of the LogicalJoin * @param right right child of the LogicalJoin * @return combined post-join filters AND'd together */ private static List<@Nullable RexNode> combinePostJoinFilters( Join joinRel, RelNode left, RelNode right) { final List<@Nullable RexNode> filters = new ArrayList<>(); if (right instanceof MultiJoin) { final MultiJoin multiRight = (MultiJoin) right; filters.add( shiftRightFilter(joinRel, left, multiRight, multiRight.getPostJoinFilter())); } if (left instanceof MultiJoin) { filters.add(((MultiJoin) left).getPostJoinFilter()); } return filters; } //~ Inner Classes ---------------------------------------------------------- /** * Visitor that keeps a reference count of the inputs used by an expression. */ private static class InputReferenceCounter extends RexVisitorImpl { private final int[] refCounts; InputReferenceCounter(int[] refCounts) { super(true); this.refCounts = refCounts; } @Override public Void visitInputRef(RexInputRef inputRef) { refCounts[inputRef.getIndex()]++; return null; } } /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { Config DEFAULT = ImmutableJoinToMultiJoinRule.Config.of() .withOperandFor(LogicalJoin.class); @Override default JoinToMultiJoinRule toRule() { return new JoinToMultiJoinRule(this); } /** Defines an operand tree for the given classes. */ default Config withOperandFor(Class joinClass) { return withOperandSupplier(b0 -> b0.operand(joinClass).inputs( b1 -> b1.operand(RelNode.class).anyInputs(), b2 -> b2.operand(RelNode.class).anyInputs())) .as(Config.class); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy