All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.runtime.aggregate.RowTimeBoundedRowsOver.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.runtime.aggregate

import java.util.{ArrayList => JArrayList, List => JList}
import java.lang.{Long => JLong}
import org.apache.flink.api.common.state.{MapStateDescriptor, ValueStateDescriptor}
import org.apache.flink.api.java.typeutils.ListTypeInfo
import org.apache.flink.runtime.state.keyed.{KeyedMapState, KeyedValueState}
import org.apache.flink.table.api.types.{DataTypes, InternalType, TypeConverters}
import org.apache.flink.table.api.{TableConfig, Types}
import org.apache.flink.table.codegen.GeneratedAggsHandleFunction
import org.apache.flink.table.dataformat.{BaseRow, JoinedRow}
import org.apache.flink.table.runtime.functions.ProcessFunction.{Context, OnTimerContext}
import org.apache.flink.table.runtime.functions.{AggsHandleFunction, ExecutionContext}
import org.apache.flink.table.typeutils.{BaseRowTypeInfo, TypeUtils}
import org.apache.flink.table.util.Logging
import org.apache.flink.util.{Collector, Preconditions}

/**
  * Process Function for ROWS clause event-time bounded OVER window
  *
  * @param genAggsHandler           Generated aggregate helper function
  * @param accTypes                 accumulator types of aggregation
  * @param inputFieldTypes          input field type infos of input row
  * @param precedingOffset          preceding offset
  */
class RowTimeBoundedRowsOver(
    genAggsHandler: GeneratedAggsHandleFunction,
    accTypes: Seq[InternalType],
    inputFieldTypes: Seq[InternalType],
    precedingOffset: Long,
    rowTimeIdx: Int,
    tableConfig: TableConfig)
  extends ProcessFunctionWithCleanupState[BaseRow, BaseRow](tableConfig)
  with Logging {

  Preconditions.checkNotNull(precedingOffset)

  private var output: JoinedRow = _

  // the state which keeps the last triggering timestamp
  private var lastTriggeringTsState: KeyedValueState[BaseRow, JLong] = _

  // the state which keeps the count of data
  private var counterState: KeyedValueState[BaseRow, JLong] = _

  // the state which used to materialize the accumulator for incremental calculation
  private var accState: KeyedValueState[BaseRow, BaseRow] = _

  // the state which keeps all the data that are not expired.
  // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
  // the second element of tuple is a list that contains the entire data of all the rows belonging
  // to this time stamp.
  private var inputState: KeyedMapState[BaseRow, JLong, JList[BaseRow]] = _

  private var function: AggsHandleFunction = _

//  private val inputFieldTypeInfo = inputRowType.getFieldTypes
//  private val inputFieldTypeSerializer =
  // inputFieldTypeInfo.map(DataTypes.createInternalSerializer)

  override def open(ctx: ExecutionContext) {
    super.open(ctx)
    LOG.debug(s"Compiling AggregateHelper: ${genAggsHandler.name} \n\n" +
        s"Code:\n${genAggsHandler.code}")
    function = genAggsHandler.newInstance(ctx.getRuntimeContext.getUserCodeClassLoader)
    function.open(ctx)

    output = new JoinedRow()

    val lastTriggeringTsDescriptor = new ValueStateDescriptor[JLong](
      "lastTriggeringTsState",
      Types.LONG)
    lastTriggeringTsState = ctx.getKeyedValueState(lastTriggeringTsDescriptor)

    val dataCountStateDescriptor = new ValueStateDescriptor[JLong](
      "processedCountState",
      Types.LONG)
    counterState = ctx.getKeyedValueState(dataCountStateDescriptor)

    val accTypeInfo = new BaseRowTypeInfo(
      accTypes.map(TypeConverters.createExternalTypeInfoFromDataType): _*)
    val accStateDesc = new ValueStateDescriptor[BaseRow]("accState", accTypeInfo)
    accState = ctx.getKeyedValueState(accStateDesc)

    // input element are all binary row as they are came from network
    val inputType = new BaseRowTypeInfo(
      inputFieldTypes.map(TypeConverters.createExternalTypeInfoFromDataType): _*)
      .asInstanceOf[BaseRowTypeInfo]
    val rowListTypeInfo = new ListTypeInfo[BaseRow](inputType)
    val inputStateDesc = new MapStateDescriptor[JLong, JList[BaseRow]](
      "inputState",
      Types.LONG,
      rowListTypeInfo)
    inputState = ctx.getKeyedMapState(inputStateDesc)

    initCleanupTimeState("RowTimeBoundedRowsOverCleanupTime")
  }

  override def processElement(
      input: BaseRow,
      ctx: Context,
      out: Collector[BaseRow]): Unit = {

    val currentKey = executionContext.currentKey()
    // register state-cleanup timer
    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())

    // triggering timestamp for trigger calculation
    val triggeringTs = input.getLong(rowTimeIdx)

    var lastTriggeringTs = lastTriggeringTsState.get(currentKey)
    if (lastTriggeringTs == null) {
      lastTriggeringTs = 0L
    }

    // check if the data is expired, if not, save the data and register event time timer
    if (triggeringTs > lastTriggeringTs) {
      val data = inputState.get(currentKey, triggeringTs)
      if (null != data) {
        data.add(input)
        inputState.add(currentKey, triggeringTs, data)
      } else {
        val data = new JArrayList[BaseRow]
        data.add(input)
        inputState.add(currentKey, triggeringTs, data)
        // register event time timer
        ctx.timerService.registerEventTimeTimer(triggeringTs)
      }
    }
  }

  override def onTimer(
      timestamp: Long,
      ctx: OnTimerContext,
      out: Collector[BaseRow]): Unit = {

    val currentKey = executionContext.currentKey()

    if (isProcessingTimeTimer(ctx.asInstanceOf[OnTimerContext])) {
      if (needToCleanupState(timestamp)) {

        val keysIt = inputState.iterator(currentKey)
        var lastProcessedTime = lastTriggeringTsState.get(currentKey)
        if (lastProcessedTime == null) {
          lastProcessedTime = 0L
        }

        // is data left which has not been processed yet?
        var noRecordsToProcess = true
        while (keysIt.hasNext && noRecordsToProcess) {
          if (keysIt.next().getKey > lastProcessedTime) {
            noRecordsToProcess = false
          }
        }

        if (noRecordsToProcess) {
          // We clean the state
          cleanupState(inputState, accState, counterState, lastTriggeringTsState)
          function.cleanup()
        } else {
          // There are records left to process because a watermark has not been received yet.
          // This would only happen if the input stream has stopped. So we don't need to clean up.
          // We leave the state as it is and schedule a new cleanup timer
          registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
        }
      }
      return
    }

    // gets all window data from state for the calculation
    val inputs: JList[BaseRow] = inputState.get(currentKey, timestamp)

    if (null != inputs) {

      var dataCount = counterState.get(currentKey)
      if (dataCount == null) {
        dataCount = 0L
      }

      var accumulators = accState.get(currentKey)
      if (accumulators == null) {
        accumulators = function.createAccumulators()
      }
      // set accumulators in context first
      function.setAccumulators(accumulators)

      var retractList: JList[BaseRow] = null
      var retractTs: Long = Long.MaxValue
      var retractCnt: Int = 0
      var i = 0

      while (i < inputs.size) {
        val input = inputs.get(i)
        var retractRow: BaseRow = null
        if (dataCount >= precedingOffset) {
          if (null == retractList) {
            // find the smallest timestamp
            retractTs = Long.MaxValue
            val dataTimestampIt = inputState.iterator(currentKey)
            while (dataTimestampIt.hasNext) {
              val data = dataTimestampIt.next
              val dataTs = data.getKey
              if (dataTs < retractTs) {
                retractTs = dataTs
                // get the oldest rows to retract them
                retractList = data.getValue
              }
            }
          }

          retractRow = retractList.get(retractCnt)
          retractCnt += 1

          // remove retracted values from state
          if (retractList.size == retractCnt) {
            inputState.remove(currentKey, retractTs)
            retractList = null
            retractCnt = 0
          }
        } else {
          dataCount += 1
        }

        // retract old row from accumulators
        if (null != retractRow) {
          function.retract(retractRow)
        }

        // accumulate current row
        function.accumulate(input)

        // prepare output row
        output.replace(input, function.getValue)
        out.collect(output)

        i += 1
      }

      // update all states
      if (inputState.contains(currentKey, retractTs)) {
        if (retractCnt > 0) {
          retractList.subList(0, retractCnt).clear()
          inputState.add(currentKey, retractTs, retractList)
        }
      }
      counterState.put(currentKey, dataCount)
      // update the value of accumulators for future incremental computation
      accumulators = function.getAccumulators
      accState.put(currentKey, accumulators)
    }

    lastTriggeringTsState.put(currentKey, timestamp)

    // update cleanup timer
    registerProcessingCleanupTimer(ctx, ctx.timerService().currentProcessingTime())
  }

  override def close(): Unit = {
    function.close()
  }
}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy