package org.apache.flink.table.runtime.rank
import java.lang.{Integer => JInt}
import java.util
import java.util.function.Supplier
import java.util.{Comparator, Collection => JCollection, HashMap => JHashMap, Map => JMap}
import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.api.common.typeinfo.Types
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.runtime.state.keyed.KeyedMapState
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.codegen.{Compiler, GeneratedSorter}
import org.apache.flink.table.plan.util.RankRange
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.runtime.aggregate.CollectionBaseRowComparator
import org.apache.flink.table.runtime.functions.ExecutionContext
import org.apache.flink.table.runtime.functions.ProcessFunction.OnTimerContext
import org.apache.flink.table.runtime.sort.RecordComparator
import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.flink.table.util.LRUMap.RemovalListener
import org.apache.flink.table.util.{LRUMap, Logging}
import org.apache.flink.util.Collector
abstract class AbstractUpdateRankFunction(
inputRowType: BaseRowTypeInfo,
rowKeyType: BaseRowTypeInfo,
gSorter: GeneratedSorter,
sortKeySelector: KeySelector[BaseRow, BaseRow],
outputArity: Int,
rankRange: RankRange,
cacheSize: Long,
generateRetraction: Boolean,
tableConfig: TableConfig)
extends AbstractRankFunction(
with CheckpointedFunction
with Compiler[RecordComparator]
with Logging {
// a map state stores mapping from row key to record which is in topN
// in tuple2, f0 is the record row, f1 is the index in the list of the same sort_key
// the f1 is used to preserve the record order in the same sort_key
protected var dataState: KeyedMapState[BaseRow, BaseRow, Tuple2[BaseRow, JInt]] = _
// a sorted map stores mapping from sort key to rowkey list
protected var sortedMap: SortedMap[BaseRow] = _
protected var kvSortedMap: JMap[BaseRow, SortedMap[BaseRow]] = _
// a HashMap stores mapping from rowkey to record, a heap mirror to dataState
protected var rowKeyMap: JMap[BaseRow, RankRow] = _
protected var kvRowKeyMap: LRUMap[BaseRow, JMap[BaseRow, RankRow]] = _
protected var sortKeyComparator: Comparator[BaseRow] = _
override def open(ctx: ExecutionContext): Unit = {
val lruCacheSize: Int = Math.max(1, (cacheSize / getMaxSortMapSize).toInt)
// make sure the cached map is in a fixed size, avoid OOM
kvSortedMap = new JHashMap[BaseRow, SortedMap[BaseRow]](
kvRowKeyMap = new LRUMap[BaseRow, JMap[BaseRow, RankRow]](
new CacheRemovalListener())
LOG.info("Top{} operator is using LRU caches key-size: {}", getMaxSortMapSize, lruCacheSize)
val valueTypeInfo = new TupleTypeInfo[Tuple2[BaseRow, JInt]](inputRowType, Types.INT)
val mapStateDescriptor = new MapStateDescriptor[BaseRow, Tuple2[BaseRow, JInt]](
new BaseRowTypeInfo(rowKeyType.getFieldTypes: _*)
dataState = ctx.getKeyedMapState(mapStateDescriptor)
// metrics
registerMetric(kvSortedMap.size() * getMaxSortMapSize)
val name = gSorter.comparator.name
val code = gSorter.comparator.code
LOG.debug(s"Compiling Sorter: $name \n\n Code:\n$code")
val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
gSorter.comparator.code = null
LOG.debug("Instantiating Sorter.")
val comparator = clazz.newInstance()
comparator.init(gSorter.serializers, gSorter.comparators)
sortKeyComparator = new CollectionBaseRowComparator(comparator)
def initHeapStates(): Unit = {
requestCount += 1
val partitionKey = executionContext.currentKey()
sortedMap = kvSortedMap.get(partitionKey)
rowKeyMap = kvRowKeyMap.get(partitionKey)
if (sortedMap == null) {
sortedMap = new SortedMap(sortKeyComparator, LinkedHashSetSupplier)
rowKeyMap = new JHashMap[BaseRow, RankRow]()
kvSortedMap.put(partitionKey, sortedMap)
kvRowKeyMap.put(partitionKey, rowKeyMap)
// restore sorted map
val iter = dataState.iterator(partitionKey)
if (iter != null) {
// a temp map associate sort key to tuple2
val tempSortedMap = new JHashMap[BaseRow, util.TreeMap[Int, BaseRow]]()
while (iter.hasNext) {
val entry = iter.next()
val rowkey = entry.getKey
val recordAndInnerRank = entry.getValue
val record = recordAndInnerRank.f0
val innerRank = recordAndInnerRank.f1
rowKeyMap.put(rowkey, RankRow(record, innerRank, dirty = false))
// insert into temp sort map to preserve the record order in the same sort key
val sortKey = sortKeySelector.getKey(record)
var treeMap = tempSortedMap.get(sortKey)
if (treeMap == null) {
treeMap = new util.TreeMap[Int, BaseRow]()
tempSortedMap.put(sortKey, treeMap)
treeMap.put(innerRank, rowkey)
// build sorted map from the temp map
val tempIter = tempSortedMap.entrySet().iterator()
while (tempIter.hasNext) {
val entry = tempIter.next()
val sortKey = entry.getKey
val treeMap = entry.getValue
val treeMapIter = treeMap.entrySet().iterator()
while (treeMapIter.hasNext) {
val treeMapEntry = treeMapIter.next()
val innerRank = treeMapEntry.getKey
val recordRowKey = treeMapEntry.getValue
val size = sortedMap.put(sortKey, recordRowKey)
if (innerRank != size) {
LOG.warn("Failed to build sorted map from state, this may result in wrong result." +
s" The sort key is $sortKey, partition key is $partitionKey," +
s" treeMap is $treeMap. The expected inner rank is $innerRank," +
s" but current size is $size")
} else {
hitCount += 1
def updateInnerRank(oldSortKey: BaseRow): Unit = {
val list = sortedMap.get(oldSortKey)
if (list != null) {
val iter = list.iterator()
var innerRank = 1
while (iter.hasNext) {
val rowkey = iter.next()
val row = rowKeyMap.get(rowkey)
if (row.innerRank != innerRank) {
row.innerRank = innerRank
row.dirty = true
innerRank += 1
override def onTimer(
timestamp: Long,
ctx: OnTimerContext,
out: Collector[BaseRow]): Unit = {
if (needToCleanupState(timestamp)) {
val partitionKey = executionContext.currentKey()
// cleanup cache
override def snapshotState(ctx: FunctionSnapshotContext): Unit = {
val iter = kvRowKeyMap.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val partitionKey = entry.getKey
val currentRowKeyMap = entry.getValue
def synchronizeState(curRowKeyMap: JMap[BaseRow, RankRow]): Unit = {
val currentKey = executionContext.currentKey()
val iter = curRowKeyMap.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val key = entry.getKey
val rankRow = entry.getValue
if (rankRow.dirty) {
// should update state
dataState.add(currentKey, key, Tuple2.of(rankRow.row, rankRow.innerRank))
rankRow.dirty = false
override def initializeState(ctx: FunctionInitializationContext): Unit = {
// nothing to do
protected object LinkedHashSetSupplier extends Supplier[JCollection[BaseRow]] {
override def get(): JCollection[BaseRow] = new util.LinkedHashSet[BaseRow]()
protected class CacheRemovalListener extends RemovalListener[BaseRow, JMap[BaseRow, RankRow]] {
override def onRemoval(eldest: JMap.Entry[BaseRow, JMap[BaseRow, RankRow]]): Unit = {
val previousKey = executionContext.currentKey()
val partitionKey = eldest.getKey
val currentRowKeyMap = eldest.getValue
class RankRow(var row: BaseRow, var innerRank: Int, var dirty: Boolean)
object RankRow {
def apply(row: BaseRow, innerRank: Int, dirty: Boolean): RankRow = {
new RankRow(row, innerRank, dirty)
