All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.util.FlinkRelMdUtil.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.util

import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalRank, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalRank, FlinkLogicalWindowAggregate}
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecRank
import org.apache.flink.table.plan.util.FlinkRelOptUtil.{checkAndGetFullGroupSet, checkAndSplitAggCalls}

import org.apache.calcite.avatica.util.TimeUnitRange._
import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata.{RelMdUtil, RelMetadataQuery}
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.util.{ImmutableBitSet, NumberUtil}
import com.google.common.collect.ImmutableList

import java.lang.Double
import java.math.BigDecimal
import java.util

import scala.collection.JavaConversions._
import scala.collection.mutable

/**
  * FlinkRelMdUtil provides utility methods used by the metadata provider methods.
  */
object FlinkRelMdUtil {

  /**
    * Creates a RexNode that stores a selectivity value corresponding to the
    * selectivity of a semi-join/anti-join. This can be added to a filter to simulate the
    * effect of the semi-join/anti-join during costing, but should never appear in a real
    * plan since it has no physical implementation.
    *
    * @param mq  instance of metadata query
    * @param rel the semiJoin or antiJoin of interest
    * @return constructed rexNode
    */
  def makeSemiJoinSelectivityRexNode(
      mq: RelMetadataQuery,
      rel: SemiJoin): RexNode = {
    val joinInfo = rel.analyzeCondition()
    val rexBuilder = rel.getCluster.getRexBuilder
    makeSemiJoinSelectivityRexNode(mq, joinInfo, rel.getLeft, rel.getRight, rel.isAnti, rexBuilder)
  }

  private def makeSemiJoinSelectivityRexNode(
      mq: RelMetadataQuery,
      joinInfo: JoinInfo,
      left: RelNode,
      right: RelNode,
      isAnti: Boolean,
      rexBuilder: RexBuilder): RexNode = {

    val equiSelectivity: Double = if (!joinInfo.leftKeys.isEmpty) {
      RelMdUtil.computeSemiJoinSelectivity(mq, left, right, joinInfo.leftKeys, joinInfo.rightKeys)
    } else {
      1D
    }

    val nonEquiSelectivity = RelMdUtil.guessSelectivity(joinInfo.getRemaining(rexBuilder))
    val semiJoinSelectivity = equiSelectivity * nonEquiSelectivity

    val selectivity = if (isAnti) {
      val antiJoinSelectivity = 1.0 - semiJoinSelectivity
      if (antiJoinSelectivity == 0.0) {
        // we don't expect that anti-join's selectivity is 0.0, so choose a default value 0.1
        0.1
      } else {
        antiJoinSelectivity
      }
    } else {
      semiJoinSelectivity
    }

    rexBuilder.makeCall(
      RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
      rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
  }

  /**
    * Creates a RexNode that stores a selectivity value corresponding to the
    * selectivity of a NamedProperties predicate.
    *
    * @param winAgg  window aggregate node
    * @param predicate a RexNode
    * @return constructed rexNode including non-NamedProperties predicates and
    *         a predicate that stores NamedProperties predicate's selectivity
    */
  def makeNamePropertiesSelectivityRexNode(
      winAgg: LogicalWindowAggregate,
      predicate: RexNode): RexNode = {
    val fullGroupSet = checkAndGetFullGroupSet(winAgg)
    makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
  }

  /**
    * Creates a RexNode that stores a selectivity value corresponding to the
    * selectivity of a NamedProperties predicate.
    *
    * @param winAgg  window aggregate node
    * @param predicate a RexNode
    * @return constructed rexNode including non-NamedProperties predicates and
    *         a predicate that stores NamedProperties predicate's selectivity
    */
  def makeNamePropertiesSelectivityRexNode(
      winAgg: FlinkLogicalWindowAggregate,
      predicate: RexNode): RexNode = {
    val fullGroupSet = checkAndGetFullGroupSet(winAgg)
    makeNamePropertiesSelectivityRexNode(winAgg, fullGroupSet, winAgg.getNamedProperties, predicate)
  }

  /**
    * Creates a RexNode that stores a selectivity value corresponding to the
    * selectivity of a NamedProperties predicate.
    *
    * @param globalWinAgg global window aggregate node
    * @param predicate a RexNode
    * @return constructed rexNode including non-NamedProperties predicates and
    *         a predicate that stores NamedProperties predicate's selectivity
    */
  def makeNamePropertiesSelectivityRexNode(
      globalWinAgg: BatchExecWindowAggregateBase,
      predicate: RexNode): RexNode = {
    require(globalWinAgg.isFinal, "local window agg does not contain NamedProperties!")
    val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
    makeNamePropertiesSelectivityRexNode(
      globalWinAgg, fullGrouping, globalWinAgg.getNamedProperties, predicate)
  }

  /**
    * Creates a RexNode that stores a selectivity value corresponding to the
    * selectivity of a NamedProperties predicate.
    *
    * @param winAgg window aggregate node
    * @param fullGrouping full groupSets
    * @param namedProperties NamedWindowProperty list
    * @param predicate a RexNode
    * @return constructed rexNode including non-NamedProperties predicates and
    *         a predicate that stores NamedProperties predicate's selectivity
    */
  def makeNamePropertiesSelectivityRexNode(
      winAgg: SingleRel,
      fullGrouping: Array[Int],
      namedProperties: Seq[NamedWindowProperty],
      predicate: RexNode): RexNode = {
    if (predicate == null || predicate.isAlwaysTrue || namedProperties.isEmpty) {
      return predicate
    }
    val rexBuilder = winAgg.getCluster.getRexBuilder
    val namePropertiesStartIdx = winAgg.getRowType.getFieldCount - namedProperties.size
    // split non-nameProperties predicates and nameProperties predicates
    val pushable = new util.ArrayList[RexNode]
    val notPushable = new util.ArrayList[RexNode]
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, namePropertiesStartIdx),
      predicate,
      pushable,
      notPushable)
    if (notPushable.nonEmpty) {
      val pred = RexUtil.composeConjunction(rexBuilder, notPushable, true)
      val selectivity = RelMdUtil.guessSelectivity(pred)
      val fun = rexBuilder.makeCall(
        RelMdUtil.ARTIFICIAL_SELECTIVITY_FUNC,
        rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)))
      pushable.add(fun)
    }
    RexUtil.composeConjunction(rexBuilder, pushable, true)
  }

  /**
    * Computes the cardinality of a particular expression from the projection
    * list.
    *
    * @param mq   metadata query instance
    * @param calc calc RelNode
    * @param expr projection expression
    * @return cardinality
    */
  def cardOfCalcExpr(mq: RelMetadataQuery, calc: Calc, expr: RexNode): Double = {
    expr.accept(new CardOfCalcExpr(mq, calc))
  }

  /**
    * Visitor that walks over a scalar expression and computes the
    * cardinality of its result.
    * The code is borrowed from RelMdUtil
    *
    * @param mq   metadata query instance
    * @param calc calc relnode
    */
  private class CardOfCalcExpr(
      mq: RelMetadataQuery,
      calc: Calc)
    extends RexVisitorImpl[Double](true) {
    private val program = calc.getProgram

    private val condition = if (program.getCondition != null) {
      program.expandLocalRef(program.getCondition)
    } else {
      null
    }

    override def visitInputRef(inputRef: RexInputRef): Double = {
      val col = ImmutableBitSet.of(inputRef.getIndex)
      val distinctRowCount = mq.getDistinctRowCount(calc.getInput, col, condition)
      if (distinctRowCount == null) {
        null
      } else {
        RelMdUtil.numDistinctVals(distinctRowCount, mq.getAverageRowSize(calc))
      }
    }

    override def visitLiteral(literal: RexLiteral): Double = {
      RelMdUtil.numDistinctVals(1D, mq.getAverageRowSize(calc))
    }

    override def visitCall(call: RexCall): Double = {
      val rowCount = mq.getRowCount(calc)
      val distinctRowCount: Double = if (call.isA(SqlKind.MINUS_PREFIX)) {
        cardOfCalcExpr(mq, calc, call.getOperands.get(0))
      } else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) {
        val card0 = cardOfCalcExpr(mq, calc, call.getOperands.get(0))
        if (card0 == null) {
          null
        } else {
          val card1 = cardOfCalcExpr(mq, calc, call.getOperands.get(1))
          if (card1 == null) {
            null
          } else {
            Math.max(card0, card1)
          }
        }
      } else if (call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE))) {
        NumberUtil.multiply(
          cardOfCalcExpr(mq, calc, call.getOperands.get(0)),
          cardOfCalcExpr(mq, calc, call.getOperands.get(1)))
      } else if (call.isA(SqlKind.EXTRACT)) {
        val extractUnit = call.getOperands.get(0)
        val timeOperand = call.getOperands.get(1)
        extractUnit match {
          // go https://www.postgresql.org/docs/9.1/static/
          // functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT to get the definitions of timeunits
          case unit: RexLiteral =>
            val unitValue = unit.getValue
            val timeOperandType = timeOperand.getType.getSqlTypeName
            // assume min time is 1970-01-01 00:00:00, max time is 2100-12-31 21:59:59
            unitValue match {
              case YEAR => 130D // [1970, 2100]
              case MONTH => 12D
              case DAY => 31D
              case HOUR => 24D
              case MINUTE => 60D
              case SECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case QUARTER => 4D
              case WEEK => 53D // [1, 53]
              case MILLISECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case MICROSECOND => timeOperandType match {
                case TIMESTAMP | TIME => 60 * 1000D * 1000D // [0.000, 59.999]
                case _ => 60D // [0, 59]
              }
              case DOW => 7D // [0, 6]
              case DOY => 366D // [1, 366]
              case EPOCH => timeOperandType match {
                // the number of seconds since 1970-01-01 00:00:00 UTC
                case TIMESTAMP | TIME => 130 * 24 * 60 * 60 * 1000D
                case _ => 130 * 24 * 60 * 60D
              }
              case DECADE => 13D // The year field divided by 10
              case CENTURY => 2D
              case MILLENNIUM => 2D
              case _ => cardOfCalcExpr(mq, calc, timeOperand)
            }
          case _ => cardOfCalcExpr(mq, calc, timeOperand)
        }
      } else if (call.getOperands.size() == 1) {
        cardOfCalcExpr(mq, calc, call.getOperands.get(0))
      } else {
        if (rowCount != null) rowCount / 10 else null
      }
      if (distinctRowCount == null) {
        null
      } else {
        RelMdUtil.numDistinctVals(distinctRowCount, rowCount)
      }
    }

  }

  /**
    * Takes a bitmap representing a set of input references and extracts the
    * ones that reference the group by columns in an aggregate.
    *
    *
    * @param groupKey the original bitmap
    * @param aggRel   the aggregate
    */
  def setAggChildKeys(
    groupKey: ImmutableBitSet,
    aggRel: Aggregate): (ImmutableBitSet, Array[AggregateCall]) = {
    val childKeyBuilder = ImmutableBitSet.builder
    val aggCalls = new mutable.ArrayBuffer[AggregateCall]()
    val groupSet = aggRel.getGroupSet.toArray
    val (auxGroupSet, otherAggCalls) = checkAndSplitAggCalls(aggRel)
    val fullGroupSet = groupSet ++ auxGroupSet
    // does not need to take keys in aggregate call into consideration if groupKey contains all
    // groupSet element in aggregate
    val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
    groupKey.foreach(
      bit =>
        if (bit < fullGroupSet.length) {
          childKeyBuilder.set(fullGroupSet(bit))
        } else if (!containsAllAggGroupKeys) {
          // getIndicatorCount return 0 if auxGroupSet is not empty
          val agg = otherAggCalls.get(bit - (fullGroupSet.length + aggRel.getIndicatorCount))
          aggCalls += agg
        }
    )
    (childKeyBuilder.build(), aggCalls.toArray)
  }

  /**
    * Takes a bitmap representing a set of input references and extracts the
    * ones that reference the group by columns in an aggregate.
    *
    * @param groupKey the original bitmap
    * @param aggRel   the aggregate
    */
  def setAggChildKeys(
    groupKey: ImmutableBitSet,
    aggRel: BatchExecGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
    require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
    setChildKeysOfAgg(groupKey, aggRel)
  }

  /**
    * Takes a bitmap representing a set of input references and extracts the
    * ones that reference the group by columns in an aggregate.
    *
    * @param groupKey the original bitmap
    * @param aggRel   the aggregate
    */
  def setAggChildKeys(
    groupKey: ImmutableBitSet,
    aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = {
    require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!")
    setChildKeysOfAgg(groupKey, aggRel)
  }

  private def setChildKeysOfAgg(
    groupKey: ImmutableBitSet,
    aggRel: SingleRel): (ImmutableBitSet, Array[AggregateCall]) = {
    val (aggCalls, fullGroupSet) = aggRel match {
      case agg: BatchExecLocalSortWindowAggregate =>
        // grouping + assignTs + auxGrouping
        (agg.getAggCallList,
          agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping)
      case agg: BatchExecLocalHashWindowAggregate =>
        // grouping + assignTs + auxGrouping
        (agg.getAggCallList,
          agg.getGrouping ++ Array(agg.inputTimestampIndex) ++ agg.getAuxGrouping)
      case agg: BatchExecWindowAggregateBase =>
        (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
      case agg: BatchExecGroupAggregateBase =>
        (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping)
      case _ => throw new IllegalArgumentException(s"Unknown relnode type ${aggRel.getRelTypeName}")
    }
    // does not need to take keys in aggregate call into consideration if groupKey contains all
    // groupSet element in aggregate
    val containsAllAggGroupKeys = fullGroupSet.indices.forall(groupKey.get)
    val childKeyBuilder = ImmutableBitSet.builder
    val aggs = new mutable.ArrayBuffer[AggregateCall]()
    groupKey.foreach(
      bit =>
        if (bit < fullGroupSet.length) {
          childKeyBuilder.set(fullGroupSet(bit))
        } else if (!containsAllAggGroupKeys) {
          val agg = aggCalls.get(bit - fullGroupSet.length)
          aggs += agg
        }
    )
    (childKeyBuilder.build(), aggs.toArray)
  }

  /**
    * Takes a bitmap representing a set of local window aggregate references.
    *
    * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
    * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
    *
    * Skips `assignTs` when mapping `groupKey` to `childKey`.
    *
    * @param groupKey the original bitmap
    * @param globalWinAgg the global window aggregate
    */
  def setChildKeysOfWinAgg(
      groupKey: ImmutableBitSet,
      globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = {
    require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!")
    val childKeyBuilder = ImmutableBitSet.builder
    groupKey.toArray.foreach { key =>
      if (key < globalWinAgg.getGrouping.length) {
        childKeyBuilder.set(key)
      } else {
        // skips `assignTs`
        childKeyBuilder.set(key + 1)
      }
    }
    childKeyBuilder.build()
  }

  /**
    * Split groupKeys on Agregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase
    * into keys on aggregate's groupKey and aggregate's aggregateCalls.
    *
    * @param agg      the aggregate
    * @param groupKey the original bitmap
    */
  private[flink] def splitGroupKeysOnAggregate(
    agg: SingleRel,
    groupKey: ImmutableBitSet): (ImmutableBitSet, Array[AggregateCall]) = {

    def removeAuxKey(
      groupKey: ImmutableBitSet,
      groupSet: Array[Int],
      auxGroupSet: Array[Int]): ImmutableBitSet = {
      if (groupKey.contains(ImmutableBitSet.of(groupSet: _*))) {
        // remove auxGroupSet from groupKey if groupKey contain both full-groupSet
        // and (partial-)auxGroupSet
        groupKey.except(ImmutableBitSet.of(auxGroupSet: _*))
      } else {
        groupKey
      }
    }

    agg match {
      case rel: Aggregate =>
        val (auxGroupSet, _) = FlinkRelOptUtil.checkAndSplitAggCalls(rel)
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGroupSet.toArray, auxGroupSet)
        (childKeyExcludeAuxKey, aggCalls)
      case rel: BatchExecGroupAggregateBase =>
        // set the bits as they correspond to the child input
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
        (childKeyExcludeAuxKey, aggCalls)
      case rel: BatchExecWindowAggregateBase =>
        val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel)
        val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping)
        (childKeyExcludeAuxKey, aggCalls)
      case _ => throw new IllegalArgumentException(
        s"Unknown aggregate type: ${agg.getRelTypeName}.")
    }
  }

  /**
    * Shifts every [[RexInputRef]] in an expression higher than length of full grouping
    * (for skips `assignTs`).
    *
    * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties
    * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls
    *
    * @param predicate a RexNode
    * @param globalWinAgg the global window aggregate
    */
  def setChildPredicateOfWinAgg(
      predicate: RexNode,
      globalWinAgg: BatchExecWindowAggregateBase): RexNode = {
    require(globalWinAgg.isMerge, "Cannot handle global agg which does not have local agg!")
    if (predicate == null) {
      return null
    }
    // grouping + assignTs + auxGrouping
    val fullGrouping = globalWinAgg.getGrouping ++ globalWinAgg.getAuxGrouping
    // skips `assignTs`
    RexUtil.shift(predicate, fullGrouping.length, 1)
  }

  /**
    * Split a predicate on Aggregate into two parts, the first one is pushable part,
    * the second one is rest part.
    *
    * @param agg       Aggregate which to analyze
    * @param predicate Predicate which to analyze
    * @return a tuple, first element is pushable part, second element is rest part.
    *         Note, pushable condition will be converted based on the input field position.
    */
  def splitPredicateOnAggregate(
      agg: Aggregate,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    val fullGroupSet = checkAndGetFullGroupSet(agg)
    splitPredicateOnAgg(fullGroupSet, agg, predicate)
  }

  /**
    * Split a predicate on BatchExecGroupAggregateBase into two parts,
    * the first one is pushable part, the second one is rest part.
    *
    * @param agg       Aggregate which to analyze
    * @param predicate Predicate which to analyze
    * @return a tuple, first element is pushable part, second element is rest part.
    *         Note, pushable condition will be converted based on the input field position.
    */
  def splitPredicateOnAggregate(
      agg: BatchExecGroupAggregateBase,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
  }

  /**
    * Split a predicate on WindowAggregateBatchExecBase into two parts,
    * the first one is pushable part, the second one is rest part.
    *
    * @param agg       Aggregate which to analyze
    * @param predicate Predicate which to analyze
    * @return a tuple, first element is pushable part, second element is rest part.
    *         Note, pushable condition will be converted based on the input field position.
    */
  def splitPredicateOnAggregate(
      agg: BatchExecWindowAggregateBase,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate)
  }

  private def splitPredicateOnAgg(
      grouping: Array[Int],
      agg: SingleRel,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    val notPushable = new util.ArrayList[RexNode]
    val pushable = new util.ArrayList[RexNode]
    val numOfGroupKey = grouping.length
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, numOfGroupKey),
      predicate,
      pushable,
      notPushable)
    val rexBuilder = agg.getCluster.getRexBuilder
    val childPred = if (pushable.isEmpty) {
      None
    } else {
      // Converts a list of expressions that are based on the output fields of a
      // Aggregate to equivalent expressions on the Aggregate's input fields.
      val aggOutputFields = agg.getRowType.getFieldList
      val aggInputFields = agg.getInput.getRowType.getFieldList
      val adjustments = new Array[Int](aggOutputFields.size)
      grouping.zipWithIndex foreach {
        case (bit, index) => adjustments(index) = bit - index
      }
      val pushableConditions = pushable map {
        pushCondition =>
          pushCondition.accept(
            new RelOptUtil.RexInputConverter(
              rexBuilder,
              aggOutputFields,
              aggInputFields,
              adjustments))
      }
      Option(RexUtil.composeConjunction(rexBuilder, pushableConditions, true))
    }
    val restPred = if (notPushable.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, notPushable, true))
    }
    (childPred, restPred)
  }

  def getRankFunColumnIndex(rank: Rank): Int = {
    rank match {
      case r: LogicalRank => getRankFunColumnIndex(rank, outputRankFunColumn = true)
      case r: FlinkLogicalRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
      case r: BatchExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
      case r: StreamExecRank => getRankFunColumnIndex(rank, r.outputRankFunColumn)
    }
  }

  private def getRankFunColumnIndex(rank: Rank, outputRankFunColumn: Boolean): Int = {
    if (outputRankFunColumn) {
      require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount + 1)
      rank.getRowType.getFieldCount - 1
    } else {
      require(rank.getRowType.getFieldCount == rank.getInput.getRowType.getFieldCount)
      -1
    }
  }

  def splitPredicateOnRank(
      rank: Rank,
      predicate: RexNode): (Option[RexNode], Option[RexNode]) = {
    val rankFunColumnIndex = getRankFunColumnIndex(rank)
    if (predicate == null || predicate.isAlwaysTrue || rankFunColumnIndex < 0) {
      return (Some(predicate), None)
    }

    val rankNodes = new util.ArrayList[RexNode]
    val nonRankNodes = new util.ArrayList[RexNode]
    RelOptUtil.splitFilters(
      ImmutableBitSet.range(0, rankFunColumnIndex),
      predicate,
      nonRankNodes,
      rankNodes)
    val rexBuilder = rank.getCluster.getRexBuilder
    val nonRankPred = if (nonRankNodes.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, nonRankNodes, true))
    }
    val rankPred = if (rankNodes.isEmpty) {
      None
    } else {
      Option(RexUtil.composeConjunction(rexBuilder, rankNodes, true))
    }
    (nonRankPred, rankPred)
  }

  def getRankRangeNdv(rankRange: RankRange): Double = rankRange match {
    case r: ConstantRankRange => (r.rankEnd - r.rankStart + 1).toDouble
    case _ => 100D // default value now
  }

  /** Splits a column set between left and right sets. */
  def splitColumnsIntoLeftAndRight(
      leftCount: Int,
      columns: ImmutableBitSet): (ImmutableBitSet, ImmutableBitSet) = {
    val leftBuilder = ImmutableBitSet.builder
    val rightBuilder = ImmutableBitSet.builder
    columns.foreach {
      bit => if (bit < leftCount) leftBuilder.set(bit) else rightBuilder.set(bit - leftCount)
    }
    (leftBuilder.build, rightBuilder.build)
  }


  /**
    * Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
    *
    * the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount
    * from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095,
    * when groupCount is 2, the formula result is 0.18,
    * when groupCount is 3, the formula result is 0.25.
    * ...
    *
    * @param groupingLength grouping keys length of aggregate
    * @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable.
    */
  def getAggregationRatioIfNdvUnavailable(groupingLength: Int): Double =
    1.0 - math.exp(-0.1 * groupingLength)

  /**
    * Estimates outputRowCount of local aggregate.
    *
    * output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption:
    * 1. even distribution of all distinct data
    * 2. even distribution of all data in each concurrent local agg worker
    *
    * @param parallelism       number of concurrent worker of local aggregate
    * @param inputRowCount     rowcount of input node of aggregate.
    * @param globalAggRowCount rowcount of output of global aggregate.
    * @return outputRowCount of local aggregate.
    */
  def getRowCountOfLocalAgg(
      parallelism: Int,
      inputRowCount: Double,
      globalAggRowCount: Double): Double =
    Math.min((1 - math.pow(1 - 1.0 / parallelism, inputRowCount / globalAggRowCount))
        * globalAggRowCount * parallelism, inputRowCount)

  /**
    * Estimates new distinctRowCount of currentNode after it applies a condition.
    * The estimation based on one assumption:
    * even distribution of all distinct data
    *
    * @param rowCount         rowcount of node.
    * @param distinctRowCount distinct rowcount of node.
    * @param selectivity      selectivity of condition expression.
    * @return new distinctRowCount
    */
  def adaptNdvBasedOnSelectivity(
      rowCount: Double,
      distinctRowCount: Double,
      selectivity: Double): Double = {
    val ndv = Math.min(distinctRowCount, rowCount)
    (1 - Math.pow(1 - selectivity, rowCount / ndv)) * ndv
  }

  /**
    * Returns [[RexInputRef]] index set of projects corresponding to the given column index.
    * The index will be set as -1 if the given column in project is not a [[RexInputRef]].
    */
  def getInputRefIndices(index: Int, expand: Expand): util.Set[Int] = {
    val inputRefs = new util.HashSet[Int]()
    for (project <- expand.projects) {
      project.get(index) match {
        case inputRef: RexInputRef => inputRefs.add(inputRef.getIndex)
        case _ => inputRefs.add(-1)
      }
    }
    inputRefs
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy