All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.planner.plan.metadata.FlinkRelMdColumnInterval.scala Maven / Gradle / Ivy

Go to download

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.planner.plan.metadata

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ColumnInterval
import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate}
import org.apache.flink.table.planner.plan.nodes.physical.batch._
import org.apache.flink.table.planner.plan.nodes.physical.stream._
import org.apache.flink.table.planner.plan.schema.FlinkPreparingTableBase
import org.apache.flink.table.planner.plan.stats._
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ColumnIntervalUtil, FlinkRelOptUtil, RankUtil}
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, VariableRankRange}
import org.apache.flink.util.Preconditions

import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.metadata._
import org.apache.calcite.rel.{RelNode, SingleRel}
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlKind._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlBinaryOperator, SqlKind}
import org.apache.calcite.util.Util

import java.math.{BigDecimal => JBigDecimal}

import scala.collection.JavaConversions._

/**
 * FlinkRelMdColumnInterval supplies a default implementation of
 * [[FlinkRelMetadataQuery.getColumnInterval]] for the standard logical algebra.
 */
class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {

  override def getDef: MetadataDef[ColumnInterval] = FlinkMetadata.ColumnInterval.DEF

  /**
   * Gets interval of the given column on TableScan.
   *
   * @param ts    TableScan RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on TableScan
   */
  def getColumnInterval(ts: TableScan, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val relOptTable = ts.getTable.asInstanceOf[FlinkPreparingTableBase]
    val fieldNames = relOptTable.getRowType.getFieldNames
    Preconditions.checkArgument(index >= 0 && index < fieldNames.size())
    val fieldName = fieldNames.get(index)
    val statistic = relOptTable.getStatistic
    val colStats = statistic.getColumnStats(fieldName)
    if (colStats != null) {
      val minValue = colStats.getMinValue
      val maxValue = colStats.getMaxValue
      val min = colStats.getMin
      val max = colStats.getMax

      Preconditions.checkArgument(
        (minValue == null && maxValue == null) || (max == null && min == null))

      if (minValue != null || maxValue != null) {
        ValueInterval(convertNumberToBigDecimal(minValue), convertNumberToBigDecimal(maxValue))
      } else if (max != null || min != null) {
        ValueInterval(convertNumberToBigDecimal(min), convertNumberToBigDecimal(max))
      } else {
        null
      }
    } else {
      null
    }
  }

  private def convertNumberToBigDecimal(number: Number): Number = {
    if (number != null) {
      new JBigDecimal(number.toString)
    } else {
      number
    }
  }

  private def convertNumberToBigDecimal(comparable: Comparable[_]): Comparable[_] = {
    if (comparable != null && comparable.isInstanceOf[Number]) {
      new JBigDecimal(comparable.toString)
    } else {
      comparable
    }
  }

  /**
   * Gets interval of the given column on Values.
   *
   * @param values Values RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  the index of the given column
   * @return interval of the given column on Values
   */
  def getColumnInterval(values: Values, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val tuples = values.tuples
    if (tuples.isEmpty) {
      EmptyValueInterval
    } else {
      val values = tuples.map {
        t => FlinkRelOptUtil.getLiteralValueByBroadType(t.get(index))
      }.filter(_ != null)
      if (values.isEmpty) {
        EmptyValueInterval
      } else {
        values.map(v => ValueInterval(v, v)).reduceLeft(ValueInterval.union)
      }
    }
  }

  /**
   * Gets interval of the given column on Snapshot.
   *
   * @param snapshot    Snapshot RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on Snapshot.
   */
  def getColumnInterval(snapshot: Snapshot, mq: RelMetadataQuery, index: Int): ValueInterval = null

  /**
   * Gets interval of the given column on Project.
   *
   * Note: Only support the simple RexNode, e.g RexInputRef.
   *
   * @param project Project RelNode
   * @param mq      RelMetadataQuery instance
   * @param index   the index of the given column
   * @return interval of the given column on Project
   */
  def getColumnInterval(project: Project, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val projects = project.getProjects
    Preconditions.checkArgument(index >= 0 && index < projects.size())
    projects.get(index) match {
      case inputRef: RexInputRef => fmq.getColumnInterval(project.getInput, inputRef.getIndex)
      case literal: RexLiteral =>
        val literalValue = FlinkRelOptUtil.getLiteralValueByBroadType(literal)
        if (literalValue == null) {
          ValueInterval.empty
        } else {
          ValueInterval(literalValue, literalValue)
        }
      case rexCall: RexCall =>
        getRexNodeInterval(rexCall, project, mq)
      case _ => null
    }
  }

  /**
   * Gets interval of the given column on Filter.
   *
   * @param filter Filter RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  the index of the given column
   * @return interval of the given column on Filter
   */
  def getColumnInterval(filter: Filter, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val inputValueInterval = fmq.getColumnInterval(filter.getInput, index)
    ColumnIntervalUtil.getColumnIntervalWithFilter(
      Option(inputValueInterval),
      RexUtil.expandSearch(
        filter.getCluster.getRexBuilder,
        null,
        filter.getCondition),
      index,
      filter.getCluster.getRexBuilder)
  }

  /**
   * Gets interval of the given column on Calc.
   *
   * @param calc  Filter RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on Calc
   */
  def getColumnInterval(calc: Calc, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val rexProgram = calc.getProgram
    val project = rexProgram.split().left.get(index)
    getColumnIntervalOfCalc(calc, fmq, project)
  }

  /**
   * Calculate interval of column which results from the given rex node in calc.
   * Note that this function is called by function above, and is reclusive in case
   * of "AS" rex call, and is private, too.
   */
  private def getColumnIntervalOfCalc(
      calc: Calc,
      mq: RelMetadataQuery,
      project: RexNode): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    project match {
      case call: RexCall if call.getKind == SqlKind.AS =>
        getColumnIntervalOfCalc(calc, fmq, call.getOperands.head)

      case inputRef: RexInputRef =>
        val program = calc.getProgram
        val sourceFieldIndex = inputRef.getIndex
        val inputValueInterval = fmq.getColumnInterval(calc.getInput, sourceFieldIndex)
        val condition = program.getCondition
        if (condition != null) {
          val predicate = program.expandLocalRef(program.getCondition)
          ColumnIntervalUtil.getColumnIntervalWithFilter(
            Option(inputValueInterval),
            predicate,
            sourceFieldIndex,
            calc.getCluster.getRexBuilder)
        } else {
          inputValueInterval
        }

      case literal: RexLiteral =>
        val literalValue = FlinkRelOptUtil.getLiteralValueByBroadType(literal)
        if (literalValue == null) {
          ValueInterval.empty
        } else {
          ValueInterval(literalValue, literalValue)
        }

      case rexCall: RexCall =>
        getRexNodeInterval(rexCall, calc, mq)

      case _ => null
    }
  }

  private def getRexNodeInterval(
      rexNode: RexNode,
      baseNode: SingleRel,
      mq: RelMetadataQuery): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    rexNode match {
      case inputRef: RexInputRef =>
        fmq.getColumnInterval(baseNode.getInput, inputRef.getIndex)

      case literal: RexLiteral =>
        val literalValue = FlinkRelOptUtil.getLiteralValueByBroadType(literal)
        if (literalValue == null) {
          ValueInterval.empty
        } else {
          ValueInterval(literalValue, literalValue)
        }

      case caseCall: RexCall if caseCall.getKind == SqlKind.CASE =>
        // compute all the possible result values of this case when clause,
        // the result values is the value interval
        val operands = caseCall.getOperands
        val operandCount = operands.size()
        val possibleValueIntervals = operands.indices
          // filter expressions which is condition
          .filter(i => i % 2 != 0 || i == operandCount - 1)
          .map(operands(_))
          .map(getRexNodeInterval(_, baseNode, mq))
        possibleValueIntervals.reduceLeft(ValueInterval.union)

      case searchCall: RexCall if searchCall.getKind == SqlKind.SEARCH =>
        val expanded = RexUtil.expandSearch(
          baseNode.getCluster.getRexBuilder, null, searchCall)
        getRexNodeInterval(expanded, baseNode, mq)

      // TODO supports ScalarSqlFunctions.IF
      // TODO supports CAST

      case rexCall: RexCall if rexCall.op.isInstanceOf[SqlBinaryOperator] =>
        val leftValueInterval = getRexNodeInterval(rexCall.operands.get(0), baseNode, mq)
        val rightValueInterval = getRexNodeInterval(rexCall.operands.get(1), baseNode, mq)
        ColumnIntervalUtil.getValueIntervalOfRexCall(rexCall, leftValueInterval, rightValueInterval)

      case _ => null
    }
  }

  /**
   * Gets interval of the given column on Exchange.
   *
   * @param exchange Exchange RelNode
   * @param mq       RelMetadataQuery instance
   * @param index    the index of the given column
   * @return interval of the given column on Exchange
   */
  def getColumnInterval(exchange: Exchange, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    fmq.getColumnInterval(exchange.getInput, index)
  }

  /**
   * Gets interval of the given column on Sort.
   *
   * @param sort  Sort RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on Sort
   */
  def getColumnInterval(sort: Sort, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    fmq.getColumnInterval(sort.getInput, index)
  }

  /**
   * Gets interval of the given column of Expand.
   *
   * @param expand expand RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  the index of the given column
   * @return interval of the given column in batch sort
   */
  def getColumnInterval(
      expand: Expand,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val intervals = expand.projects.flatMap { project =>
      project(index) match {
        case inputRef: RexInputRef =>
          Some(fmq.getColumnInterval(expand.getInput, inputRef.getIndex))
        case l: RexLiteral if l.getTypeName eq SqlTypeName.DECIMAL =>
          val v = l.getValueAs(classOf[JBigDecimal])
          Some(ValueInterval(v, v))
        case l: RexLiteral if l.getValue == null =>
          None
        case p@_ =>
          throw new TableException(s"Column interval can't handle $p type in expand.")
      }
    }
    if (intervals.contains(null)) {
      // null union any value interval is null
      null
    } else {
      intervals.reduce((a, b) => ValueInterval.union(a, b))
    }
  }

  /**
   * Gets interval of the given column on Rank.
   *
   * @param rank        [[Rank]] instance to analyze
   * @param mq          RelMetadataQuery instance
   * @param index       the index of the given column
   * @return interval of the given column on Rank
   */
  def getColumnInterval(
      rank: Rank,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val rankFunColumnIndex = RankUtil.getRankNumberColumnIndex(rank).getOrElse(-1)
    if (index == rankFunColumnIndex) {
      rank.rankRange match {
        case r: ConstantRankRange =>
          ValueInterval(JBigDecimal.valueOf(r.getRankStart), JBigDecimal.valueOf(r.getRankEnd))
        case v: VariableRankRange =>
          val interval = fmq.getColumnInterval(rank.getInput, v.getRankEndIndex)
          interval match {
            case hasUpper: WithUpper =>
              val lower = JBigDecimal.valueOf(1)
              ValueInterval(lower, hasUpper.upper, includeUpper = hasUpper.includeUpper)
            case _ => null
          }
      }
    } else {
      fmq.getColumnInterval(rank.getInput, index)
    }
  }

  /**
   * Gets interval of the given column on Aggregates.
   *
   * @param aggregate Aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on Aggregate
   */
  def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on TableAggregates.
   *
   * @param aggregate TableAggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on TableAggregate
   */
  def getColumnInterval(
      aggregate: TableAggregate,
      mq: RelMetadataQuery, index: Int): ValueInterval =

    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on batch group aggregate.
   *
   * @param aggregate batch group aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on batch group aggregate
   */
  def getColumnInterval(
      aggregate: BatchPhysicalGroupAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream group aggregate.
   *
   * @param aggregate stream group aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on stream group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream group table aggregate.
   *
   * @param aggregate stream group table aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on stream group TableAggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGroupTableAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream local group aggregate.
   *
   * @param aggregate stream local group aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on stream local group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalLocalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on stream global group aggregate.
   *
   * @param aggregate stream global group aggregate RelNode
   * @param mq        RelMetadataQuery instance
   * @param index     the index of the given column
   * @return interval of the given column on stream global group Aggregate
   */
  def getColumnInterval(
      aggregate: StreamPhysicalGlobalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
   * Gets interval of the given column on window aggregate.
   *
   * @param agg   window aggregate RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on window Aggregate
   */
  def getColumnInterval(
      agg: WindowAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on batch window aggregate.
   *
   * @param agg   batch window aggregate RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on batch window Aggregate
   */
  def getColumnInterval(
      agg: BatchPhysicalWindowAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on stream window aggregate.
   *
   * @param agg   stream window aggregate RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on stream window Aggregate
   */
  def getColumnInterval(
      agg: StreamPhysicalGroupWindowAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
   * Gets interval of the given column on stream window table aggregate.
   *
   * @param agg   stream window table aggregate RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on stream window Aggregate
   */
  def getColumnInterval(
      agg: StreamPhysicalGroupWindowTableAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  private def estimateColumnIntervalOfAggregate(
      aggregate: SingleRel,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val input = aggregate.getInput
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val groupSet = aggregate match {
      case agg: StreamPhysicalGroupAggregate => agg.grouping
      case agg: StreamPhysicalLocalGroupAggregate => agg.grouping
      case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping
      case agg: StreamPhysicalIncrementalGroupAggregate => agg.partialAggGrouping
      case agg: StreamPhysicalGroupWindowAggregate => agg.grouping
      case agg: BatchPhysicalGroupAggregateBase => agg.grouping ++ agg.auxGrouping
      case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
      case agg: BatchPhysicalLocalSortWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
      case agg: BatchPhysicalLocalHashWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.grouping ++ Array(agg.inputTimeFieldIndex) ++ agg.auxGrouping
      case agg: BatchPhysicalWindowAggregateBase => agg.grouping ++ agg.auxGrouping
      case agg: TableAggregate => agg.getGroupSet.toArray
      case agg: StreamPhysicalGroupTableAggregate => agg.grouping
      case agg: StreamPhysicalGroupWindowTableAggregate => agg.grouping
    }

    if (index < groupSet.length) {
      // estimates group keys according to the input relNodes.
      val sourceFieldIndex = groupSet(index)
      fmq.getColumnInterval(input, sourceFieldIndex)
    } else {
      def getAggCallFromLocalAgg(
          index: Int,
          aggCalls: Seq[AggregateCall],
          inputType: RelDataType): AggregateCall = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          aggCalls, inputType)
        if (outputIndexToAggCallIndexMap.containsKey(index)) {
          val realIndex = outputIndexToAggCallIndexMap.get(index)
          aggCalls(realIndex)
        } else {
          null
        }
      }

      def getAggCallIndexInLocalAgg(
          index: Int,
          globalAggCalls: Seq[AggregateCall],
          inputRowType: RelDataType): Integer = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          globalAggCalls, inputRowType)

        outputIndexToAggCallIndexMap.foreach {
          case (k, v) => if (v == index) {
            return k
          }
        }
        null.asInstanceOf[Integer]
      }

      if (index < groupSet.length) {
        // estimates group keys according to the input relNodes.
        val sourceFieldIndex = groupSet(index)
        fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
      } else {
        val aggCallIndex = index - groupSet.length
        val aggCall = aggregate match {
          case agg: StreamPhysicalGroupAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: StreamPhysicalGlobalGroupAggregate
            if agg.aggCalls.length > aggCallIndex =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: StreamPhysicalLocalGroupAggregate =>
            getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType)
          case agg: StreamPhysicalIncrementalGroupAggregate
            if agg.partialAggCalls.length > aggCallIndex =>
            agg.partialAggCalls(aggCallIndex)
          case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: BatchPhysicalLocalHashAggregate =>
            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
          case agg: BatchPhysicalHashAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchPhysicalLocalSortAggregate =>
            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
          case agg: BatchPhysicalSortAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchPhysicalGroupAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case agg: Aggregate =>
            val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
            if (aggCalls.length > aggCallIndex) {
              aggCalls(aggCallIndex)
            } else {
              null
            }
          case agg: BatchPhysicalWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case _ => null
        }

        if (aggCall != null) {
          aggCall.getAggregation.getKind match {
            case SUM | SUM0 =>
              val inputInterval = fmq.getColumnInterval(input, aggCall.getArgList.get(0))
              if (inputInterval != null) {
                inputInterval match {
                  case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
                    if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
                      RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
                    } else {
                      null.asInstanceOf[ValueInterval]
                    }
                  case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
                    if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
                      LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
                    } else {
                      null
                    }
                  case _ => null
                }
              } else {
                null
              }
            case COUNT =>
              RightSemiInfiniteValueInterval(JBigDecimal.valueOf(0), includeLower = true)
            // TODO add more built-in agg functions
            case _ => null
          }
        } else {
          null
        }
      }
    }
  }

  /**
   * Gets interval of the given column on calcite window.
   *
   * @param window Window RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  the index of the given column
   * @return interval of the given column on window
   */
  def getColumnInterval(
      window: Window,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    getColumnIntervalOfOverAgg(window, mq, index)
  }

  /**
   * Gets interval of the given column on batch over aggregate.
   *
   * @param agg    batch over aggregate RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  he index of the given column
   * @return interval of the given column on batch over aggregate.
   */
  def getColumnInterval(
      agg: BatchPhysicalOverAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)

  /**
   * Gets interval of the given column on stream over aggregate.
   *
   * @param agg    stream over aggregate RelNode
   * @param mq     RelMetadataQuery instance
   * @param index  he index of the given column
   * @return interval of the given column on stream over aggregate.
   */
  def getColumnInterval(
      agg: StreamPhysicalOverAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index)

  private def getColumnIntervalOfOverAgg(
      overAgg: SingleRel,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val input = overAgg.getInput
    val fieldsCountOfInput = input.getRowType.getFieldCount
    if (index < fieldsCountOfInput) {
      fmq.getColumnInterval(input, index)
    } else {
      // cannot estimate aggregate function calls columnInterval.
      null
    }
  }

  /**
   * Gets interval of the given column on Join.
   *
   * @param join  Join RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on Join
   */
  def getColumnInterval(join: Join, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val joinCondition = join.getCondition
    val nLeftColumns = join.getLeft.getRowType.getFieldCount
    val inputValueInterval = if (index < nLeftColumns) {
      fmq.getColumnInterval(join.getLeft, index)
    } else {
      fmq.getColumnInterval(join.getRight, index - nLeftColumns)
    }
    //TODO if column at index position is EuqiJoinKey in a Inner Join, its interval is
    // origin interval intersect interval in the pair joinJoinKey.
    // for example, if join is a InnerJoin with condition l.A = r.A
    // the valueInterval of l.A is the intersect of l.A with r.A
    if (joinCondition == null || joinCondition.isAlwaysTrue) {
      inputValueInterval
    } else {
      ColumnIntervalUtil.getColumnIntervalWithFilter(
        Option(inputValueInterval),
        joinCondition,
        index,
        join.getCluster.getRexBuilder)
    }
  }

  /**
   * Gets interval of the given column on Union.
   *
   * @param union Union RelNode
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return interval of the given column on Union
   */
  def getColumnInterval(union: Union, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val subIntervals = union
      .getInputs
      .map(fmq.getColumnInterval(_, index))
    subIntervals.reduceLeft(ValueInterval.union)
  }

  /**
   * Gets interval of the given column on RelSubset.
   *
   * @param subset RelSubset to analyze
   * @param mq     RelMetadataQuery instance
   * @param index  the index of the given column
   * @return If exist best relNode, then transmit to it, else transmit to the original relNode
   */
  def getColumnInterval(subset: RelSubset, mq: RelMetadataQuery, index: Int): ValueInterval = {
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val rel = Util.first(subset.getBest, subset.getOriginal)
    fmq.getColumnInterval(rel, index)
  }

  /**
   * Catches-all rule when none of the others apply.
   *
   * @param rel   RelNode to analyze
   * @param mq    RelMetadataQuery instance
   * @param index the index of the given column
   * @return Always returns null
   */
  def getColumnInterval(rel: RelNode, mq: RelMetadataQuery, index: Int): ValueInterval = null

}

object FlinkRelMdColumnInterval {

  private val INSTANCE = new FlinkRelMdColumnInterval

  val SOURCE: RelMetadataProvider = ReflectiveRelMetadataProvider.reflectiveSource(
    FlinkMetadata.ColumnInterval.METHOD, INSTANCE)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy