All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.metadata.FlinkRelMdRowCount.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.metadata

import org.apache.flink.table.expressions.ExpressionUtils._
import org.apache.flink.table.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.stats.ValueInterval
import org.apache.flink.table.plan.util.AggregateUtil._
import org.apache.flink.table.plan.util.FlinkRelMdUtil._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, FlinkRelOptUtil}
import org.apache.flink.table.util.NodeResourceUtil

import org.apache.calcite.adapter.enumerable.EnumerableLimit
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.util._

import java.lang.Double

import scala.collection.JavaConversions._

class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCount] {

  def getDef: MetadataDef[BuiltInMetadata.RowCount] = BuiltInMetadata.RowCount.DEF

  def getRowCount(rel: Expand, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)

  def getRowCount(rel: Rank, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)

  def getRowCount(rel: Aggregate, mq: RelMetadataQuery): Double = {
    getRowCountOfAgg(rel, rel.getGroupSet, rel.getGroupSets.size(), mq)._1
  }

  /**
    * Get output rowCount and input rowCount of agg
    *
    * @param rel           agg relNode
    * @param groupSet      agg groupSet
    * @param groupSetsSize agg groupSets count
    * @param mq            metadata query
    * @return a tuple, the first element is output rowCount, second one is input rowCount
    */
  private def getRowCountOfAgg(
      rel: SingleRel,
      groupSet: ImmutableBitSet,
      groupSetsSize: Int,
      mq: RelMetadataQuery): (Double, Double) = {
    val childRowCount = mq.getRowCount(rel.getInput)
    if (groupSet.cardinality() == 0) {
      return (1.0, childRowCount)
    }

    // rowCount is the cardinality of the group by columns
    val distinctRowCount = mq.getDistinctRowCount(rel.getInput, groupSet, null)
    val groupCount = groupSet.cardinality()
    val d: Double = if (distinctRowCount == null) {
      NumberUtil.multiply(childRowCount,
        FlinkRelMdUtil.getAggregationRatioIfNdvUnavailable(groupCount))
    } else {
      NumberUtil.min(distinctRowCount, childRowCount)
    }

    if (d != null) {
      // Grouping sets multiply
      (d * groupSetsSize, childRowCount)
    } else {
      (null, childRowCount)
    }
  }

  def getRowCount(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): Double = {
    getRowCountOfBatchExecAgg(rel, mq)
  }

  private def getRowCountOfBatchExecAgg(rel: SingleRel, mq: RelMetadataQuery): Double = {
    val (grouping, isFinal, isMerge) = rel match {
      case agg: BatchExecGroupAggregateBase =>
        (ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge)
      case windowAgg: BatchExecWindowAggregateBase =>
        (ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge)
      case _ => throw new IllegalArgumentException(s"Unknown node type ${rel.getRelTypeName}!")
    }
    val ndvOfGroupKeysOnGlobalAgg: Double = if (grouping.isEmpty) {
      1.0
    } else {
      // rowCount is the cardinality of the group by columns
      val distinctRowCount = mq.getDistinctRowCount(rel.getInput, grouping, null)
      val childRowCount = mq.getRowCount(rel.getInput)
      if (distinctRowCount == null) {
        if (isFinal && isMerge) {
          // Avoid apply aggregation ratio twice when calculate row count of global agg
          // which has local agg.
          childRowCount
        } else {
          NumberUtil.multiply(childRowCount, getAggregationRatioIfNdvUnavailable(grouping.length))
        }
      } else {
        NumberUtil.min(distinctRowCount, childRowCount)
      }
    }
    if (isFinal) {
      ndvOfGroupKeysOnGlobalAgg
    } else {
      val childRowCount = mq.getRowCount(rel.getInput)
      val tableConfig = FlinkRelOptUtil.getTableConfig(rel)
      val nParallelism = NodeResourceUtil.calOperatorParallelism(childRowCount, tableConfig.getConf)
      if (nParallelism == 1) {
        ndvOfGroupKeysOnGlobalAgg
      } else if (grouping.isEmpty) {
        // output rowcount of local agg is parallelism for agg which has no group keys
        nParallelism.toDouble
      } else {
        val distinctRowCount = mq.getDistinctRowCount(rel.getInput, grouping, null)
        if (distinctRowCount == null) {
          ndvOfGroupKeysOnGlobalAgg
        } else {
          getRowCountOfLocalAgg(nParallelism, childRowCount, ndvOfGroupKeysOnGlobalAgg)
        }
      }
    }
  }

  def getRowCount(rel: FlinkLogicalWindowAggregate, mq: RelMetadataQuery): Double = {
    getRowCountOfWindowAgg(rel, rel.getWindow, rel.getGroupSet, mq)
  }

  def getRowCount(rel: LogicalWindowAggregate, mq: RelMetadataQuery): Double = {
    getRowCountOfWindowAgg(rel, rel.getWindow, rel.getGroupSet, mq)
  }

  def getRowCount(rel: BatchExecWindowAggregateBase, mq: RelMetadataQuery): Double = {
    val ndvOfGroupKeys = getRowCountOfBatchExecAgg(rel, mq)
    val inputRowCount = mq.getRowCount(rel.getInput)
    estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, rel.getWindow)
  }

  private def getRowCountOfWindowAgg(
    windowAgg: SingleRel,
    window: LogicalWindow,
    grouping: ImmutableBitSet,
    mq: RelMetadataQuery): Double = {
    val (ndvOfGroupKeys, inputRowCount) = getRowCountOfAgg(windowAgg, grouping, 1, mq)
    estimateRowCountOfWindowAgg(ndvOfGroupKeys, inputRowCount, window)
  }

  private def estimateRowCountOfWindowAgg(
      ndv: Double,
      inputRowCount: Double,
      window: LogicalWindow): Double = {
    if (ndv == null) {
      null
    } else {
      // simply assume expand factor of TumblingWindow/SessionWindow/SlideWindowWithoutOverlap is 2
      // SlideWindowWithOverlap is 4.
      // Introduce expand factor here to distinguish output rowCount of normal agg with all kinds of
      // window aggregates.
      val expandFactorOfTumblingWindow = 2D
      val expandFactorOfNoOverLapSlidingWindow = 2D
      val expandFactorOfOverLapSlidingWindow = 4D
      val expandFactorOfSessionWindow = 2D
      window match {
        case TumblingGroupWindow(_, _, size) if isTimeIntervalLiteral(size) =>
          Math.min(expandFactorOfTumblingWindow * ndv, inputRowCount)
        case SlidingGroupWindow(_, _, size, slide) if isTimeIntervalLiteral(size) =>
          val sizeV = asLong(size)
          val slideV = asLong(slide)
          if (sizeV > slideV) {
            // only slideWindow which has overlap may generates more records than input
            expandFactorOfOverLapSlidingWindow * ndv
          } else {
            Math.min(expandFactorOfNoOverLapSlidingWindow * ndv, inputRowCount)
          }
        case _ => Math.min(expandFactorOfSessionWindow * ndv, inputRowCount)
      }
    }
  }

  def getRowCount(rel: BatchExecOverAggregate, mq: RelMetadataQuery): Double =
    getRowCountOfOverWindow(rel, mq)

  def getRowCount(rel: Window, mq: RelMetadataQuery): Double =
    getRowCountOfOverWindow(rel, mq)

  private def getRowCountOfOverWindow(overWindow: SingleRel, mq: RelMetadataQuery): Double =
    mq.getRowCount(overWindow.getInput)

  def getRowCount(join: Join, mq: RelMetadataQuery): Double = {
    val leftChild = join.getLeft
    val rightChild = join.getRight
    val leftRowCount = mq.getRowCount(leftChild)
    val rightRowCount = mq.getRowCount(rightChild)
    if (leftRowCount == null || rightRowCount == null) {
      return null
    }

    val joinInfo = JoinInfo.of(leftChild, rightChild, join.getCondition)
    if (joinInfo.leftSet().nonEmpty) {
      val innerJoinRowCount = getEquiInnerJoinRowCount(join, mq, leftRowCount, rightRowCount)
      require(innerJoinRowCount != null)
      // Make sure outputRowCount won't be too small based on join type.
      join.getJoinType match {
        case JoinRelType.INNER => innerJoinRowCount
        case JoinRelType.LEFT =>
          // All rows from left side should be in the result.
          math.max(leftRowCount, innerJoinRowCount)
        case JoinRelType.RIGHT =>
          // All rows from right side should be in the result.
          math.max(rightRowCount, innerJoinRowCount)
        case JoinRelType.FULL =>
          // T(A FULL JOIN B) = T(A LEFT JOIN B) + T(A RIGHT JOIN B) - T(A INNER JOIN B)
          math.max(leftRowCount, innerJoinRowCount) +
            math.max(rightRowCount, innerJoinRowCount) - innerJoinRowCount
      }
    } else {
      val rexBuilder = join.getCluster.getRexBuilder
      val crossJoin = copyJoinWithNewCondition(join, rexBuilder.makeLiteral(true))
      val selectivity = mq.getSelectivity(crossJoin, join.getCondition)
      (leftRowCount * rightRowCount) * selectivity
    }
  }

  private def getEquiInnerJoinRowCount(
      join: Join,
      mq: RelMetadataQuery,
      leftRowCount: Double,
      rightRowCount: Double): Double = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val leftChild = join.getLeft
    val rightChild = join.getRight
    val rexBuilder = join.getCluster.getRexBuilder
    val condition = join.getCondition
    val joinInfo = JoinInfo.of(leftChild, rightChild, condition)
    // the leftKeys length equals to rightKeys, so it's ok to only check leftKeys length
    require(joinInfo.leftKeys.nonEmpty)

    val joinKeyDisjoint = joinInfo.leftKeys.zip(joinInfo.rightKeys).exists {
      case (leftKey, rightKey) =>
        val leftInterval = fmq.getColumnInterval(leftChild, leftKey)
        val rightInterval = fmq.getColumnInterval(rightChild, rightKey)
        if (leftInterval != null && rightInterval != null) {
          !ValueInterval.isIntersected(leftInterval, rightInterval)
        } else {
          false
        }
    }
    // One of the join key pairs is disjoint, thus the two sides of join is disjoint.
    if (joinKeyDisjoint) {
      return 0D
    }

    val leftKeySet = joinInfo.leftSet()
    val rightKeySet = joinInfo.rightSet()
    val leftNdv = fmq.getDistinctRowCount(leftChild, leftKeySet, null)
    val rightNdv = fmq.getDistinctRowCount(rightChild, rightKeySet, null)
    // estimate selectivity of non-equi
    val selectivityOfNonEquiPred: Double = if (joinInfo.isEqui) {
      1D
    } else {
      val nonEquiPred = joinInfo.getRemaining(rexBuilder)
      val equiPred = RelMdUtil.minusPreds(rexBuilder, condition, nonEquiPred)
      val joinWithOnlyEquiPred = copyJoinWithNewCondition(join, equiPred)
      fmq.getSelectivity(joinWithOnlyEquiPred, nonEquiPred)
    }

    if (leftNdv != null && rightNdv != null) {
      // selectivity of equi part is 1 / Max(leftNdv, rightNdv)
      val selectivityOfEquiPred = Math.min(1D, 1D / Math.max(leftNdv, rightNdv))
      return leftRowCount * rightRowCount * selectivityOfEquiPred * selectivityOfNonEquiPred
    }

    val leftKeysAreUnique = fmq.areColumnsUnique(leftChild, leftKeySet)
    val rightKeysAreUnique = fmq.areColumnsUnique(rightChild, rightKeySet)
    if (leftKeysAreUnique != null && rightKeysAreUnique != null &&
      (leftKeysAreUnique || rightKeysAreUnique)) {
      val outputRowCount = if (leftKeysAreUnique && rightKeysAreUnique) {
        // if both leftKeys and rightKeys are both unique,
        // rowCount = Min(leftRowCount) * selectivity of non-equi
        Math.min(leftRowCount, rightRowCount) * selectivityOfNonEquiPred
      } else if (leftKeysAreUnique) {
        rightRowCount * selectivityOfNonEquiPred
      } else {
        leftRowCount * selectivityOfNonEquiPred
      }
      return outputRowCount
    }

    // if joinCondition has no ndv stats and no uniqueKeys stats,
    // rowCount = (leftRowCount + rightRowCount) * join condition selectivity
    val crossJoin = copyJoinWithNewCondition(join, rexBuilder.makeLiteral(true))
    val selectivity = fmq.getSelectivity(crossJoin, condition)
    (leftRowCount + rightRowCount) * selectivity
  }

  private def copyJoinWithNewCondition(join: Join, newCondition: RexNode): Join = {
    join.copy(
      join.getTraitSet,
      newCondition,
      join.getLeft,
      join.getRight,
      join.getJoinType,
      join.isSemiJoinDone)
  }

  def getRowCount(rel: SemiJoin, mq: RelMetadataQuery): Double = {
    val semiJoinSelectivity = FlinkRelMdUtil.makeSemiJoinSelectivityRexNode(mq, rel)
    NumberUtil.multiply(
      mq.getSelectivity(rel.getLeft, semiJoinSelectivity),
      mq.getRowCount(rel.getLeft))
  }

  /** Catch-all implementation for
    * [[BuiltInMetadata.RowCount#getRowCount()]],
    * invoked using reflection.
    *
    * @see org.apache.calcite.rel.metadata.RelMetadataQuery#getRowCount(RelNode)
    */
  def getRowCount(rel: RelNode, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)

  def getRowCount(subset: RelSubset, mq: RelMetadataQuery): Double = {
    if (!Bug.CALCITE_1048_FIXED) {
      return mq.getRowCount(Util.first(subset.getBest, subset.getOriginal))
    }
    val v = subset.getRels.foldLeft(null.asInstanceOf[Double]) {
      (min, r) =>
        try {
          NumberUtil.min(min, mq.getRowCount(r))
        } catch {
          // ignore this rel; there will be other, non-cyclic ones
          case e: CyclicMetadataException => min
          case e: Throwable =>
            e.printStackTrace()
            min
        }
    }
    // if set is empty, estimate large
    Util.first(v, 1e6d)
  }

  def getRowCount(rel: Union, mq: RelMetadataQuery): Double = {
    val rowCounts = rel.getInputs.map(mq.getRowCount)
    if (rowCounts.contains(null)) {
      null
    } else {
      rowCounts.foldLeft(0D)(_ + _)
    }
  }

  def getRowCount(rel: Intersect, mq: RelMetadataQuery): Double = {
    rel.getInputs.foldLeft(null.asInstanceOf[Double])((res, r) => {
      val partialRowCount = mq.getRowCount(r)
      if (res == null || (partialRowCount != null && partialRowCount < res)) {
        partialRowCount
      } else {
        res
      }
    })
  }

  def getRowCount(rel: Minus, mq: RelMetadataQuery): Double = {
    rel.getInputs.foldLeft(null.asInstanceOf[Double])((res, r) => {
      val partialRowCount = mq.getRowCount(r)
      if (res == null || (partialRowCount != null && partialRowCount < res)) {
        partialRowCount
      } else {
        res
      }
    })
  }

  def getRowCount(rel: Filter, mq: RelMetadataQuery): Double =
    RelMdUtil.estimateFilteredRows(rel.getInput, rel.getCondition, mq)

  def getRowCount(rel: Calc, mq: RelMetadataQuery): Double =
    RelMdUtil.estimateFilteredRows(rel.getInput, rel.getProgram, mq)

  def getRowCount(rel: Project, mq: RelMetadataQuery): Double = mq.getRowCount(rel.getInput)

  def getRowCount(rel: Sort, mq: RelMetadataQuery): Double = {
    var rowCount = mq.getRowCount(rel.getInput)
    if (rowCount == null) {
      return null
    }
    val offset = if (rel.offset == null) 0 else RexLiteral.intValue(rel.offset)
    rowCount = Math.max(rowCount - offset, 0D)
    if (rel.fetch != null) {
      val limit = RexLiteral.intValue(rel.fetch)
      if (limit < rowCount) {
        return limit.toDouble
      }
    }
    rowCount
  }

  def getRowCount(rel: EnumerableLimit, mq: RelMetadataQuery): Double = {
    var rowCount: Double = mq.getRowCount(rel.getInput)
    if (rowCount == null) {
      return null
    }
    val offset = if (rel.offset == null) 0 else RexLiteral.intValue(rel.offset)
    rowCount = Math.max(rowCount - offset, 0D)
    if (rel.fetch != null) {
      val limit = RexLiteral.intValue(rel.fetch)
      if (limit < rowCount) {
        return limit.toDouble
      }
    }
    rowCount
  }

  def getRowCount(rel: SingleRel, mq: RelMetadataQuery): Double = mq.getRowCount(rel.getInput)

  def getRowCount(rel: TableScan, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)

  def getRowCount(rel: Values, mq: RelMetadataQuery): Double = rel.estimateRowCount(mq)

}

object FlinkRelMdRowCount {

  private val INSTANCE = new FlinkRelMdRowCount

  val SOURCE: RelMetadataProvider = ReflectiveRelMetadataProvider.reflectiveSource(
    BuiltInMethod.ROW_COUNT.method, INSTANCE)

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy