All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.planner.plan.utils

import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.planner.JBoolean
import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkPlannerImpl, FlinkTypeFactory}
import org.apache.flink.table.planner.plan.`trait`.{MiniBatchInterval, MiniBatchMode}

import org.apache.calcite.config.NullCollation
import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.RelFieldCollation.{Direction, NullDirection}
import org.apache.calcite.rel.{RelFieldCollation, RelNode}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil, RexVisitorImpl}
import org.apache.calcite.sql.SqlExplainLevel
import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.commons.math3.util.ArithmeticUtils

import java.io.{PrintWriter, StringWriter}
import java.math.BigDecimal
import java.sql.{Date, Time, Timestamp}
import java.util.Calendar

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable

/**
  * FlinkRelOptUtil provides utility methods for use in optimizing RelNodes.
  */
object FlinkRelOptUtil {

  /**
    * Converts a relational expression to a string.
    * This is different from [[RelOptUtil]]#toString on two points:
    * 1. Generated string by this method is in a tree style
    * 2. Generated string by this method may have more information about RelNode, such as
    * RelNode id, retractionTraits.
    *
    * @param rel                the RelNode to convert
    * @param detailLevel        detailLevel defines detail levels for EXPLAIN PLAN.
    * @param withIdPrefix       whether including ID of RelNode as prefix
    * @param withChangelogTraits  whether including changelog traits of RelNode (only applied to
    *                             StreamPhysicalRel node at present)
    * @param withRowType        whether including output rowType
    * @return explain plan of RelNode
    */
  def toString(
      rel: RelNode,
      detailLevel: SqlExplainLevel = SqlExplainLevel.EXPPLAN_ATTRIBUTES,
      withIdPrefix: Boolean = false,
      withChangelogTraits: Boolean = false,
      withRowType: Boolean = false): String = {
    if (rel == null) {
      return null
    }
    val sw = new StringWriter
    val planWriter = new RelTreeWriterImpl(
      new PrintWriter(sw),
      detailLevel,
      withIdPrefix,
      withChangelogTraits,
      withRowType,
      withTreeStyle = true)
    rel.explain(planWriter)
    sw.toString
  }

  /**
    * Gets the digest for a rel tree.
    *
    * The digest of RelNode should contain the result of RelNode#explain method, retraction traits
    * (for StreamPhysicalRel) and RelNode's row type.
    *
    * Row type is part of the digest for the rare occasion that similar
    * expressions have different types, e.g.
    * "WITH
    * t1 AS (SELECT CAST(a as BIGINT) AS a, SUM(b) AS b FROM x GROUP BY CAST(a as BIGINT)),
    * t2 AS (SELECT CAST(a as DOUBLE) AS a, SUM(b) AS b FROM x GROUP BY CAST(a as DOUBLE))
    * SELECT t1.*, t2.* FROM t1, t2 WHERE t1.b = t2.b"
    *
    * the physical plan is:
    * {{{
    *  HashJoin(where=[=(b, b0)], join=[a, b, a0, b0], joinType=[InnerJoin],
    *    isBroadcast=[true], build=[right])
    *  :- HashAggregate(groupBy=[a], select=[a, Final_SUM(sum$0) AS b])
    *  :  +- Exchange(distribution=[hash[a]])
    *  :     +- LocalHashAggregate(groupBy=[a], select=[a, Partial_SUM(b) AS sum$0])
    *  :        +- Calc(select=[CAST(a) AS a, b])
    *  :           +- ScanTable(table=[[builtin, default, x]], fields=[a, b, c])
    *  +- Exchange(distribution=[broadcast])
    *     +- HashAggregate(groupBy=[a], select=[a, Final_SUM(sum$0) AS b])
    *        +- Exchange(distribution=[hash[a]])
    *           +- LocalHashAggregate(groupBy=[a], select=[a, Partial_SUM(b) AS sum$0])
    *              +- Calc(select=[CAST(a) AS a, b])
    *                 +- ScanTable(table=[[builtin, default, x]], fields=[a, b, c])
    * }}}
    *
    * The sub-plan of `HashAggregate(groupBy=[a], select=[a, Final_SUM(sum$0) AS b])`
    * are different because `CAST(a) AS a` has different types, where one is BIGINT type
    * and another is DOUBLE type.
    *
    * If use the result of `RelOptUtil.toString(aggregate, SqlExplainLevel.DIGEST_ATTRIBUTES)`
    * on `HashAggregate(groupBy=[a], select=[a, Final_SUM(sum$0) AS b])` as digest,
    * we will get incorrect result. So rewrite `explain_` method of `RelWriterImpl` to
    * add row-type to digest value.
    *
    * @param rel rel node tree
    * @return The digest of given rel tree.
    */
  def getDigest(rel: RelNode): String = {
    val sw = new StringWriter
    rel.explain(new RelTreeWriterImpl(
      new PrintWriter(sw),
      explainLevel = SqlExplainLevel.DIGEST_ATTRIBUTES,
      // ignore id, only contains RelNode's attributes
      withIdPrefix = false,
      // add retraction traits to digest for StreamPhysicalRel node
      withChangelogTraits = true,
      // add row type to digest to avoid corner case that similar
      // expressions have different types
      withRowType = true,
      // ignore tree style, only contains RelNode's attributes
      withTreeStyle = false))
    sw.toString
  }

  /**
    * Returns the null direction if not specified.
    *
    * @param direction Direction that a field is ordered in.
    * @return default null direction
    */
  def defaultNullDirection(direction: Direction): NullDirection = {
    FlinkPlannerImpl.defaultNullCollation match {
      case NullCollation.FIRST => NullDirection.FIRST
      case NullCollation.LAST => NullDirection.LAST
      case NullCollation.LOW =>
        direction match {
          case Direction.ASCENDING | Direction.STRICTLY_ASCENDING => NullDirection.FIRST
          case Direction.DESCENDING | Direction.STRICTLY_DESCENDING => NullDirection.LAST
          case _ => NullDirection.UNSPECIFIED
        }
      case NullCollation.HIGH =>
        direction match {
          case Direction.ASCENDING | Direction.STRICTLY_ASCENDING => NullDirection.LAST
          case Direction.DESCENDING | Direction.STRICTLY_DESCENDING => NullDirection.FIRST
          case _ => NullDirection.UNSPECIFIED
        }
    }
  }

  /**
    * Creates a field collation with default direction.
    *
    * @param fieldIndex 0-based index of field being sorted
    * @return the field collation with default direction and given field index.
    */
  def ofRelFieldCollation(fieldIndex: Int): RelFieldCollation = {
    new RelFieldCollation(
      fieldIndex,
      FlinkPlannerImpl.defaultCollationDirection,
      defaultNullDirection(FlinkPlannerImpl.defaultCollationDirection))
  }

  /**
    * Creates a field collation.
    *
    * @param fieldIndex    0-based index of field being sorted
    * @param direction     Direction of sorting
    * @param nullDirection Direction of sorting of nulls
    * @return the field collation.
    */
  def ofRelFieldCollation(
      fieldIndex: Int,
      direction: RelFieldCollation.Direction,
      nullDirection: RelFieldCollation.NullDirection): RelFieldCollation = {
    new RelFieldCollation(fieldIndex, direction, nullDirection)
  }

  def getTableConfigFromContext(rel: RelNode): TableConfig = {
    rel.getCluster.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
  }

  /** Get max cnf node limit by context of rel */
  def getMaxCnfNodeCount(rel: RelNode): Int = {
    val tableConfig = getTableConfigFromContext(rel)
    tableConfig.getConfiguration.getInteger(FlinkRexUtil.TABLE_OPTIMIZER_CNF_NODES_LIMIT)
  }

  /**
    * Gets values of [[RexLiteral]] by its broad type.
   *
   * 

All number value (TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, DECIMAL) * will be converted to BigDecimal * * @param literal input RexLiteral * @return value of the input RexLiteral */ def getLiteralValueByBroadType(literal: RexLiteral): Comparable[_] = { if (literal.isNull) { null } else { literal.getTypeName match { case BOOLEAN => RexLiteral.booleanValue(literal) case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE | DECIMAL => literal.getValue3.asInstanceOf[BigDecimal] case VARCHAR | CHAR => literal.getValueAs(classOf[String]) // temporal types case DATE => new Date(literal.getValueAs(classOf[Calendar]).getTimeInMillis) case TIME => new Time(literal.getValueAs(classOf[Calendar]).getTimeInMillis) case TIMESTAMP => new Timestamp(literal.getValueAs(classOf[Calendar]).getTimeInMillis) case _ => throw new IllegalArgumentException( s"Literal type ${literal.getTypeName} is not supported!") } } } /** * Partitions the [[RexNode]] in two [[RexNode]] according to a predicate. * The result is a pair of RexNode: the first RexNode consists of RexNode that satisfy the * predicate and the second RexNode consists of RexNode that don't. * * For simple condition which is not AND, OR, NOT, it is completely satisfy the predicate or not. * * For complex condition Ands, partition each operands of ANDS recursively, then * merge the RexNode which satisfy the predicate as the first part, merge the rest parts as the * second part. * * For complex condition ORs, try to pull up common factors among ORs first, if the common * factors is not A ORs, then simplify the question to partition the common factors expression; * else the input condition is completely satisfy the predicate or not based on whether all * its operands satisfy the predicate or not. * * For complex condition NOT, it is completely satisfy the predicate or not based on whether its * operand satisfy the predicate or not. * * @param expr the expression to partition * @param rexBuilder rexBuilder * @param predicate the specified predicate on which to partition * @return a pair of RexNode: the first RexNode consists of RexNode that satisfy the predicate * and the second RexNode consists of RexNode that don't */ def partition( expr: RexNode, rexBuilder: RexBuilder, predicate: RexNode => JBoolean): (Option[RexNode], Option[RexNode]) = { val condition = pushNotToLeaf(expr, rexBuilder) val (left: Option[RexNode], right: Option[RexNode]) = condition.getKind match { case AND => val (leftExprs, rightExprs) = partition( condition.asInstanceOf[RexCall].operands, rexBuilder, predicate) if (leftExprs.isEmpty) { (None, Option(condition)) } else { val l = RexUtil.composeConjunction(rexBuilder, leftExprs.asJava, false) if (rightExprs.isEmpty) { (Option(l), None) } else { val r = RexUtil.composeConjunction(rexBuilder, rightExprs.asJava, false) (Option(l), Option(r)) } } case OR => val e = RexUtil.pullFactors(rexBuilder, condition) e.getKind match { case OR => val (leftExprs, rightExprs) = partition( condition.asInstanceOf[RexCall].operands, rexBuilder, predicate) if (leftExprs.isEmpty || rightExprs.nonEmpty) { (None, Option(condition)) } else { val l = RexUtil.composeDisjunction(rexBuilder, leftExprs.asJava, false) (Option(l), None) } case _ => partition(e, rexBuilder, predicate) } case NOT => val operand = condition.asInstanceOf[RexCall].operands.head partition(operand, rexBuilder, predicate) match { case (Some(_), None) => (Option(condition), None) case (_, _) => (None, Option(condition)) } case IS_TRUE => val operand = condition.asInstanceOf[RexCall].operands.head partition(operand, rexBuilder, predicate) case IS_FALSE => val operand = condition.asInstanceOf[RexCall].operands.head val newCondition = pushNotToLeaf(operand, rexBuilder, needReverse = true) partition(newCondition, rexBuilder, predicate) case _ => if (predicate(condition)) { (Option(condition), None) } else { (None, Option(condition)) } } (convertRexNodeIfAlwaysTrue(left), convertRexNodeIfAlwaysTrue(right)) } private def partition( exprs: Iterable[RexNode], rexBuilder: RexBuilder, predicate: RexNode => JBoolean): (Iterable[RexNode], Iterable[RexNode]) = { val leftExprs = mutable.ListBuffer[RexNode]() val rightExprs = mutable.ListBuffer[RexNode]() exprs.foreach(expr => partition(expr, rexBuilder, predicate) match { case (Some(first), Some(second)) => leftExprs += first rightExprs += second case (None, Some(rest)) => rightExprs += rest case (Some(interested), None) => leftExprs += interested }) (leftExprs, rightExprs) } private def convertRexNodeIfAlwaysTrue(expr: Option[RexNode]): Option[RexNode] = { expr match { case Some(rex) if rex.isAlwaysTrue => None case _ => expr } } private def pushNotToLeaf(expr: RexNode, rexBuilder: RexBuilder, needReverse: Boolean = false): RexNode = (expr.getKind, needReverse) match { case (AND, true) | (OR, false) => val convertedExprs = expr.asInstanceOf[RexCall].operands .map(pushNotToLeaf(_, rexBuilder, needReverse)) RexUtil.composeDisjunction(rexBuilder, convertedExprs, false) case (AND, false) | (OR, true) => val convertedExprs = expr.asInstanceOf[RexCall].operands .map(pushNotToLeaf(_, rexBuilder, needReverse)) RexUtil.composeConjunction(rexBuilder, convertedExprs, false) case (NOT, _) => val child = expr.asInstanceOf[RexCall].operands.head pushNotToLeaf(child, rexBuilder, !needReverse) case (_, true) if expr.isInstanceOf[RexCall] => val negatedExpr = RexUtil.negate(rexBuilder, expr.asInstanceOf[RexCall]) if (negatedExpr != null) negatedExpr else RexUtil.not(expr) case (_, true) => RexUtil.not(expr) case (_, false) => expr } /** * An RexVisitor to judge whether the RexNode is related to the specified index InputRef */ class ColumnRelatedVisitor(index: Int) extends RexVisitorImpl[JBoolean](true) { override def visitInputRef(inputRef: RexInputRef): JBoolean = inputRef.getIndex == index override def visitLiteral(literal: RexLiteral): JBoolean = true override def visitCall(call: RexCall): JBoolean = { call.operands.forall(operand => { val isRelated = operand.accept(this) isRelated != null && isRelated }) } } /** * An RexVisitor to find whether this is a call on a time indicator field. */ class TimeIndicatorExprFinder extends RexVisitorImpl[Boolean](true) { override def visitInputRef(inputRef: RexInputRef): Boolean = { FlinkTypeFactory.isTimeIndicatorType(inputRef.getType) } } /** * Merge two MiniBatchInterval as a new one. * * The Merge Logic: MiniBatchMode: (R: rowtime, P: proctime, N: None), I: Interval * Possible values: * - (R, I = 0): operators that require watermark (window excluded). * - (R, I > 0): window / operators that require watermark with minibatch enabled. * - (R, I = -1): existing window aggregate * - (P, I > 0): unbounded agg with minibatch enabled. * - (N, I = 0): no operator requires watermark, minibatch disabled * ------------------------------------------------ * | A | B | merged result * ------------------------------------------------ * | R, I_a == 0 | R, I_b | R, gcd(I_a, I_b) * ------------------------------------------------ * | R, I_a == 0 | P, I_b | R, I_b * ------------------------------------------------ * | R, I_a > 0 | R, I_b | R, gcd(I_a, I_b) * ------------------------------------------------ * | R, I_a > 0 | P, I_b | R, I_a * ------------------------------------------------ * | R, I_a = -1 | R, I_b | R, I_a * ------------------------------------------------ * | R, I_a = -1 | P, I_b | R, I_a * ------------------------------------------------ * | P, I_a | R, I_b == 0 | R, I_a * ------------------------------------------------ * | P, I_a | R, I_b > 0 | R, I_b * ------------------------------------------------ * | P, I_a | P, I_b > 0 | P, I_a * ------------------------------------------------ */ def mergeMiniBatchInterval( interval1: MiniBatchInterval, interval2: MiniBatchInterval): MiniBatchInterval = { if (interval1 == MiniBatchInterval.NO_MINIBATCH || interval2 == MiniBatchInterval.NO_MINIBATCH) { return MiniBatchInterval.NO_MINIBATCH } interval1.getMode match { case MiniBatchMode.None => interval2 case MiniBatchMode.RowTime => interval2.getMode match { case MiniBatchMode.None => interval1 case MiniBatchMode.RowTime => val gcd = ArithmeticUtils.gcd(interval1.getInterval, interval2.getInterval) new MiniBatchInterval(gcd, MiniBatchMode.RowTime) case MiniBatchMode.ProcTime => if (interval1.getInterval == 0) { new MiniBatchInterval(interval2.getInterval, MiniBatchMode.RowTime) } else { interval1 } } case MiniBatchMode.ProcTime => interval2.getMode match { case MiniBatchMode.None | MiniBatchMode.ProcTime => interval1 case MiniBatchMode.RowTime => if (interval2.getInterval > 0) { interval2 } else { new MiniBatchInterval(interval1.getInterval, MiniBatchMode.RowTime) } } } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy