All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.util.RankUtil.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.util

import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.table.api.dataview.Order
import org.apache.flink.table.api.types.{DataTypes, RowType, TypeConverters}
import org.apache.flink.table.api.{TableConfig, TableConfigOptions, TableException}
import org.apache.flink.table.codegen._
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecRank
import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.runtime.aggregate.SorterHelper
import org.apache.flink.table.runtime.{BinaryRowKeySelector, NullBinaryRowKeySelector}
import org.apache.flink.table.typeutils.BaseRowTypeInfo

import org.apache.calcite.plan.{RelOptCluster, RelOptUtil}
import org.apache.calcite.rel.RelFieldCollation.Direction
import org.apache.calcite.rel.{RelCollation, RelFieldCollation}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.sql.validate.SqlMonotonicity

import java.util

import scala.collection.JavaConversions._

/**
  * An util class to optimize and generate TopN operators.
  */
object RankUtil {

  case class LimitPredicate(rankOnLeftSide: Boolean, pred: RexCall)

  /**
    * Extracts the TopN offset and fetch bounds from a predicate.
    *
    * @param  predicate           predicate
    * @param  rankFieldIndex      the index of rank field
    * @param  rexBuilder          RexBuilder
    * @param  config              TableConfig
    * @return A Tuple2 of extracted rank range and remaining predicates.
    */
  def extractRankRange(
      predicate: RexNode,
      rankFieldIndex: Int,
      rexBuilder: RexBuilder,
      config: TableConfig): (Option[RankRange], Option[RexNode]) = {

    // Converts the condition to conjunctive normal form (CNF)
    val cnfCondition = FlinkRexUtil.toCnf(rexBuilder,
      config.getConf.getInteger(TableConfigOptions.SQL_OPTIMIZER_CNF_NODES_LIMIT),
      predicate)

    // split the condition into sort limit condition and other condition
    val (limitPreds: Seq[LimitPredicate], otherPreds: Seq[RexNode]) = cnfCondition match {
      case c: RexCall if c.getKind == SqlKind.AND =>
        c.getOperands
          .map(identifyLimitPredicate(_, rankFieldIndex))
          .foldLeft((Seq[LimitPredicate](), Seq[RexNode]())) {
            (preds, analyzed) =>
              analyzed match {
                case Left(limitPred) => (preds._1 :+ limitPred, preds._2)
                case Right(otherPred) => (preds._1, preds._2 :+ otherPred)
              }
          }
      case rex: RexNode =>
        identifyLimitPredicate(rex, rankFieldIndex) match {
          case Left(limitPred) => (Seq(limitPred), Seq())
          case Right(otherPred) => (Seq(), Seq(otherPred))
        }
      case _ =>
        return (None, Some(predicate))
    }

    if (limitPreds.isEmpty) {
      // no valid TopN bounds.
      return (None, Some(predicate))
    }

    val sortBounds = limitPreds.map(computeWindowBoundFromPredicate(_, rexBuilder, config))
    val rankRange = sortBounds match {
      case Seq(Some(LowerBoundary(x)), Some(UpperBoundary(y))) =>
        ConstantRankRange(x, y)
      case Seq(Some(UpperBoundary(x)), Some(LowerBoundary(y))) =>
        ConstantRankRange(y, x)
      case Seq(Some(LowerBoundary(x))) =>
        // only offset
        ConstantRankRangeWithoutEnd(x)
      case Seq(Some(UpperBoundary(x))) =>
        // rankStart starts from one
        ConstantRankRange(1, x)
      case Seq(Some(BothBoundary(x, y))) =>
        // nth rank
        ConstantRankRange(x, y)
      case Seq(Some(InputRefBoundary(x))) =>
        VariableRankRange(x)
      case _ =>
        // TopN requires at least one rank comparison predicate
        return (None, Some(predicate))
    }

    val remainCondition = otherPreds match {
      case Seq() =>
        None
      case _ =>
        Some(otherPreds.reduceLeft((l, r) => RelOptUtil.andJoinFilters(rexBuilder, l, r)))
    }

    (Some(rankRange), remainCondition)
  }

  /**
    * Analyzes a predicate and identifies whether it is a valid predicate for a TopN.
    * A valid TopN predicate is a comparison predicate (<, <=, =>, >) or equal predicate
    * that accesses rank fields of input rel node, the rank field reference must be on
    * one side of the condition alone.
    *
    * Examples:
    * - rank <= 10 => valid (Top 10)
    * - rank + 1 <= 10 => invalid: rank is not alone in the condition
    * - rank == 10 => valid (10th)
    * - rank <= rank + 2 => invalid: rank on same side
    *
    * @return Either a valid time predicate (Left) or a valid non-time predicate (Right)
    */
  private def identifyLimitPredicate(
      pred: RexNode,
      rankFieldIndex: Int): Either[LimitPredicate, RexNode] = pred match {
    case c: RexCall =>
      c.getKind match {
        case SqlKind.GREATER_THAN |
             SqlKind.GREATER_THAN_OR_EQUAL |
             SqlKind.LESS_THAN |
             SqlKind.LESS_THAN_OR_EQUAL |
             SqlKind.EQUALS =>

          val leftTerm = c.getOperands.head
          val rightTerm = c.getOperands.last

          if (isRankFieldRef(leftTerm, rankFieldIndex) &&
            !accessesRankField(rightTerm, rankFieldIndex)) {
            Left(LimitPredicate(rankOnLeftSide = true, c))
          } else if (isRankFieldRef(rightTerm, rankFieldIndex) &&
            !accessesRankField(leftTerm, rankFieldIndex)) {
            Left(LimitPredicate(rankOnLeftSide = false, c))
          } else {
            Right(pred)
          }

        // not a comparison predicate.
        case _ => Right(pred)
      }
    case _ => Right(pred)
  }

  // checks if the expression is the rank field reference
  def isRankFieldRef(expr: RexNode, rankFieldIndex: Int): Boolean = expr match {
    case i: RexInputRef => i.getIndex == rankFieldIndex
    case _ => false
  }

  /**
    * Checks if an expression accesses a rank field.
    *
    * @param expr The expression to check.
    * @param rankFieldIndex The rank field index.
    * @return True, if the expression accesses a time attribute. False otherwise.
    */
  def accessesRankField(expr: RexNode, rankFieldIndex: Int): Boolean = expr match {
    case i: RexInputRef => i.getIndex == rankFieldIndex
    case c: RexCall => c.operands.exists(accessesRankField(_, rankFieldIndex))
    case _ => false
  }


  sealed trait Boundary

  case class LowerBoundary(lower: Long) extends Boundary

  case class UpperBoundary(upper: Long) extends Boundary

  case class BothBoundary(lower: Long, upper: Long) extends Boundary

  case class InputRefBoundary(inputFieldIndex: Int) extends Boundary

  sealed trait BoundDefine

  object Lower extends BoundDefine // defined lower bound
  object Upper extends BoundDefine // defined upper bound
  object Both extends BoundDefine // defined lower and uppper bound

  /**
    * Computes the absolute bound on the left operand of a comparison expression and
    * whether the bound is an upper or lower bound.
    *
    * @return sort boundary (lower boundary, upper boundary)
    */
  private def computeWindowBoundFromPredicate(
      limitPred: LimitPredicate,
      rexBuilder: RexBuilder,
      config: TableConfig): Option[Boundary] = {

    val bound: BoundDefine = limitPred.pred.getKind match {
      case SqlKind.GREATER_THAN | SqlKind.GREATER_THAN_OR_EQUAL if limitPred.rankOnLeftSide =>
        Lower
      case SqlKind.GREATER_THAN | SqlKind.GREATER_THAN_OR_EQUAL if !limitPred.rankOnLeftSide =>
        Upper
      case SqlKind.LESS_THAN | SqlKind.LESS_THAN_OR_EQUAL if limitPred.rankOnLeftSide =>
        Upper
      case SqlKind.LESS_THAN | SqlKind.LESS_THAN_OR_EQUAL if !limitPred.rankOnLeftSide =>
        Lower
      case SqlKind.EQUALS =>
        Both
    }

    val predExpression = if (limitPred.rankOnLeftSide) {
      limitPred.pred.operands.get(1)
    } else {
      limitPred.pred.operands.get(0)
    }

    (predExpression, bound) match {
      case (r: RexInputRef, Upper | Both) => Some(InputRefBoundary(r.getIndex))
      case (_: RexInputRef, Lower) => None
      case _ =>
        // reduce predicate to constants to compute bounds
        val literal = reduceComparisonPredicate(limitPred, rexBuilder, config)
        if (literal.isEmpty) {
          None
        } else {
          // compute boundary
          val tmpBoundary: Long = literal.get
          val boundary = limitPred.pred.getKind match {
            case SqlKind.LESS_THAN if limitPred.rankOnLeftSide =>
              tmpBoundary - 1
            case SqlKind.LESS_THAN =>
              tmpBoundary + 1
            case SqlKind.GREATER_THAN if limitPred.rankOnLeftSide =>
              tmpBoundary + 1
            case SqlKind.GREATER_THAN =>
              tmpBoundary - 1
            case _ =>
              tmpBoundary
          }
          bound match {
            case Lower => Some(LowerBoundary(boundary))
            case Upper => Some(UpperBoundary(boundary))
            case Both => Some(BothBoundary(boundary, boundary))
          }
        }
    }
  }

  /**
    * Replaces the rank aggregate reference on of a predicate by a zero literal and
    * reduces the expressions on both sides to a long literal.
    *
    * @param limitPred The limit predicate which both sides are reduced.
    * @param rexBuilder A RexBuilder
    * @param config A TableConfig.
    * @return The values of the reduced literals on both sides of the comparison predicate.
    */
  private def reduceComparisonPredicate(
      limitPred: LimitPredicate,
      rexBuilder: RexBuilder,
      config: TableConfig): Option[Long] = {

    val expression = if (limitPred.rankOnLeftSide) {
      limitPred.pred.operands.get(1)
    } else {
      limitPred.pred.operands.get(0)
    }

    if (!RexUtil.isConstant(expression)) {
      return None
    }

    // reduce expression to literal
    val exprReducer = new ExpressionReducer(config)
    val originList = new util.ArrayList[RexNode]()
    originList.add(expression)
    val reduceList = new util.ArrayList[RexNode]()
    exprReducer.reduce(rexBuilder, originList, reduceList)

    // extract bounds from reduced literal
    val literals = reduceList.map {
      case literal: RexLiteral => Some(literal.getValue2.asInstanceOf[Long])
      case _ => None
    }

    literals.head
  }

  def getOrderFromFieldCollation(field: RelFieldCollation): Order = {
    field.getDirection match {
      case RelFieldCollation.Direction.ASCENDING
           | RelFieldCollation.Direction.STRICTLY_ASCENDING =>
        Order.ASCENDING

      case RelFieldCollation.Direction.DESCENDING
           | RelFieldCollation.Direction.STRICTLY_DESCENDING =>
        Order.DESCENDING

      case _ =>
        //Shouldn't happen
        throw new TableException(
          "Couldn't get correct sort field direction. Shouldn't happen here.")
    }
  }

  def createSortKeyTypeAndSorter(
      inputSchema: BaseRowSchema,
      fieldCollations: Seq[RelFieldCollation]): (BaseRowTypeInfo, GeneratedSorter) = {

    val (sortFields, sortDirections, nullsIsLast) = SortUtil.getKeysAndOrders(fieldCollations)

    val inputFieldTypes = inputSchema.fieldTypeInfos
    val fieldTypes = sortFields.map(inputFieldTypes(_))
    val sortKeyTypeInfo = new BaseRowTypeInfo(fieldTypes: _*)

    val logicalKeyFields = fieldTypes.indices.toArray

    val sorter = SorterHelper.createSorter(
      sortKeyTypeInfo.getFieldTypes.map(TypeConverters.createInternalTypeFromTypeInfo),
      logicalKeyFields,
      sortDirections,
      nullsIsLast)

    (sortKeyTypeInfo, sorter)
  }

  def createSortKeySelector(
      fieldCollations: Seq[RelFieldCollation],
      inputSchema: BaseRowSchema): KeySelector[BaseRow, BaseRow] = {

    val sortFields = fieldCollations.map(_.getFieldIndex).toArray
    if (sortFields.isEmpty) {
      new NullBinaryRowKeySelector
    } else {
      new BinaryRowKeySelector(sortFields, inputSchema.typeInfo())
    }
  }

  def getUnarySortKeyExtractor(
      fieldCollations: Seq[RelFieldCollation],
      inputSchema: BaseRowSchema): GeneratedFieldExtractor = {

    val sortFields = fieldCollations.map(_.getFieldIndex).toArray
    if (sortFields.length != 1) {
      throw new TableException("[rank util] This shouldn't happen. Please file an issue.")
    }

    val inputType: BaseRowTypeInfo = inputSchema.typeInfo()

    FieldAccessCodeGenerator.generateRowFieldExtractor(
      CodeGeneratorContext.apply(new TableConfig, supportReference = false),
      "SimpleFieldExtractor",
      TypeConverters.createInternalTypeFromTypeInfo(inputType).asInstanceOf[RowType],
      sortFields.head)
  }

  def createRowKeyType(
      primaryKeys: Array[Int],
      inputSchema: BaseRowSchema): BaseRowTypeInfo = {

    val fieldTypes = primaryKeys.map(inputSchema.fieldTypeInfos(_))
    new BaseRowTypeInfo(fieldTypes: _*)
  }

  def createKeySelector(
      keys: Array[Int],
      inputSchema: BaseRowSchema): KeySelector[BaseRow, BaseRow] = {
    if (keys.isEmpty) {
      new NullBinaryRowKeySelector
    } else {
      new BinaryRowKeySelector(keys, inputSchema.typeInfo())
    }
  }

  def analyzeRankStrategy(
      cluster: RelOptCluster,
      tableConfig: TableConfig,
      rank: StreamExecRank,
      sortCollation: RelCollation): RankStrategy = {

    val rankInput = rank.getInput
    val fieldCollations = sortCollation.getFieldCollations

    val mono = cluster.getMetadataQuery.asInstanceOf[FlinkRelMetadataQuery]
      .getRelModifiedMonotonicity(rankInput)

    val isMonotonic = if (mono == null) {
      false
    } else {
      if (fieldCollations.isEmpty) {
        false
      } else {
        fieldCollations.forall(collation => {
          val fieldMono = mono.fieldMonotonicities(collation.getFieldIndex)
          val direction = collation.direction
          if ((fieldMono == SqlMonotonicity.DECREASING
            || fieldMono == SqlMonotonicity.STRICTLY_DECREASING)
            && direction == Direction.ASCENDING) {
            // sort field is ascending and its mono is decreasing when arriving rank node
            true
          } else if ((fieldMono == SqlMonotonicity.INCREASING
            || fieldMono == SqlMonotonicity.STRICTLY_INCREASING)
            && direction == Direction.DESCENDING) {
            // sort field is descending and its mono is increasing when arriving rank node
            true
          } else if (fieldMono == SqlMonotonicity.CONSTANT) {
            // sort key is a grouping key of upstream agg, it is monotonic
            true
          } else {
            false
          }
        })
      }
    }

    val isUpdateStream = !UpdatingPlanChecker.isAppendOnly(rankInput)
    val partitionKey = rank.partitionKey

    if (isUpdateStream) {
      val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(rankInput)
      val uniqueKeys = cluster.getMetadataQuery.getUniqueKeys(rankInput)
      if (inputIsAccRetract || uniqueKeys == null || uniqueKeys.isEmpty
        // unique key should contains partition key
        || !uniqueKeys.exists(k => k.contains(partitionKey))) {
        // input is AccRetract or extract the unique keys failed,
        // and we fall back to using retract rank
        RetractRank
      } else {
        if (isMonotonic) {
          //FIXME choose a set of primary key
          UpdateFastRank(uniqueKeys.iterator().next().toArray)
        } else {
          if (fieldCollations.length == 1) {
            // single sort key in update stream scenario (no monotonic)
            // we can utilize unary rank function to speed up processing
            UnaryUpdateRank(uniqueKeys.iterator().next().toArray)
          } else {
            // no other choices, have to use retract rank
            RetractRank
          }
        }
      }
    } else {
      AppendFastRank
    }
  }

  sealed trait RankStrategy

  case object AppendFastRank extends RankStrategy

  case object RetractRank extends RankStrategy

  case class UpdateFastRank(primaryKeys: Array[Int]) extends RankStrategy {
    override def toString: String = {
      "UpdateFastRank" + primaryKeys.mkString("[", ",", "]")
    }
  }

  case class UnaryUpdateRank(primaryKeys: Array[Int]) extends RankStrategy {
    override def toString: String = {
      "UnaryUpdateRank" + primaryKeys.mkString("[", ",", "]")
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy