All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.util.FlinkRexUtil.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.util

import org.apache.calcite.plan.{RelOptPredicateList, RelOptUtil}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.sql.{SqlKind, SqlOperator}
import org.apache.calcite.util.{ControlFlowException, Util}

import com.google.common.base.Function
import com.google.common.collect.{ImmutableList, Lists}

import java.lang.Iterable
import java.util

import scala.collection.JavaConversions._
import scala.collection.mutable

/**
  * Utility methods concerning [[RexNode]].
  */
object FlinkRexUtil {

  /**
    * Similar to [[RexUtil#toCnf(RexBuilder, Int, RexNode)]]; it lets you
    * specify a threshold in the number of nodes that can be created out of
    * the conversion. however, if the threshold is a negative number,
    * this method will give a default threshold value that is double of
    * the number of RexCall in the given node.
    *
    * 

If the number of resulting RexCalls exceeds that threshold, * stops conversion and returns the original expression. * *

Leaf nodes(e.g. RexInputRef) in the expression do not count towards the threshold. * *

We strongly discourage use the [[RexUtil#toCnf(RexBuilder, RexNode)]] and * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]], because there are many bad case when using * [[RexUtil#toCnf(RexBuilder, RexNode)]], such as predicate in TPC-DS q41.sql will be * converted to extremely complex expression (including 736450 RexCalls); and we can not give * an appropriate value for `maxCnfNodeCount` when using * [[RexUtil#toCnf(RexBuilder, Int, RexNode)]]. */ def toCnf(rexBuilder: RexBuilder, maxCnfNodeCount: Int, rex: RexNode): RexNode = { val maxCnfNodeCnt = if (maxCnfNodeCount < 0) { getNumberOfRexCall(rex) * 2 } else { maxCnfNodeCount } new CnfHelper(rexBuilder, maxCnfNodeCnt).toCnf(rex) } /** * Get the number of RexCall in the given node. */ private def getNumberOfRexCall(rex: RexNode): Int = { var numberOfNodes = 0 rex.accept(new RexVisitorImpl[Unit](true) { override def visitCall(call: RexCall): Unit = { numberOfNodes += 1 super.visitCall(call) } }) numberOfNodes } /** Helps [[toCnf]] */ private class CnfHelper(rexBuilder: RexBuilder, maxNodeCount: Int) { /** Exception to catch when we pass the limit. */ @SuppressWarnings(Array("serial")) private class OverflowError extends ControlFlowException { } @SuppressWarnings(Array("ThrowableInstanceNeverThrown")) private val INSTANCE = new OverflowError private val ADD_NOT = new Function[RexNode, RexNode]() { override def apply(input: RexNode): RexNode = rexBuilder.makeCall(input.getType, SqlStdOperatorTable.NOT, ImmutableList.of(input)) } def toCnf(rex: RexNode): RexNode = try { toCnf2(rex) } catch { case e: OverflowError => Util.swallow(e, null) rex } private def toCnf2(rex: RexNode): RexNode = { rex.getKind match { case SqlKind.AND => val cnfOperands: util.List[RexNode] = Lists.newArrayList() val operands = RexUtil.flattenAnd(rex.asInstanceOf[RexCall].operands) operands.foreach { node => val cnf = toCnf2(node) cnf.getKind match { case SqlKind.AND => cnfOperands.addAll(cnf.asInstanceOf[RexCall].operands) case _ => cnfOperands.add(cnf) } } val node = and(cnfOperands) checkCnfRexCallCount(node) node case SqlKind.OR => val operands = RexUtil.flattenOr(rex.asInstanceOf[RexCall].operands) val head = operands.head val headCnf = toCnf2(head) val headCnfs: util.List[RexNode] = RelOptUtil.conjunctions(headCnf) val tail = or(Util.skip(operands)) val tailCnf: RexNode = toCnf2(tail) val tailCnfs: util.List[RexNode] = RelOptUtil.conjunctions(tailCnf) val list: util.List[RexNode] = Lists.newArrayList() headCnfs.foreach { h => tailCnfs.foreach { t => list.add(or(ImmutableList.of(h, t))) } } val node = and(list) checkCnfRexCallCount(node) node case SqlKind.NOT => val arg = rex.asInstanceOf[RexCall].operands.head arg.getKind match { case SqlKind.NOT => toCnf2(arg.asInstanceOf[RexCall].operands.head) case SqlKind.OR => val operands = arg.asInstanceOf[RexCall].operands toCnf2(and(Lists.transform(RexUtil.flattenOr(operands), ADD_NOT))) case SqlKind.AND => val operands = arg.asInstanceOf[RexCall].operands toCnf2(or(Lists.transform(RexUtil.flattenAnd(operands), ADD_NOT))) case _ => rex } case _ => rex } } private def checkCnfRexCallCount(node: RexNode): Unit = { // TODO use more efficient solution to get number of RexCall in CNF node if (maxNodeCount >= 0 && getNumberOfRexCall(node) > maxNodeCount) { throw INSTANCE } } private def and(nodes: Iterable[_ <: RexNode]): RexNode = RexUtil.composeConjunction(rexBuilder, nodes, false) private def or(nodes: Iterable[_ <: RexNode]): RexNode = RexUtil.composeDisjunction(rexBuilder, nodes) } /** * Merges same expressions and then simplifies the result expression by [[RexSimplify]]. * * Examples for merging same expressions: * 1. a = b AND b = a -> a = b * 2. a = b OR b = a -> a = b * 3. (a > b AND c < 10) AND b < a -> a > b AND c < 10 * 4. (a > b OR c < 10) OR b < a -> a > b OR c < 10 * 5. a = a, a >= a, a <= a -> true * 6. a <> a, a > a, a < a -> false */ def simplify(rexBuilder: RexBuilder, expr: RexNode): RexNode = { if (expr.isAlwaysTrue || expr.isAlwaysFalse) { return expr } val exprShuttle = new EquivalentExprShuttle(rexBuilder) val equiExpr = expr.accept(exprShuttle) val exprMerger = new SameExprMerger(rexBuilder) val sameExprMerged = exprMerger.mergeSameExpr(equiExpr) val binaryComparisonExprReduced = sameExprMerged.accept( new BinaryComparisonExprReducer(rexBuilder)) val rexSimplify = new RexSimplify(rexBuilder, RelOptPredicateList.EMPTY, true, RexUtil.EXECUTOR) rexSimplify.simplify(binaryComparisonExprReduced) } private class BinaryComparisonExprReducer(rexBuilder: RexBuilder) extends RexShuttle { override def visitCall(call: RexCall): RexNode = { val kind = call.getOperator.getKind if (!kind.belongsTo(SqlKind.BINARY_COMPARISON)) { super.visitCall(call) } else { val operand0 = call.getOperands.get(0) val operand1 = call.getOperands.get(1) (operand0, operand1) match { case (op0: RexInputRef, op1: RexInputRef) if op0.getIndex == op1.getIndex => kind match { case SqlKind.EQUALS | SqlKind.LESS_THAN_OR_EQUAL | SqlKind.GREATER_THAN_OR_EQUAL => rexBuilder.makeLiteral(true) case SqlKind.NOT_EQUALS | SqlKind.LESS_THAN | SqlKind.GREATER_THAN => rexBuilder.makeLiteral(false) case _ => super.visitCall(call) } case _ => super.visitCall(call) } } } } private class SameExprMerger(rexBuilder: RexBuilder) extends RexShuttle { private val sameExprMap = mutable.HashMap[String, RexNode]() private def mergeSameExpr(expr: RexNode, equiExpr: RexLiteral): RexNode = { if (sameExprMap.contains(expr.toString)) { equiExpr } else { sameExprMap.put(expr.toString, expr) expr } } def mergeSameExpr(expr: RexNode): RexNode = { // merges same expressions in the operands of AND and OR // e.g. a = b AND a = b -> a = b AND true // a = b OR a = b -> a = b OR false val newExpr1 = expr.accept(this) // merges same expressions in conjunctions // e.g. (a > b AND c < 10) AND a > b -> a > b AND c < 10 AND true sameExprMap.clear() val newConjunctions = RelOptUtil.conjunctions(newExpr1).map { ex => mergeSameExpr(ex, rexBuilder.makeLiteral(true)) } val newExpr2 = newConjunctions.size match { case 0 => newExpr1 // true AND true case 1 => newConjunctions.head case _ => rexBuilder.makeCall(AND, newConjunctions: _*) } // merges same expressions in disjunctions // e.g. (a > b OR c < 10) OR a > b -> a > b OR c < 10 OR false sameExprMap.clear() val newDisjunctions = RelOptUtil.disjunctions(newExpr2).map { ex => mergeSameExpr(ex, rexBuilder.makeLiteral(false)) } val newExpr3 = newDisjunctions.size match { case 0 => newExpr2 // false OR false case 1 => newDisjunctions.head case _ => rexBuilder.makeCall(OR, newDisjunctions: _*) } newExpr3 } override def visitCall(call: RexCall): RexNode = { val newCall = call.getOperator match { case AND | OR => sameExprMap.clear() val newOperands = call.getOperands.map { op => val value = if (call.getOperator == AND) true else false mergeSameExpr(op, rexBuilder.makeLiteral(value)) } call.clone(call.getType, newOperands) case _ => call } super.visitCall(newCall) } } /** * Adjust the condition's field indices according to mapOldToNewIndex. * * @param c The condition to be adjusted. * @param fieldsOldToNewIndexMapping A map containing the mapping the old field indices to new * field indices. * @param rowType The row type of the new output. * @return Return new condition with new field indices. */ private[flink] def adjustInputRefs( c: RexNode, fieldsOldToNewIndexMapping: Map[Int, Int], rowType: RelDataType) = c.accept( new RexShuttle() { override def visitInputRef(inputRef: RexInputRef): RexNode = { assert(fieldsOldToNewIndexMapping.containsKey(inputRef.getIndex)) val newIndex = fieldsOldToNewIndexMapping(inputRef.getIndex) val ref = RexInputRef.of(newIndex, rowType) if (ref.getIndex == inputRef.getIndex && (ref.getType eq inputRef.getType)) { inputRef } else { // re-use old object, to prevent needless expr cloning ref } } }) private class EquivalentExprShuttle(rexBuilder: RexBuilder) extends RexShuttle { private val equiExprMap = mutable.HashMap[String, RexNode]() override def visitCall(call: RexCall): RexNode = { call.getOperator match { case EQUALS | NOT_EQUALS | GREATER_THAN | LESS_THAN | GREATER_THAN_OR_EQUAL | LESS_THAN_OR_EQUAL => val swapped = swapOperands(call) if (equiExprMap.contains(swapped.toString)) { swapped } else { equiExprMap.put(call.toString, call) call } case _ => super.visitCall(call) } } private def swapOperands(call: RexCall): RexCall = { val newOp = call.getOperator match { case EQUALS | NOT_EQUALS => call.getOperator case GREATER_THAN => LESS_THAN case GREATER_THAN_OR_EQUAL => LESS_THAN_OR_EQUAL case LESS_THAN => GREATER_THAN case LESS_THAN_OR_EQUAL => GREATER_THAN_OR_EQUAL case _ => throw new IllegalArgumentException(s"Unsupported operator: ${call.getOperator}") } val operands = call.getOperands rexBuilder.makeCall(newOp, operands.last, operands.head).asInstanceOf[RexCall] } } /** * Returns whether a given expression has dynamic function. * * @param e Expression * @return true if tree has dynamic function, false otherwise */ def hasDynamicFunction(e: RexNode): Boolean = try { val visitor = new RexVisitorImpl[Void](true) { override def visitCall(call: RexCall): Void = { if (call.getOperator.isDynamicFunction) { throw Util.FoundOne.NULL } super.visitCall(call) } } e.accept(visitor) false } catch { case ex: Util.FoundOne => Util.swallow(ex, null) true } /** * Return true if the given RexNode is null or does not have * non-deterministic `SqlOperator` and dynamic function `SqlOperator`. */ def isDeterministicOperator(rex: RexNode): Boolean = { if (rex == null) { true } else { RexUtil.isDeterministic(rex) && !FlinkRexUtil.hasDynamicFunction(rex) } } /** * Return true if the given operator is null or is deterministic and none dynamic function. */ def isDeterministicOperator(op: SqlOperator): Boolean = { if (op == null) { true } else { op.isDeterministic && !op.isDynamicFunction } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy