All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.util.UpdatingPlanChecker.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.util

import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.{RelNode, RelVisitor}
import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode}
import org.apache.calcite.sql.SqlKind
import org.apache.flink.table.expressions.ProctimeAttribute
import org.apache.flink.table.plan.nodes.physical.stream._

import _root_.scala.collection.JavaConversions._
import scala.collection.mutable

object UpdatingPlanChecker {

  /** Validates that the plan produces only append changes. */
  def isAppendOnly(plan: RelNode): Boolean = {
    val appendOnlyValidator = new AppendOnlyValidator
    appendOnlyValidator.go(plan)

    appendOnlyValidator.isAppendOnly
  }

  /** Extracts the unique keys of the table produced by the plan. */
  def getUniqueKeyFields(plan: RelNode): Option[Array[String]] = {
    getUniqueKeyGroups(plan).map(_.map(_._1).toArray)
  }

  /** Extracts the unique keys and groups of the table produced by the plan. */
  def getUniqueKeyGroups(plan: RelNode): Option[Seq[(String, String)]] = {
    val keyExtractor = new UniqueKeyExtractor
    keyExtractor.visit(plan)
  }

  private class AppendOnlyValidator extends RelVisitor {

    var isAppendOnly = true

    override def visit(node: RelNode, ordinal: Int, parent: RelNode): Unit = {
      node match {
        case s: StreamPhysicalRel if s.producesUpdates =>
          isAppendOnly = false
        case hep: HepRelVertex =>
          visit(hep.getCurrentRel, ordinal, parent)   //remove wrapper node
        case rs: RelSubset =>
          visit(rs.getOriginal, ordinal, parent)      //remove wrapper node
        case _ =>
          super.visit(node, ordinal, parent)
      }
    }
  }

  /** Identifies unique key fields in the output of a RelNode. */
  private class UniqueKeyExtractor {

    // visit() function will return a tuple, the first element is the name of a key field, the
    // second is a group name that is shared by all equivalent key fields. The group names are
    // used to identify same keys, for example: select('pk as pk1, 'pk as pk2), both pk1 and pk2
    // belong to the same group, i.e., pk1. Here we use the lexicographic smallest attribute as
    // the common group id. A node can have keys if it generates the keys by itself or it
    // forwards keys from its input(s).
    def visit(node: RelNode): Option[Seq[(String, String)]] = {
      node match {
        case c: StreamExecCalc =>
          val inputKeys = visit(node.getInput(0))
          // check if input has keys
          if (inputKeys.isDefined) {
            // track keys forward
            val inNames = c.getInput.getRowType.getFieldNames
            val inOutNames = c.getProgram.getNamedProjects.map(p => {
              c.getProgram.expandLocalRef(p.left) match {
                // output field is forwarded input field
                case i: RexInputRef => (i.getIndex, p.right)
                // output field is renamed input field
                case a: RexCall if a.getKind.equals(SqlKind.AS) =>
                  a.getOperands.get(0) match {
                    case ref: RexInputRef =>
                      (ref.getIndex, p.right)
                    case _ =>
                      (-1, p.right)
                  }
                // output field is not forwarded from input
                case _: RexNode => (-1, p.right)
              }
            })
              // filter all non-forwarded fields
              .filter(_._1 >= 0)
              // resolve names of input fields
              .map(io => (inNames.get(io._1), io._2))

            // filter by input keys
            val inputKeysAndOutput = inOutNames
              .filter(io => inputKeys.get.map(e => e._1).contains(io._1))

            val inputKeysMap = inputKeys.get.toMap
            val inOutGroups = inputKeysAndOutput.sorted.reverse
              .map(e => (inputKeysMap(e._1), e._2))
              .toMap

            // get output keys
            val outputKeys = inputKeysAndOutput
              .map(io => (io._2, inOutGroups(inputKeysMap(io._1))))

            // check if all keys have been preserved
            if (outputKeys.map(_._2).distinct.length == inputKeys.get.map(_._2).distinct.length) {
              // all key have been preserved (but possibly renamed)
              Some(outputKeys)
            } else {
              // some (or all) keys have been removed. Keys are no longer unique and removed
              None
            }
          } else {
            None
          }

        case _: StreamExecOverAggregate =>
          // keys are always forwarded by Over aggregate
          visit(node.getInput(0))
        case a: StreamExecGroupAggregate =>
          // get grouping keys
          val groupKeys = a.getRowType.getFieldNames.take(a.getGroupings.length)
          Some(groupKeys.map(e => (e, e)))
        case w: StreamExecGroupWindowAggregate =>
          // get grouping keys
          val groupKeys =
            w.getRowType.getFieldNames.take(w.getGroupings.length).toArray
          // proctime is not a valid key
          val windowProperties = w.getWindowProperties
            .filter(!_.property.isInstanceOf[ProctimeAttribute])
            .map(_.name)
          // we have only a unique key if at least one window property is selected
          if (windowProperties.nonEmpty) {
            val windowId = windowProperties.min
            Some(groupKeys.map(e => (e, e)) ++ windowProperties.map(e => (e, windowId)))
          } else {
            None
          }

        case j: StreamExecJoin =>
          // get key(s) for join
          val lInKeys = visit(j.getLeft)
          val rInKeys = visit(j.getRight)
          if (lInKeys.isEmpty || rInKeys.isEmpty) {
            None
          } else {
            // Output of join must have keys if left and right both contain key(s).
            // Key groups from both side will be merged by join equi-predicates
            val lInNames: Seq[String] = j.getLeft.getRowType.getFieldNames
            val rInNames: Seq[String] = j.getRight.getRowType.getFieldNames
            val joinNames = j.getRowType.getFieldNames

            // if right field names equal to left field names, calcite will rename right
            // field names. For example, T1(pk, a) join T2(pk, b), calcite will rename T2(pk, b)
            // to T2(pk0, b).
            val rInNamesToJoinNamesMap = rInNames
              .zip(joinNames.subList(lInNames.size, joinNames.length))
              .toMap

            val lJoinKeys: Seq[String] = j.joinInfo.leftKeys
              .map(lInNames.get(_))
            val rJoinKeys: Seq[String] = j.joinInfo.rightKeys
              .map(rInNames.get(_))
              .map(rInNamesToJoinNamesMap(_))

            val inKeys: Seq[(String, String)] = lInKeys.get ++ rInKeys.get
              .map(e => (rInNamesToJoinNamesMap(e._1), rInNamesToJoinNamesMap(e._2)))

            getOutputKeysForNonWindowJoin(
              joinNames,
              inKeys,
              lJoinKeys.zip(rJoinKeys)
            )
          }
        case _: StreamPhysicalRel =>
          // anything else does not forward keys, so we can stop
          None
      }
    }

    /**
      * Get output keys for non-window join according to it's inputs.
      *
      * @param inNames  Field names of join
      * @param inKeys   Input keys of join
      * @param joinKeys JoinKeys of join
      * @return Return output keys of join
      */
    def getOutputKeysForNonWindowJoin(
      inNames: Seq[String],
      inKeys: Seq[(String, String)],
      joinKeys: Seq[(String, String)])
    : Option[Seq[(String, String)]] = {

      val nameToGroups = mutable.HashMap.empty[String, String]

      // merge two groups
      def merge(nameA: String, nameB: String): Unit = {
        val ga: String = findGroup(nameA)
        val gb: String = findGroup(nameB)
        if (!ga.equals(gb)) {
          if (ga.compare(gb) < 0) {
            nameToGroups += (gb -> ga)
          } else {
            nameToGroups += (ga -> gb)
          }
        }
      }

      def findGroup(x: String): String = {
        // find the group of x
        var r: String = x
        while (!nameToGroups(r).equals(r)) {
          r = nameToGroups(r)
        }

        // point all name to the group name directly
        var a: String = x
        var b: String = null
        while (!nameToGroups(a).equals(r)) {
          b = nameToGroups(a)
          nameToGroups += (a -> r)
          a = b
        }
        r
      }

      // init groups
      inNames.foreach(e => nameToGroups += (e -> e))
      inKeys.foreach(e => nameToGroups += (e._1 -> e._2))
      // merge groups
      joinKeys.foreach(e => merge(e._1, e._2))
      // make sure all name point to the group name directly
      inNames.foreach(findGroup)

      val outputGroups = inKeys.map(e => nameToGroups(e._1)).distinct
      Some(
        inNames
          .filter(e => outputGroups.contains(nameToGroups(e)))
          .map(e => (e, nameToGroups(e)))
      )
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy