org.apache.flink.table.plan.metadata.FlinkRelMdUniqueGroups.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.metadata
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.plan.metadata.FlinkMetadata.UniqueGroups
import org.apache.flink.table.plan.nodes.calcite.{Expand, LogicalWindowAggregate, Rank}
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalWindowAggregate
import org.apache.flink.table.plan.nodes.physical.batch._
import org.apache.flink.table.plan.util.FlinkRelMdUtil
import org.apache.flink.table.plan.util.FlinkRelMdUtil.splitColumnsIntoLeftAndRight
import org.apache.flink.table.plan.util.FlinkRelOptUtil.checkAndSplitAggCalls
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind
import org.apache.calcite.util.{Bug, ImmutableBitSet, Util}
import java.util
import scala.collection.JavaConversions._
import scala.collection.mutable
class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] {
override def getDef: MetadataDef[UniqueGroups] = FlinkMetadata.UniqueGroups.DEF
def getUniqueGroups(
ts: TableScan,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val uniqueKeys = mq.getUniqueKeys(ts)
if (uniqueKeys == null || uniqueKeys.isEmpty) {
return columns
}
require(columns.forall(_ < ts.getRowType.getFieldCount))
val none = Option.empty[ImmutableBitSet]
// find the minimum uniqueKey
val uniqueGroups = uniqueKeys.foldLeft(none) {
(groups, uniqueKey) =>
val containUniqueKey = columns.contains(uniqueKey)
groups match {
case Some(g) =>
if (containUniqueKey && g.cardinality() > uniqueKey.cardinality()) {
Some(uniqueKey)
} else {
groups
}
case _ => if (containUniqueKey) Some(uniqueKey) else none
}
}
uniqueGroups.getOrElse(columns)
}
def getUniqueGroups(
expand: Expand,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val columnList = columns.toList
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val columnsSkipExpandId = columnList.filter(_ != expand.expandIdIndex)
if (columnsSkipExpandId.isEmpty) {
return columns
}
// mapping input column index to output index for non-null value columns
val mapInputToOutput = new util.HashMap[Int, Int]()
columnsSkipExpandId.foreach { column =>
val inputRefs = FlinkRelMdUtil.getInputRefIndices(column, expand)
if (inputRefs.size() == 1 && inputRefs.head >= 0) {
mapInputToOutput.put(inputRefs.head, column)
}
}
if (mapInputToOutput.isEmpty) {
return columns
}
val leftColumns = columnList.filterNot(mapInputToOutput.values().contains)
val inputUniqueGroups = fmq.getUniqueGroups(
expand.getInput, ImmutableBitSet.of(mapInputToOutput.keys.toSeq: _*))
val outputUniqueGroups = inputUniqueGroups.map(mapInputToOutput.get)
ImmutableBitSet.of(outputUniqueGroups.toSeq: _*).union(ImmutableBitSet.of(leftColumns))
}
def getUniqueGroups(
rank: Rank,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val columnList = columns.toList
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val rankFunColumnIndex = FlinkRelMdUtil.getRankFunColumnIndex(rank)
val columnSkipRankCol = columnList.filter(_ != rankFunColumnIndex)
if (columnSkipRankCol.isEmpty) {
return columns
}
val inputUniqueGroups = fmq.getUniqueGroups(
rank.getInput, ImmutableBitSet.of(columnSkipRankCol))
if (columnList.contains(rankFunColumnIndex)) {
inputUniqueGroups.union(ImmutableBitSet.of(rankFunColumnIndex))
} else {
inputUniqueGroups
}
}
def getUniqueGroups(
filter: Filter,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getUniqueGroups(filter.getInput, columns)
}
def getUniqueGroups(
project: Project,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val projects = project.getProjects
getUniqueGroupsOfProject(projects, project.getInput, mq, columns)
}
def getUniqueGroups(
calc: Calc,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val projects = calc.getProgram.getProjectList.map(calc.getProgram.expandLocalRef)
getUniqueGroupsOfProject(projects, calc.getInput, mq, columns)
}
private def getUniqueGroupsOfProject(
projects: util.List[RexNode],
input: RelNode,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val columnList = columns.toList
val mapInToOutRefPos = new mutable.HashMap[Integer, Integer]()
// find non-RexInputRef and non-Constant columns
val outNonRefOrConstantCols = new mutable.ArrayBuffer[Integer]()
columns.foreach { column =>
require(column < projects.size)
projects.get(column) match {
case ref: RexInputRef => mapInToOutRefPos.putIfAbsent(ref.getIndex, column)
case call: RexCall if call.getKind.equals(SqlKind.AS) &&
call.getOperands.head.isInstanceOf[RexInputRef] =>
val index = call.getOperands.head.asInstanceOf[RexInputRef].getIndex
mapInToOutRefPos.putIfAbsent(index, column)
case _: RexLiteral => // do nothing
case _ => outNonRefOrConstantCols += column
}
}
if (mapInToOutRefPos.isEmpty) {
val nonConstantCols = columnList.filterNot { column =>
projects.get(column).isInstanceOf[RexLiteral]
}
if (nonConstantCols.isEmpty) {
// all columns are constant, return first column
ImmutableBitSet.of(columnList.head)
} else {
// return non-constant columns
ImmutableBitSet.of(nonConstantCols)
}
} else {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputColumns = ImmutableBitSet.of(mapInToOutRefPos.keys.toList)
val inputUniqueGroups = fmq.getUniqueGroups(input, inputColumns)
val outputUniqueGroups = inputUniqueGroups.asList.map {
k => mapInToOutRefPos.getOrElse(k, throw new IllegalArgumentException(s"Illegal index: $k"))
}
ImmutableBitSet.of(outputUniqueGroups).union(ImmutableBitSet.of(outNonRefOrConstantCols))
}
}
def getUniqueGroups(
exchange: Exchange,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getUniqueGroups(exchange.getInput, columns)
}
def getUniqueGroups(
rel: SetOp,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = columns
def getUniqueGroups(
sort: Sort,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getUniqueGroups(sort.getInput, columns)
}
def getUniqueGroups(
rel: Correlate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = columns
def getUniqueGroups(
rel: BatchExecCorrelate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = columns
def getUniqueGroups(
join: Join,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
require(join.getSystemFieldList.isEmpty)
val leftFieldCount = join.getLeft.getRowType.getFieldCount
val (leftColumns, rightColumns) = splitColumnsIntoLeftAndRight(leftFieldCount, columns)
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val leftUniqueGroups = fmq.getUniqueGroups(join.getLeft, leftColumns)
val rightUniqueGroups = fmq.getUniqueGroups(join.getRight, rightColumns)
val joinType = join.getJoinType
val joinInfo = join.analyzeCondition()
val leftJoinKeys = ImmutableBitSet.of(joinInfo.leftKeys)
val rightJoinKeys = ImmutableBitSet.of(joinInfo.rightKeys)
// for INNER and LEFT join, returns leftUniqueGroups if the join keys of RHS are unique
if (leftJoinKeys.nonEmpty
&& leftUniqueGroups.contains(leftJoinKeys)
&& !joinType.generatesNullsOnLeft()) {
val isRightJoinKeysUnique = fmq.areColumnsUnique(join.getRight, rightJoinKeys)
if (isRightJoinKeysUnique != null && isRightJoinKeysUnique) {
return leftUniqueGroups
}
}
val outputRightUniqueGroups =
rightUniqueGroups.asList.map(c => Integer.valueOf(c + leftFieldCount))
// for INNER and RIGHT join, returns rightUniqueGroups if the join keys of LHS are unique
if (rightJoinKeys.nonEmpty
&& rightUniqueGroups.contains(rightJoinKeys)
&& !joinType.generatesNullsOnRight()) {
val isLeftJoinKeysUnique = fmq.areColumnsUnique(join.getLeft, leftJoinKeys)
if (isLeftJoinKeysUnique != null && isLeftJoinKeysUnique) {
return ImmutableBitSet.of(outputRightUniqueGroups)
}
}
leftUniqueGroups.union(ImmutableBitSet.of(outputRightUniqueGroups))
}
def getUniqueGroups(
semiJoin: SemiJoin,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
require(semiJoin.getSystemFieldList.isEmpty)
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getUniqueGroups(semiJoin.getLeft, columns)
}
def getUniqueGroups(
agg: Aggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = agg.getGroupSet.map(_.toInt).toArray
getUniqueGroupsOfAggregate(agg.getRowType.getFieldCount, grouping, agg.getInput, mq, columns)
}
def getUniqueGroups(
agg: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = agg.getGrouping
getUniqueGroupsOfAggregate(agg.getRowType.getFieldCount, grouping, agg.getInput, mq, columns)
}
private def getUniqueGroupsOfAggregate(
outputFiledCount: Int,
grouping: Array[Int],
input: RelNode,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val columnList = columns.toList
val groupingInToOutMap = new mutable.HashMap[Integer, Integer]()
columnList.foreach { column =>
require(column < outputFiledCount)
if (column < grouping.length) {
groupingInToOutMap.put(grouping(column), column)
}
}
if (groupingInToOutMap.isEmpty) {
columns
} else {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputColumns = ImmutableBitSet.of(groupingInToOutMap.keys.toList)
val inputUniqueGroups = fmq.getUniqueGroups(input, inputColumns)
val uniqueGroupsFromGrouping = inputUniqueGroups.asList.map { k =>
groupingInToOutMap.getOrElse(k, throw new IllegalArgumentException(s"Illegal index: $k"))
}
val nonGroupingCols = if (inputColumns.toArray.sorted.sameElements(grouping.sorted)) {
// if values of inputColumns are grouping columns, nonGroupingCols can be dropped.
// (because grouping columns are unique.)
Seq.empty[Integer]
} else {
val groupingOutColumns = groupingInToOutMap.values
columnList.filterNot(groupingOutColumns.contains(_))
}
ImmutableBitSet.of(uniqueGroupsFromGrouping).union(ImmutableBitSet.of(nonGroupingCols))
}
}
def getUniqueGroups(
window: LogicalWindowAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = window.getGroupSet.map(_.toInt).toArray
val namedProperties = window.getNamedProperties
val (auxGroupSet, _) = checkAndSplitAggCalls(window)
if (window.indicator) {
require(auxGroupSet.isEmpty)
}
getUniqueGroupsOfWindow(window, grouping, auxGroupSet, namedProperties, mq, columns)
}
def getUniqueGroups(
window: FlinkLogicalWindowAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = window.getGroupSet.map(_.toInt).toArray
val namedProperties = window.getNamedProperties
val (auxGroupSet, _) = checkAndSplitAggCalls(window)
getUniqueGroupsOfWindow(window, grouping, auxGroupSet, namedProperties, mq, columns)
}
def getUniqueGroups(
window: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val grouping = window.getGrouping
val namedProperties = window.getNamedProperties
getUniqueGroupsOfWindow(window, grouping, window.getAuxGrouping, namedProperties, mq, columns)
}
private def getUniqueGroupsOfWindow(
window: SingleRel,
grouping: Array[Int],
auxGrouping: Array[Int],
namedProperties: Seq[NamedWindowProperty],
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val fieldCount = window.getRowType.getFieldCount
val columnList = columns.toList
val groupingInToOutMap = new mutable.HashMap[Integer, Integer]()
columnList.foreach { column =>
require(column < fieldCount)
if (column < grouping.length) {
groupingInToOutMap.put(grouping(column), column)
}
}
if (groupingInToOutMap.isEmpty) {
columns
} else {
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputColumns = ImmutableBitSet.of(groupingInToOutMap.keys.toList)
val inputUniqueGroups = fmq.getUniqueGroups(window.getInput, inputColumns)
val uniqueGroupsFromGrouping = inputUniqueGroups.asList.map { i =>
groupingInToOutMap.getOrElse(i, throw new IllegalArgumentException(s"Illegal index: $i"))
}
if (columns.equals(ImmutableBitSet.of(grouping ++ auxGrouping: _*))) {
return ImmutableBitSet.of(uniqueGroupsFromGrouping)
}
val groupingOutCols = groupingInToOutMap.values
// TODO drop some nonGroupingCols base on FlinkRelMdColumnUniqueness#areColumnsUnique(window)
val nonGroupingCols = columnList.filterNot(groupingOutCols.contains)
ImmutableBitSet.of(uniqueGroupsFromGrouping).union(ImmutableBitSet.of(nonGroupingCols))
}
}
def getUniqueGroups(
over: Window,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
getUniqueGroupsOfOver(over.getRowType.getFieldCount, over.getInput, mq, columns)
}
def getUniqueGroups(
over: BatchExecOverAggregate,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
getUniqueGroupsOfOver(over.getRowType.getFieldCount, over.getInput, mq, columns)
}
private def getUniqueGroupsOfOver(
outputFiledCount: Int,
input: RelNode,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
val inputFieldCount = input.getRowType.getFieldCount
val (inputColumns, nonInputColumns) = columns.toList.partition(_ < inputFieldCount)
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val inputUniqueGroups = fmq.getUniqueGroups(input, ImmutableBitSet.of(inputColumns))
inputUniqueGroups.union(ImmutableBitSet.of(nonInputColumns))
}
// Catch-all rule when none of the others apply.
def getUniqueGroups(
rel: RelNode,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = columns
def getUniqueGroups(
rel: RelSubset,
mq: RelMetadataQuery,
columns: ImmutableBitSet): ImmutableBitSet = {
if (!Bug.CALCITE_1048_FIXED) {
//if the best node is null, so we can get the uniqueKeys based original node, due to
//the original node is logically equivalent as the rel.
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
fmq.getUniqueGroups(Util.first(rel.getBest, rel.getOriginal), columns)
} else {
throw new RuntimeException("CALCITE_1048 is fixed, so check this method again!")
}
}
}
object FlinkRelMdUniqueGroups {
private val INSTANCE = new FlinkRelMdUniqueGroups
val SOURCE: RelMetadataProvider = ReflectiveRelMetadataProvider.reflectiveSource(
FlinkMetadata.UniqueGroups.METHOD, INSTANCE)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy