All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.adapter.enumerable.EnumerableMergeJoin Maven / Gradle / Ivy

There is a newer version: 5.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.adapter.enumerable;

import com.hazelcast.org.apache.calcite.adapter.java.JavaTypeFactory;
import com.hazelcast.org.apache.calcite.linq4j.EnumerableDefaults;
import com.hazelcast.org.apache.calcite.linq4j.tree.BlockBuilder;
import com.hazelcast.org.apache.calcite.linq4j.tree.Expression;
import com.hazelcast.org.apache.calcite.linq4j.tree.Expressions;
import com.hazelcast.org.apache.calcite.linq4j.tree.ParameterExpression;
import com.hazelcast.org.apache.calcite.plan.DeriveMode;
import com.hazelcast.org.apache.calcite.plan.RelOptCluster;
import com.hazelcast.org.apache.calcite.plan.RelOptCost;
import com.hazelcast.org.apache.calcite.plan.RelOptPlanner;
import com.hazelcast.org.apache.calcite.plan.RelTraitSet;
import com.hazelcast.org.apache.calcite.rel.RelCollation;
import com.hazelcast.org.apache.calcite.rel.RelCollationTraitDef;
import com.hazelcast.org.apache.calcite.rel.RelCollations;
import com.hazelcast.org.apache.calcite.rel.RelFieldCollation;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.CorrelationId;
import com.hazelcast.org.apache.calcite.rel.core.Join;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rel.metadata.RelMdCollation;
import com.hazelcast.org.apache.calcite.rel.metadata.RelMetadataQuery;
import com.hazelcast.org.apache.calcite.rel.type.RelDataType;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexUtil;
import com.hazelcast.org.apache.calcite.util.BuiltInMethod;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.apache.calcite.util.ImmutableIntList;
import com.hazelcast.org.apache.calcite.util.Pair;
import com.hazelcast.org.apache.calcite.util.Util;
import com.hazelcast.org.apache.calcite.util.mapping.Mappings;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.com.google.common.collect.ImmutableSet;

import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.hazelcast.org.apache.calcite.rel.RelCollations.containsOrderless;

import static java.util.Objects.requireNonNull;

/** Implementation of {@link com.hazelcast.org.apache.calcite.rel.core.Join} in
 * {@link EnumerableConvention enumerable calling convention} using
 * a merge algorithm. */
public class EnumerableMergeJoin extends Join implements EnumerableRel {
  protected EnumerableMergeJoin(
      RelOptCluster cluster,
      RelTraitSet traits,
      RelNode left,
      RelNode right,
      RexNode condition,
      Set variablesSet,
      JoinRelType joinType) {
    super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType);
    assert getConvention() instanceof EnumerableConvention;
    final List leftCollations = getCollations(left.getTraitSet());
    final List rightCollations = getCollations(right.getTraitSet());

    // If the join keys are not distinct, the sanity check doesn't apply.
    // e.g. t1.a=t2.b and t1.a=t2.c
    boolean isDistinct = Util.isDistinct(joinInfo.leftKeys)
        && Util.isDistinct(joinInfo.rightKeys);

    if (!RelCollations.collationsContainKeysOrderless(leftCollations, joinInfo.leftKeys)
        || !RelCollations.collationsContainKeysOrderless(rightCollations, joinInfo.rightKeys)) {
      if (isDistinct) {
        throw new RuntimeException("wrong collation in left or right input");
      }
    }

    final List collations =
        traits.getTraits(RelCollationTraitDef.INSTANCE);
    assert collations != null && collations.size() > 0;
    ImmutableIntList rightKeys = joinInfo.rightKeys
        .incr(left.getRowType().getFieldCount());
    // Currently it has very limited ability to represent the equivalent traits
    // due to the flaw of RelCompositeTrait, so the following case is totally
    // legit, but not yet supported:
    // SELECT * FROM foo JOIN bar ON foo.a = bar.c AND foo.b = bar.d;
    // MergeJoin has collation on [a, d], or [b, c]
    if (!RelCollations.collationsContainKeysOrderless(collations, joinInfo.leftKeys)
        && !RelCollations.collationsContainKeysOrderless(collations, rightKeys)
        && !RelCollations.keysContainCollationsOrderless(joinInfo.leftKeys, collations)
        && !RelCollations.keysContainCollationsOrderless(rightKeys, collations)) {
      if (isDistinct) {
        throw new RuntimeException("wrong collation for mergejoin");
      }
    }
    if (!isMergeJoinSupported(joinType)) {
      throw new UnsupportedOperationException(
          "EnumerableMergeJoin unsupported for join type " + joinType);
    }
  }

  public static boolean isMergeJoinSupported(JoinRelType joinType) {
    return EnumerableDefaults.isMergeJoinSupported(EnumUtils.toLinq4jJoinType(joinType));
  }

  private static RelCollation getCollation(RelTraitSet traits) {
    return requireNonNull(traits.getCollation(),
        () -> "no collation trait in " + traits);
  }

  private static List getCollations(RelTraitSet traits) {
    return requireNonNull(traits.getTraits(RelCollationTraitDef.INSTANCE),
        () -> "no collation trait in " + traits);
  }

  @Deprecated // to be removed before 2.0
  EnumerableMergeJoin(RelOptCluster cluster, RelTraitSet traits, RelNode left,
      RelNode right, RexNode condition, ImmutableIntList leftKeys,
      ImmutableIntList rightKeys, Set variablesSet,
      JoinRelType joinType) {
    this(cluster, traits, left, right, condition, variablesSet, joinType);
  }

  @Deprecated // to be removed before 2.0
  EnumerableMergeJoin(RelOptCluster cluster, RelTraitSet traits, RelNode left,
      RelNode right, RexNode condition, ImmutableIntList leftKeys,
      ImmutableIntList rightKeys, JoinRelType joinType,
      Set variablesStopped) {
    this(cluster, traits, left, right, condition, leftKeys, rightKeys,
        CorrelationId.setOf(variablesStopped), joinType);
  }

  /**
   * Pass collations through can have three cases:
   * 1. If sort keys are equal to either left join keys, or right join keys,
   * collations can be pushed to both join sides with correct mappings.
   * For example, for the query
   *    select * from foo join bar on foo.a=bar.b order by foo.a desc
   * after traits pass through it will be equivalent to
   *    select * from
   *        (select * from foo order by foo.a desc)
   *        join
   *        (select * from bar order by bar.b desc)
   *
   * 2. If sort keys are sub-set of either left join keys, or right join keys,
   * collations have to be extended to cover all joins keys before passing through,
   * because merge join requires all join keys are sorted.
   * For example, for the query
   *    select * from foo join bar
   *        on foo.a=bar.b and foo.c=bar.d
   *        order by foo.a desc
   * after traits pass through it will be equivalent to
   *    select * from
   *        (select * from foo order by foo.a desc, foo.c)
   *        join
   *        (select * from bar order by bar.b desc, bar.d)
   *
   * 3. If sort keys are super-set of either left join keys, or right join keys,
   * but not both, collations can be completely passed to the join key whose join
   * keys match the prefix of collations. Meanwhile, partial mapped collations can
   * be passed to another join side to make sure join keys are sorted.
   * For example, for the query
   *    select * from foo join bar
   *        on foo.a=bar.b and foo.c=bar.d
   *        order by foo.a desc, foo.c desc, foo.e
   * after traits pass through it will be equivalent to
   *    select * from
   *        (select * from foo order by foo.a desc, foo.c desc, foo.e)
   *        join
   *        (select * from bar order by bar.b desc, bar.d desc)
   */
  @Override public @Nullable Pair> passThroughTraits(
      final RelTraitSet required) {
    // Required collation keys can be subset or superset of merge join keys.
    RelCollation collation = getCollation(required);
    int leftInputFieldCount = left.getRowType().getFieldCount();

    List reqKeys = RelCollations.ordinals(collation);
    List leftKeys = joinInfo.leftKeys.toIntegerList();
    List rightKeys =
        joinInfo.rightKeys.incr(leftInputFieldCount).toIntegerList();

    ImmutableBitSet reqKeySet = ImmutableBitSet.of(reqKeys);
    ImmutableBitSet leftKeySet = ImmutableBitSet.of(joinInfo.leftKeys);
    ImmutableBitSet rightKeySet = ImmutableBitSet.of(joinInfo.rightKeys)
        .shift(leftInputFieldCount);

    if (reqKeySet.equals(leftKeySet)) {
      // if sort keys equal to left join keys, we can pass through all collations directly.
      Mappings.TargetMapping mapping = buildMapping(true);
      RelCollation rightCollation = collation.apply(mapping);
      return Pair.of(
          required, ImmutableList.of(required,
          required.replace(rightCollation)));
    } else if (containsOrderless(leftKeys, collation)) {
      // if sort keys are subset of left join keys, we can extend collations to make sure all join
      // keys are sorted.
      collation = extendCollation(collation, leftKeys);
      Mappings.TargetMapping mapping = buildMapping(true);
      RelCollation rightCollation = collation.apply(mapping);
      return Pair.of(
          required, ImmutableList.of(required.replace(collation),
              required.replace(rightCollation)));
    } else if (containsOrderless(collation, leftKeys)
        && reqKeys.stream().allMatch(i -> i < leftInputFieldCount)) {
      // if sort keys are superset of left join keys, and left join keys is prefix of sort keys
      // (order not matter), also sort keys are all from left join input.
      Mappings.TargetMapping mapping = buildMapping(true);
      RelCollation rightCollation =
          RexUtil.apply(
              mapping,
              intersectCollationAndJoinKey(collation, joinInfo.leftKeys));
      return Pair.of(
          required, ImmutableList.of(required,
              required.replace(rightCollation)));
    } else if (reqKeySet.equals(rightKeySet)) {
      // if sort keys equal to right join keys, we can pass through all collations directly.
      RelCollation rightCollation = RelCollations.shift(collation, -leftInputFieldCount);
      Mappings.TargetMapping mapping = buildMapping(false);
      RelCollation leftCollation = rightCollation.apply(mapping);
      return Pair.of(
          required, ImmutableList.of(
          required.replace(leftCollation),
          required.replace(rightCollation)));
    } else if (containsOrderless(rightKeys, collation)) {
      // if sort keys are subset of right join keys, we can extend collations to make sure all join
      // keys are sorted.
      collation = extendCollation(collation, rightKeys);
      RelCollation rightCollation = RelCollations.shift(collation, -leftInputFieldCount);
      Mappings.TargetMapping mapping = buildMapping(false);
      RelCollation leftCollation = RexUtil.apply(mapping, rightCollation);
      return Pair.of(
          required, ImmutableList.of(
              required.replace(leftCollation),
              required.replace(rightCollation)));
    } else if (containsOrderless(collation, rightKeys)
        && reqKeys.stream().allMatch(i -> i >= leftInputFieldCount)) {
      // if sort keys are superset of right join keys, and right join keys is prefix of sort keys
      // (order not matter), also sort keys are all from right join input.
      RelCollation rightCollation = RelCollations.shift(collation, -leftInputFieldCount);
      Mappings.TargetMapping mapping = buildMapping(false);
      RelCollation leftCollation =
          RexUtil.apply(
              mapping,
              intersectCollationAndJoinKey(rightCollation, joinInfo.rightKeys));
      return Pair.of(
          required, ImmutableList.of(
              required.replace(leftCollation),
              required.replace(rightCollation)));
    }

    return null;
  }

  @Override public @Nullable Pair> deriveTraits(
      final RelTraitSet childTraits, final int childId) {
    final int keyCount = joinInfo.leftKeys.size();
    RelCollation collation = getCollation(childTraits);
    final int colCount = collation.getFieldCollations().size();
    if (colCount < keyCount || keyCount == 0) {
      return null;
    }

    if (colCount > keyCount) {
      collation = RelCollations.of(collation.getFieldCollations().subList(0, keyCount));
    }

    ImmutableIntList sourceKeys = childId == 0 ? joinInfo.leftKeys : joinInfo.rightKeys;
    ImmutableBitSet keySet = ImmutableBitSet.of(sourceKeys);
    ImmutableBitSet childCollationKeys = ImmutableBitSet.of(
        RelCollations.ordinals(collation));
    if (!childCollationKeys.equals(keySet)) {
      return null;
    }

    Mappings.TargetMapping mapping = buildMapping(childId == 0);
    RelCollation targetCollation = collation.apply(mapping);

    if (childId == 0) {
      // traits from left child
      RelTraitSet joinTraits = getTraitSet().replace(collation);
      // Forget about the equiv keys for the moment
      return Pair.of(joinTraits,
          ImmutableList.of(childTraits,
          right.getTraitSet().replace(targetCollation)));
    } else {
      // traits from right child
      assert childId == 1;
      RelTraitSet joinTraits = getTraitSet().replace(targetCollation);
      // Forget about the equiv keys for the moment
      return Pair.of(joinTraits,
          ImmutableList.of(joinTraits,
          childTraits.replace(collation)));
    }
  }

  @Override public DeriveMode getDeriveMode() {
    return DeriveMode.BOTH;
  }

  private Mappings.TargetMapping buildMapping(boolean left2Right) {
    ImmutableIntList sourceKeys = left2Right ? joinInfo.leftKeys : joinInfo.rightKeys;
    ImmutableIntList targetKeys = left2Right ? joinInfo.rightKeys : joinInfo.leftKeys;
    Map keyMap = new HashMap<>();
    for (int i = 0; i < joinInfo.leftKeys.size(); i++) {
      keyMap.put(sourceKeys.get(i), targetKeys.get(i));
    }

    Mappings.TargetMapping mapping = Mappings.target(keyMap,
        (left2Right ? left : right).getRowType().getFieldCount(),
        (left2Right ? right : left).getRowType().getFieldCount());
    return mapping;
  }

  /**
   * This function extends collation by appending new collation fields defined on keys.
   */
  private static RelCollation extendCollation(RelCollation collation, List keys) {
    List fieldsForNewCollation = new ArrayList<>(keys.size());
    fieldsForNewCollation.addAll(collation.getFieldCollations());

    ImmutableBitSet keysBitset = ImmutableBitSet.of(keys);
    ImmutableBitSet colKeysBitset = ImmutableBitSet.of(collation.getKeys());
    ImmutableBitSet exceptBitset = keysBitset.except(colKeysBitset);
    for (Integer i : exceptBitset) {
      fieldsForNewCollation.add(new RelFieldCollation(i));
    }
    return RelCollations.of(fieldsForNewCollation);
  }

  /**
   * This function will remove collations that are not defined on join keys.
   * For example:
   *    select * from
   *    foo join bar
   *    on foo.a = bar.a and foo.c=bar.c
   *    order by bar.a, bar.c, bar.b;
   *
   * The collation [bar.a, bar.c, bar.b] can be pushed down to bar. However, only
   * [a, c] can be pushed down to foo. This function will help create [a, c] for foo by removing
   * b from the required collation, because b is not defined on join keys.
   *
   * @param collation collation defined on the JOIN
   * @param joinKeys  the join keys
   */
  private static RelCollation intersectCollationAndJoinKey(
      RelCollation collation, ImmutableIntList joinKeys) {
    List fieldCollations = new ArrayList<>();
    for (RelFieldCollation rf : collation.getFieldCollations()) {
      if (joinKeys.contains(rf.getFieldIndex())) {
        fieldCollations.add(rf);
      }
    }
    return RelCollations.of(fieldCollations);
  }

  public static EnumerableMergeJoin create(RelNode left, RelNode right,
      RexNode condition, ImmutableIntList leftKeys,
      ImmutableIntList rightKeys, JoinRelType joinType) {
    final RelOptCluster cluster = right.getCluster();
    RelTraitSet traitSet = cluster.traitSetOf(EnumerableConvention.INSTANCE);
    if (traitSet.isEnabled(RelCollationTraitDef.INSTANCE)) {
      final RelMetadataQuery mq = cluster.getMetadataQuery();
      final List collations =
          RelMdCollation.mergeJoin(mq, left, right, leftKeys, rightKeys, joinType);
      traitSet = traitSet.replaceIfs(RelCollationTraitDef.INSTANCE, () -> collations);
    }
    return new EnumerableMergeJoin(cluster, traitSet, left, right, condition,
        ImmutableSet.of(), joinType);
  }

  @Override public EnumerableMergeJoin copy(RelTraitSet traitSet,
      RexNode condition, RelNode left, RelNode right, JoinRelType joinType,
      boolean semiJoinDone) {
    return new EnumerableMergeJoin(getCluster(), traitSet, left, right,
        condition, variablesSet, joinType);
  }

  @Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
      RelMetadataQuery mq) {
    // We assume that the inputs are sorted. The price of sorting them has
    // already been paid. The cost of the join is therefore proportional to the
    // input and output size.
    final double rightRowCount = right.estimateRowCount(mq);
    final double leftRowCount = left.estimateRowCount(mq);
    final double rowCount = mq.getRowCount(this);
    final double d = leftRowCount + rightRowCount + rowCount;
    return planner.getCostFactory().makeCost(d, 0, 0);
  }

  @Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
    BlockBuilder builder = new BlockBuilder();
    final Result leftResult =
        implementor.visitChild(this, 0, (EnumerableRel) left, pref);
    final Expression leftExpression =
        builder.append("left", leftResult.block);
    final ParameterExpression left_ =
        Expressions.parameter(leftResult.physType.getJavaRowType(), "left");
    final Result rightResult =
        implementor.visitChild(this, 1, (EnumerableRel) right, pref);
    final Expression rightExpression =
        builder.append("right", rightResult.block);
    final ParameterExpression right_ =
        Expressions.parameter(rightResult.physType.getJavaRowType(), "right");
    final JavaTypeFactory typeFactory = implementor.getTypeFactory();
    final PhysType physType =
        PhysTypeImpl.of(typeFactory, getRowType(), pref.preferArray());
    final List leftExpressions = new ArrayList<>();
    final List rightExpressions = new ArrayList<>();
    for (Pair pair : Pair.zip(joinInfo.leftKeys, joinInfo.rightKeys)) {
      RelDataType leftType = left.getRowType().getFieldList().get(pair.left).getType();
      RelDataType rightType = right.getRowType().getFieldList().get(pair.right).getType();
      final RelDataType keyType = requireNonNull(
          typeFactory.leastRestrictive(ImmutableList.of(leftType, rightType)),
          () -> "leastRestrictive returns null for " + leftType + " and " + rightType);
      final Type keyClass = typeFactory.getJavaClass(keyType);
      leftExpressions.add(
          EnumUtils.convert(
              leftResult.physType.fieldReference(left_, pair.left), keyClass));
      rightExpressions.add(
          EnumUtils.convert(
              rightResult.physType.fieldReference(right_, pair.right), keyClass));
    }
    Expression predicate = Expressions.constant(null);
    if (!joinInfo.nonEquiConditions.isEmpty()) {
      final RexNode nonEquiCondition = RexUtil.composeConjunction(
          getCluster().getRexBuilder(), joinInfo.nonEquiConditions, true);
      if (nonEquiCondition != null) {
        predicate = EnumUtils.generatePredicate(implementor, getCluster().getRexBuilder(),
            left, right, leftResult.physType, rightResult.physType, nonEquiCondition);
      }
    }
    final PhysType leftKeyPhysType =
        leftResult.physType.project(joinInfo.leftKeys, JavaRowFormat.LIST);
    final PhysType rightKeyPhysType =
        rightResult.physType.project(joinInfo.rightKeys, JavaRowFormat.LIST);

    // Generate the appropriate key Comparator (keys must be sorted in ascending order, nulls last).
    final int keysSize = joinInfo.leftKeys.size();
    final List fieldCollations = new ArrayList<>(keysSize);
    for (int i = 0; i < keysSize; i++) {
      fieldCollations.add(
          new RelFieldCollation(i, RelFieldCollation.Direction.ASCENDING,
              RelFieldCollation.NullDirection.LAST));
    }
    final RelCollation collation = RelCollations.of(fieldCollations);
    final Expression comparator = leftKeyPhysType.generateComparator(collation);

    return implementor.result(
        physType,
        builder.append(
            Expressions.call(
                BuiltInMethod.MERGE_JOIN.method,
                Expressions.list(
                    leftExpression,
                    rightExpression,
                    Expressions.lambda(
                        leftKeyPhysType.record(leftExpressions), left_),
                    Expressions.lambda(
                        rightKeyPhysType.record(rightExpressions), right_),
                    predicate,
                    EnumUtils.joinSelector(joinType,
                        physType,
                        ImmutableList.of(
                            leftResult.physType, rightResult.physType)),
                    Expressions.constant(EnumUtils.toLinq4jJoinType(joinType)),
                    comparator))).toBlock());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy