All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.runtime.rank.RetractRankFunction.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.runtime.rank

import java.lang.Long
import java.util
import java.util.{List => JList}
import org.apache.calcite.sql.SqlKind
import org.apache.flink.api.common.functions.Comparator
import org.apache.flink.api.common.state.{MapStateDescriptor, ValueStateDescriptor}
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.{ListTypeInfo, SortedMapTypeInfo}
import org.apache.flink.runtime.state.keyed.{KeyedMapState, KeyedValueState}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.codegen.{Compiler, GeneratedSorter}
import org.apache.flink.table.plan.util.RankRange
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.dataformat.util.BaseRowUtil
import org.apache.flink.table.runtime.aggregate.LazyBaseRowComparator
import org.apache.flink.table.runtime.functions.ExecutionContext
import org.apache.flink.table.runtime.functions.ProcessFunction.Context
import org.apache.flink.table.runtime.sort.RecordComparator
import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.flink.table.util.{Logging, StateUtil}
import org.apache.flink.util.Collector

class RetractRankFunction(
    inputRowType: BaseRowTypeInfo,
    sortKeyType: BaseRowTypeInfo,
    var gSorter: GeneratedSorter,
    sortKeySelector: KeySelector[BaseRow, BaseRow],
    outputArity: Int,
    rankKind: SqlKind,
    rankRange: RankRange,
    generateRetraction: Boolean,
    tableConfig: TableConfig)
  extends AbstractRankFunction(
    tableConfig,
    rankRange,
    inputRowType,
    inputRowType.getArity,
    outputArity,
    generateRetraction)
  with Compiler[RecordComparator]
  with Logging {

  // flag to skip records with non-exist error instead to fail, true by default.
  private val lenient: Boolean = true

  @transient
  // a map state stores mapping from sort key to records list
  private var dataState: KeyedMapState[BaseRow, BaseRow, JList[BaseRow]] = _

  @transient
  // a sorted map stores mapping from sort key to records count
  private var treeMap: KeyedValueState[BaseRow, util.SortedMap[BaseRow, Long]] = _

  private val sortKeyComparator: Comparator[BaseRow] = new LazyBaseRowComparator(
    gSorter.comparator.name, gSorter.comparator.code, gSorter.serializers, gSorter.comparators)

  override def open(ctx: ExecutionContext): Unit = {
    super.open(ctx)
    val valueTypeInfo = new ListTypeInfo[BaseRow](
        inputRowType.asInstanceOf[BaseRowTypeInfo])
    val mapStateDescriptor = new MapStateDescriptor[BaseRow, JList[BaseRow]](
      "data-state",
      new BaseRowTypeInfo(sortKeyType.getFieldTypes: _*)
        .asInstanceOf[BaseRowTypeInfo],
      valueTypeInfo)
    dataState = ctx.getKeyedMapState(mapStateDescriptor)

    val valueStateDescriptor = new ValueStateDescriptor[util.SortedMap[BaseRow, Long]](
      "sorted-map",
      new SortedMapTypeInfo(
        new BaseRowTypeInfo(sortKeyType.getFieldTypes: _*)
          .asInstanceOf[BaseRowTypeInfo],
        BasicTypeInfo.LONG_TYPE_INFO,
        sortKeyComparator))
    treeMap = ctx.getKeyedValueState(valueStateDescriptor)

    gSorter = null
  }

  override def processElement(
    inputBaseRow: BaseRow,
    ctx: Context,
    out: Collector[BaseRow]): Unit = {

    initRankEnd(inputBaseRow)

    val currentKey = executionContext.currentKey()
    var sortedMap = treeMap.get(currentKey)
    if (sortedMap == null) {
      sortedMap = new util.TreeMap[BaseRow, Long](sortKeyComparator)
    }

    val sortKey = sortKeySelector.getKey(inputBaseRow)

    if (BaseRowUtil.isAccumulateMsg(inputBaseRow)) {
      // update sortedMap
      if (sortedMap.containsKey(sortKey)) {
        sortedMap.put(sortKey, sortedMap.get(sortKey) + 1)
      } else {
        sortedMap.put(sortKey, 1L)
      }

      // emit
      rankKind match {
        case SqlKind.ROW_NUMBER =>
          emitRecordsWithRowNumber(sortedMap, sortKey, inputBaseRow, out)

        case SqlKind.RANK => ???
        case SqlKind.DENSE_RANK => ???
      }

      // update data state
      var inputs = dataState.get(currentKey, sortKey)
      if (inputs == null) {
        // the sort key is never seen
        inputs = new util.ArrayList[BaseRow]()
      }
      inputs.add(inputBaseRow)
      dataState.add(currentKey, sortKey, inputs)
    }
    // retract input
    else {

      // emit updates first
      rankKind match {
        case SqlKind.ROW_NUMBER =>
          retractRecordWithRowNumber(sortedMap, sortKey, inputBaseRow, out)

        case SqlKind.RANK => ???
        case SqlKind.DENSE_RANK => ???
      }

      // and then update sortedMap
      if (sortedMap.containsKey(sortKey)) {
        val count = sortedMap.get(sortKey) - 1
        if (count == 0) {
          sortedMap.remove(sortKey)
        } else {
          sortedMap.put(sortKey, count)
        }
      } else {
        if (sortedMap.isEmpty) {
          if (lenient) {
            LOG.warn(StateUtil.STATE_CLEARED_WARN_MSG)
          } else {
            throw new RuntimeException(StateUtil.STATE_CLEARED_WARN_MSG)
          }
        } else {
          throw new RuntimeException(
            s"Can not retract a non-existent record: ${inputBaseRow.toString}. " +
              s"This should never happen.")
        }
      }

      // we have updated the data state in retractRecordWithRowNumber(...)
    }

    treeMap.put(currentKey, sortedMap)
  }

  // ------------- ROW_NUMBER-------------------------------

  private def retractRecordWithRowNumber(
    sortedMap: util.SortedMap[BaseRow, Long],
    sortKey: BaseRow,
    inputRow: BaseRow,
    out: Collector[BaseRow]): Unit = {

    val iterator = sortedMap.entrySet().iterator()
    var curRank = 0L
    var needUpdate = false
    val currentKey = executionContext.currentKey()
    while (iterator.hasNext && isInRankEnd(curRank)) {
      val entry = iterator.next()
      val key = entry.getKey
      if (!needUpdate && key.equals(sortKey)) {
        val inputs = dataState.get(currentKey, key)
        if (inputs == null) {
          // Skip the data if it's state is cleared because of state ttl.
          if (lenient) {
            LOG.warn(StateUtil.STATE_CLEARED_WARN_MSG)
          } else {
            throw new RuntimeException(StateUtil.STATE_CLEARED_WARN_MSG)
          }
        } else {
          val inputIter = inputs.iterator()
          while (inputIter.hasNext && isInRankEnd(curRank)) {
            curRank += 1
            val prevRow = inputIter.next()
            if (!needUpdate && equaliser.equalsWithoutHeader(prevRow, inputRow)) {
              delete(out, prevRow, curRank)
              curRank -= 1
              needUpdate = true
              inputIter.remove()
            } else if (needUpdate) {
              retract(out, prevRow, curRank + 1)
              collect(out, prevRow, curRank)
            }
          }
          if (inputs.isEmpty) {
            dataState.remove(currentKey, key)
          } else {
            dataState.add(currentKey, key, inputs)
          }
        }
      } else if (needUpdate) {
        val inputs = dataState.get(currentKey, entry.getKey)
        var i = 0
        while (i < inputs.size() && isInRankEnd(curRank)) {
          curRank += 1
          val prevRow = inputs.get(i)
          retract(out, prevRow, curRank + 1)
          collect(out, prevRow, curRank)
          i += 1
        }
      } else {
        curRank += entry.getValue
      }
    }

  }

  private def emitRecordsWithRowNumber(
      sortedMap: util.SortedMap[BaseRow, Long],
      sortKey: BaseRow,
      inputRow: BaseRow,
      out: Collector[BaseRow]): Unit = {

    val iterator = sortedMap.entrySet().iterator()
    var curRank = 0L
    var needUpdate = false
    val currentKey = executionContext.currentKey()
    while (iterator.hasNext && isInRankEnd(curRank)) {
      val entry = iterator.next()
      if (!needUpdate && entry.getKey.equals(sortKey)) {
        curRank += entry.getValue
        collect(out, inputRow, curRank)
        needUpdate = true
      } else if (needUpdate) {
        val inputs = dataState.get(currentKey, entry.getKey)
        if (inputs == null) {
          // Skip the data if it's state is cleared because of state ttl.
          if (lenient) {
            LOG.warn(StateUtil.STATE_CLEARED_WARN_MSG)
          } else {
            throw new RuntimeException(StateUtil.STATE_CLEARED_WARN_MSG)
          }
        } else {
          var i = 0
          while (i < inputs.size() && isInRankEnd(curRank)) {
            curRank += 1
            val prevRow = inputs.get(i)
            retract(out, prevRow, curRank - 1)
            collect(out, prevRow, curRank)
            i += 1
          }
        }
      } else {
        curRank += entry.getValue
      }
    }
  }

  // just let it go, retract rank has no interest in this
  override def getMaxSortMapSize: scala.Long = 0L

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy