All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.codegen.MatchCodeGenerator.scala Maven / Gradle / Ivy

Go to download

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.codegen

import java.lang.{Long => JLong}
import java.util
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.cep.functions.PatternProcessFunction
import org.apache.flink.cep.pattern.conditions.{IterativeCondition, RichIterativeCondition}
import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.dataview.DataViewSpec
import org.apache.flink.table.api.{TableConfig, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForTypeInfo, newName, primitiveDefaultValue, primitiveTypeTermForTypeInfo}
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.functions.UserDefinedAggregateFunction
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.`match`.{IterativeConditionRunner, PatternProcessFunctionRunner}
import org.apache.flink.table.runtime.aggregate.AggregateUtil
import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo
import org.apache.flink.table.util.MatchUtil.{ALL_PATTERN_VARIABLE, AggregationPatternVariableFinder}
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.table.catalog.BasicOperatorTable.{MATCH_PROCTIME, MATCH_ROWTIME}
import org.apache.flink.types.Row
import org.apache.flink.util.Collector
import org.apache.flink.util.MathUtils.checkedDownCast

import org.apache.calcite.util.ImmutableBitSet

import scala.collection.JavaConverters._
import scala.collection.mutable

/**
  * A code generator for generating CEP related functions.
  *
  * Aggregates are generated as follows:
  * 1. all aggregate [[RexCall]]s are grouped by corresponding pattern variable
  * 2. even if the same aggregation is used multiple times in an expression
  *    (e.g. SUM(A.price) > SUM(A.price) + 1) it will be calculated once. To do so [[AggBuilder]]
  *    keeps set of already seen different aggregation calls, and reuses the code to access
  *    appropriate field of aggregation result
  * 3. after translating every expression (either in [[generateCondition]] or in
  *    [[generateOneRowPerMatchExpression]]) there will be generated code for
  *       - [[GeneratedFunction]], which will be an inner class
  *       - said [[GeneratedFunction]] will be instantiated in the ctor and opened/closed
  *         in corresponding methods of top level generated classes
  *       - function that transforms input rows (row by row) into aggregate input rows
  *       - function that calculates aggregates for variable, that uses the previous method
  *    The generated code will look similar to this:
  *
  *
  * {{{
  *
  * public class MatchRecognizePatternProcessFunction$175 extends PatternProcessFunction {
  *
  *     // Class used to calculate aggregates for a single pattern variable
  *     public final class AggFunction_variable$115$151 extends GeneratedAggregations {
  *       ...
  *     }
  *
  *     private final AggFunction_variable$115$151 aggregator_variable$115;
  *
  *     public MatchRecognizePatternSelectFunction$175() {
  *       aggregator_variable$115 = new AggFunction_variable$115$151();
  *     }
  *
  *     public void open() {
  *       aggregator_variable$115.open();
  *       ...
  *     }
  *
  *     // Function to transform incoming row into aggregate specific row. It can e.g calculate
  *     // inner expression of said aggregate
  *     private Row transformRowForAgg_variable$115(Row inAgg) {
  *         ...
  *     }
  *
  *     // Function to calculate all aggregates for a single pattern variable
  *     private Row calculateAgg_variable$115(List input) {
  *       Acc accumulator = aggregator_variable$115.createAccumulator();
  *       for (Row row : input) {
  *         aggregator_variable$115.accumulate(accumulator, transformRowForAgg_variable$115(row));
  *       }
  *
  *       return aggregator_variable$115.getResult(accumulator);
  *     }
  *
  *     @Override
  *     public void processMatch(
  *         Map> in1,
  *         Context ctx,
  *         Collector c
  *       ) throws Exception {
  *
  *       // Extract list of rows assigned to a single pattern variable
  *       java.util.List patternEvents$130 = (java.util.List) in1.get("A");
  *       ...
  *
  *       // Calculate aggregates
  *       Row aggRow_variable$110$111 = calculateAgg_variable$110(patternEvents$114);
  *
  *       // Every aggregation (e.g SUM(A.price) and AVG(A.price)) will be extracted to a variable
  *       double result$135 = aggRow_variable$126$127.getField(0);
  *       long result$137 = aggRow_variable$126$127.getField(1);
  *
  *       // Result of aggregation will be used in expression evaluation
  *       out.setField(0, result$135)
  *
  *       long result$140 = result$137 * 2;
  *       out.setField(1, result$140);
  *
  *       double result$144 = $result135 + result$137;
  *       out.setField(2, result$144);
  *
  *       c.collect(out);
  *     }
  *
  *     public void close() {
  *       aggregator_variable$115.close();
  *       ...
  *     }
  *
  * }
  * }}}
  *
  * @param config configuration that determines runtime behavior
  * @param patternNames sorted sequence of pattern variables
  * @param input type information about the first input of the Function
  * @param currentPattern if generating condition the name of pattern, which the condition will
  *                       be applied to
  */
class MatchCodeGenerator(
    config: TableConfig,
    input: TypeInformation[_],
    patternNames: Seq[String],
    currentPattern: Option[String] = None)
  extends CodeGenerator(config, false, input){

  private case class GeneratedPatternList(resultTerm: String, code: String)

  /**
    * Used to assign unique names for list of events per pattern variable name. Those lists
    * are treated as inputs and are needed by input access code.
    */
  private val reusablePatternLists: mutable.HashMap[String, GeneratedPatternList] = mutable
    .HashMap[String, GeneratedPatternList]()

  /**
    * Used to deduplicate aggregations calculation. The deduplication is performed by
    * [[RexNode.toString]]. Those expressions needs to be accessible from splits, if such exists.
    */
  private val reusableAggregationExpr = new mutable.HashMap[String, GeneratedExpression]()

  /**
    * Context information used by Pattern reference variable to index rows mapped to it.
    * Indexes element at offset either from beginning or the end based on the value of first.
    */
  private var offset: Int = 0
  private var first : Boolean = false

  /**
    * Flags that tells if we generate expressions inside an aggregate. It tells how to access input
    * row.
    */
  private var isWithinAggExprState: Boolean = false

  /**
    * Name of term in function used to transform input row into aggregate input row.
    */
  private val inputAggRowTerm = "inAgg"

  /** Term for row for key extraction */
  private val keyRowTerm = "keyRow"

  /** Term for list of all pattern names */
  private val patternNamesTerm = "patternNames"

  /**
    * Used to collect all aggregates per pattern variable.
    */
  private val aggregatesPerVariable = new mutable.HashMap[String, AggBuilder]

  /**
    * Sets the new reference variable indexing context. This should be used when resolving logical
    * offsets = LAST/FIRST
    *
    * @param first  true if indexing from the beginning, false otherwise
    * @param offset offset from either beginning or the end
    */
  private def updateOffsets(first: Boolean, offset: Int): Unit = {
    this.first = first
    this.offset = offset
  }

  /** Resets indexing context of Pattern variable. */
  private def resetOffsets(): Unit = {
    first = false
    offset = 0
  }

  private def reusePatternLists(): String = {
    reusablePatternLists.values.map(_.code).mkString("\n")
  }

  private def addReusablePatternNames() : Unit = {
    reusableMemberStatements
      .add(s"private String[] $patternNamesTerm = new String[] { ${
        patternNames.map(p => s""""${EncodingUtils.escapeJava(p)}"""").mkString(", ")
      } };")
  }

  /**
    * Generates a wrapper [[IterativeConditionRunner]] around code generated [[IterativeCondition]]
    * for a single pattern definition defined in DEFINE clause.
    *
    * @param patternDefinition pattern definition as defined in DEFINE clause
    * @return a code generated condition that can be used in constructing a
    *         [[org.apache.flink.cep.pattern.Pattern]]
    */
  def generateIterativeCondition(patternDefinition: RexNode): IterativeConditionRunner = {
    val condition = generateCondition(patternDefinition)
    val body =
      s"""
         |${condition.code}
         |return ${condition.resultTerm};
         |""".stripMargin

    val genCondition = generateMatchFunction(
        "MatchRecognizeCondition",
        classOf[RichIterativeCondition[Row]],
        body,
        condition.resultType)
    new IterativeConditionRunner(genCondition.name, genCondition.code)
  }

  /**
    * Generates a wrapper [[PatternProcessFunctionRunner]] around code generated
    * [[PatternProcessFunction]] that transform found matches into expected output as defined
    * in the MEASURES. It also accounts for fields used in PARTITION BY.
    *
    * @param returnType the schema of output row
    * @param partitionKeys keys used for partitioning incoming data, they will be included in the
    *                      output
    * @param measures definitions from MEASURE clause
    * @return a process function that can be applied to [[org.apache.flink.cep.PatternStream]]
    */
  def generateOneRowPerMatchExpression(
      returnType: RowSchema,
      partitionKeys: ImmutableBitSet,
      measures: util.Map[String, RexNode])
    : PatternProcessFunctionRunner = {
    val resultExpression = generateOneRowPerMatchExpression(
      partitionKeys,
      measures,
      returnType)
    val body =
      s"""
         |${resultExpression.code}
         |$collectorTerm.collect(${resultExpression.resultTerm});
         |""".stripMargin

    val genFunction = generateMatchFunction(
      "MatchRecognizePatternProcessFunction",
      classOf[PatternProcessFunction[Row, Row]],
      body,
      resultExpression.resultType)
    new PatternProcessFunctionRunner(genFunction.name, genFunction.code)
  }

  /**
    * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java
    * compiler.
    *
    * @param name Class name of the Function. Must not be unique but has to be a valid Java class
    *             identifier.
    * @param clazz Flink Function to be generated.
    * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or
    *                 output record can be accessed via the given term methods.
    * @param returnType expected return type
    * @tparam F Flink Function to be generated.
    * @tparam T Return type of the Flink Function.
    * @return instance of GeneratedFunction
    */
  private def generateMatchFunction[F <: Function, T <: Any](
      name: String,
      clazz: Class[F],
      bodyCode: String,
      returnType: TypeInformation[T])
    : GeneratedFunction[F, T] = {
    val funcName = newName(name)
    val collectorTypeTerm = classOf[Collector[Any]].getCanonicalName
    val (functionClass, signature, inputStatements) =
      if (clazz == classOf[RichIterativeCondition[_]]) {
        val baseClass = classOf[RichIterativeCondition[_]]
        val inputTypeTerm = boxedTypeTermForTypeInfo(input)
        val contextType = classOf[IterativeCondition.Context[_]].getCanonicalName
        // declaration: make variable accessible for separated methods
        reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
        (baseClass,
          s"boolean filter(Object _in1, $contextType $contextTerm)",
          List(s"$input1Term = ($inputTypeTerm) _in1;"))
      } else if (clazz == classOf[PatternProcessFunction[_, _]]) {
        val baseClass = classOf[PatternProcessFunction[_, _]]
        val inputTypeTerm =
          s"java.util.Map>"
        val contextTypeTerm = classOf[PatternProcessFunction.Context].getCanonicalName
        // declaration: make variable accessible for separated method
        reusableMemberStatements.add(s"private $inputTypeTerm $input1Term;")
        (baseClass,
          s"void processMatch($inputTypeTerm _in1, $contextTypeTerm $contextTerm, " +
            s"$collectorTypeTerm $collectorTerm)",
          List(s"this.$input1Term = ($inputTypeTerm) _in1;"))
      } else {
        throw new CodeGenException("Unsupported Function.")
      }

    val funcCode = j"""
      |public class $funcName extends ${functionClass.getCanonicalName} {
      |
      |  ${reuseMemberCode()}
      |
      |  public $funcName() throws Exception {
      |    ${reuseInitCode()}
      |  }
      |
      |  @Override
      |  public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
      |    ${reuseOpenCode()}
      |  }
      |
      |  @Override
      |  public $signature throws Exception {
      |    ${inputStatements.mkString("\n")}
      |    ${reusePatternLists()}
      |    ${reuseInputUnboxingCode()}
      |    ${reusePerRecordCode()}
      |    $bodyCode
      |  }
      |
      |  @Override
      |  public void close() throws Exception {
      |    ${reuseCloseCode()}
      |  }
      |}
    """.stripMargin

    GeneratedFunction(funcName, returnType, funcCode)
  }

  private def generateKeyRow() : GeneratedExpression = {
    val exp = reusableInputUnboxingExprs
      .get((keyRowTerm, 0)) match {
      case Some(expr) =>
        expr

      case None =>

        val eventTypeTerm = boxedTypeTermForTypeInfo(input)
        val nullTerm = newName("isNull")

        val keyCode = j"""
           |$eventTypeTerm $keyRowTerm = null;
           |boolean $nullTerm = true;
           |for (java.util.Map.Entry entry : $input1Term.entrySet()) {
           |  java.util.List value = (java.util.List) entry.getValue();
           |  if (value != null && value.size() > 0) {
           |    $keyRowTerm = ($eventTypeTerm) value.get(0);
           |    $nullTerm = false;
           |    break;
           |  }
           |}
           """.stripMargin

        val exp = GeneratedExpression(keyRowTerm, nullTerm, keyCode, input)
        reusableInputUnboxingExprs((keyRowTerm, 0)) = exp
        exp
    }
    exp.copy(code = NO_CODE)
  }

  /**
    * Extracts partition keys from any element of the match
    *
    * @param partitionKey partition key to be extracted
    * @return generated code for the given key
    */
  private def generatePartitionKeyAccess(
      partitionIdx: Int)
    : GeneratedExpression = {

    val keyRow = generateKeyRow()
    generateFieldAccess(keyRow, partitionIdx)
  }

  private def generateOneRowPerMatchExpression(
      partitionKeys: ImmutableBitSet,
      measures: util.Map[String, RexNode],
      returnType: RowSchema)
    : GeneratedExpression = {
    // For "ONE ROW PER MATCH", the output columns include:
    // 1) the partition columns;
    // 2) the columns defined in the measures clause.
    val resultExprs =
      partitionKeys.toList.asScala
        .map(generatePartitionKeyAccess(_)) ++
        returnType.fieldNames
          .filter(measures.containsKey(_)).map { fieldName =>
        generateExpression(measures.get(fieldName))
      }

    val exp = generateResultExpression(
      resultExprs,
      returnType.typeInfo,
      returnType.fieldNames)
    aggregatesPerVariable.values.foreach(_.generateAggFunction())
    if (hasCodeSplits) {
      makeReusableInSplits()
    }

    exp
  }

  private def generateCondition(call: RexNode): GeneratedExpression = {
    val exp = call.accept(this)
    aggregatesPerVariable.values.foreach(_.generateAggFunction())
    if (hasCodeSplits) {
      makeReusableInSplits()
    }

    exp
  }

  private def makeReusableInSplits(): Unit = {
    reusableAggregationExpr.keys.foreach(
      key =>
        reusableAggregationExpr(key) = makeReusableInSplits(reusableAggregationExpr(key)))
  }

  override def visitCall(call: RexCall): GeneratedExpression = {
    call.getOperator match {
      case PREV | NEXT =>
        val countLiteral = call.operands.get(1).asInstanceOf[RexLiteral]
        val count = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
        if (count != 0) {
          throw new TableException("Flink does not support physical offsets within partition.")
        } else {
          updateOffsets(first = false, 0)
          val exp = call.getOperands.get(0).accept(this)
          resetOffsets()
          exp
        }

      case FIRST | LAST =>
        val countLiteral = call.operands.get(1).asInstanceOf[RexLiteral]
        val offset = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
        updateOffsets(call.getOperator == FIRST, offset)
        val patternExp = call.operands.get(0).accept(this)
        resetOffsets()
        patternExp

      case FINAL =>
        call.getOperands.get(0).accept(this)

      case _ : SqlAggFunction =>

        val variable = call.accept(new AggregationPatternVariableFinder)
          .getOrElse(throw new TableException("No pattern variable specified in aggregate"))

        val matchAgg = aggregatesPerVariable.get(variable) match {
          case Some(agg) => agg
          case None =>
            val agg = new AggBuilder(variable)
            aggregatesPerVariable(variable) = agg
            agg
        }

        matchAgg.generateDeduplicatedAggAccess(call)

      case MATCH_PROCTIME =>
        generateNullLiteral(TimeIndicatorTypeInfo.PROCTIME_INDICATOR)

      case MATCH_ROWTIME =>
        generateStreamRecordRowtimeAccess()
          .copy(resultType = TimeIndicatorTypeInfo.ROWTIME_INDICATOR)

      case _ => super.visitCall(call)
    }
  }

  override private[flink] def generateProctimeTimestamp() = {
    val resultTerm = newName("result")

    val resultCode =
      j"""
         |long $resultTerm = $contextTerm.currentProcessingTime();
         |""".stripMargin
    GeneratedExpression(resultTerm, NEVER_NULL, resultCode, SqlTimeTypeInfo.TIMESTAMP)
  }

  override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
    if (isWithinAggExprState) {
      generateFieldAccess(input, inputAggRowTerm, fieldRef.getIndex)
    } else {
      if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) &&
          currentPattern.isDefined && offset == 0 && !first) {
        generateInputAccess(input, input1Term, fieldRef.getIndex)
      } else {
        generatePatternFieldRef(fieldRef)
      }
    }
  }

  private def generateDefinePatternVariableExp(
      patternName: String,
      currentPattern: String)
    : GeneratedPatternList = {
    val listName = newName("patternEvents")
    val eventTypeTerm = boxedTypeTermForTypeInfo(input)
    val eventNameTerm = newName("event")

    val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) {
      j"""
         |$listName.add($input1Term);
         """.stripMargin
    } else {
      ""
    }

    reusableMemberStatements.add(s"java.util.List $listName = new java.util.ArrayList();")
    val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
      addReusablePatternNames()
      val patternTerm = newName("pattern")
      j"""
         |$listName = new java.util.ArrayList();
         |for (String $patternTerm : $patternNamesTerm) {
         |  for ($eventTypeTerm $eventNameTerm :
         |  $contextTerm.getEventsForPattern($patternTerm)) {
         |    $listName.add($eventNameTerm);
         |  }
         |}
         """.stripMargin
    } else {
      val escapedPatternName = EncodingUtils.escapeJava(patternName)
      j"""
         |$listName = new java.util.ArrayList();
         |for ($eventTypeTerm $eventNameTerm :
         |  $contextTerm.getEventsForPattern("$escapedPatternName")) {
         |    $listName.add($eventNameTerm);
         |}
         |""".stripMargin
    }

    val code =
      j"""
         |$listCode
         |$addCurrent
       """.stripMargin

    GeneratedPatternList(listName, code)
  }

  private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
    val listName = newName("patternEvents")

    reusableMemberStatements.add(s"java.util.List $listName = new java.util.ArrayList();")
    val code = if (patternName == ALL_PATTERN_VARIABLE) {
      addReusablePatternNames()

      val patternTerm = newName("pattern")

      j"""
         |$listName = new java.util.ArrayList();
         |for (String $patternTerm : $patternNamesTerm) {
         |  java.util.List rows = (java.util.List) $input1Term.get($patternTerm);
         |  if (rows != null) {
         |    $listName.addAll(rows);
         |  }
         |}
         """.stripMargin
    } else {
      val escapedPatternName = EncodingUtils.escapeJava(patternName)
      j"""
         |$listName = (java.util.List) $input1Term.get("$escapedPatternName");
         |if ($listName == null) {
         |  $listName = java.util.Collections.emptyList();
         |}
         |""".stripMargin
    }

    GeneratedPatternList(listName, code)
  }

  private def findEventByLogicalPosition(
      patternFieldAlpha: String)
    : GeneratedExpression = {
    val rowNameTerm = newName("row")
    val eventTypeTerm = boxedTypeTermForTypeInfo(input)
    val isRowNull = newName("isRowNull")

    val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
    val resultIndex = if (first) {
      j"""$offset"""
    } else {
      j"""$listName.size() - $offset - 1"""
    }

    val funcCode =
      j"""
         |$eventTypeTerm $rowNameTerm = null;
         |boolean $isRowNull = true;
         |if ($listName.size() > $offset) {
         |  $rowNameTerm = (($eventTypeTerm) $listName.get($resultIndex));
         |  $isRowNull = false;
         |}
         |""".stripMargin

    GeneratedExpression(rowNameTerm, isRowNull, funcCode, input)
  }

  private def findEventsByPatternName(
      patternFieldAlpha: String)
    : GeneratedPatternList = {
    reusablePatternLists.get(patternFieldAlpha) match {
      case Some(expr) =>
        expr

      case None =>
        val exp = currentPattern match {
          case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
          case None => generateMeasurePatternVariableExp(patternFieldAlpha)
        }
        reusablePatternLists(patternFieldAlpha) = exp
        exp
    }
  }

  private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
    val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)
    val patternVariableRef = reusableInputUnboxingExprs
      .get((s"$escapedAlpha#$first", offset)) match {
      case Some(expr) =>
        expr

      case None =>
        val exp = findEventByLogicalPosition(fieldRef.getAlpha)
        reusableInputUnboxingExprs((s"$escapedAlpha#$first", offset)) = exp
        exp
    }

    generateFieldAccess(patternVariableRef.copy(code = NO_CODE), fieldRef.getIndex)
  }

  class AggBuilder(variable: String) {

    private val aggregates = new mutable.ListBuffer[RexCall]()

    private val variableUID = newName("variable")

    private val rowTypeTerm = "org.apache.flink.types.Row"

    private val calculateAggFuncName = s"calculateAgg_$variableUID"

    def generateDeduplicatedAggAccess(aggCall: RexCall): GeneratedExpression = {
      reusableAggregationExpr.get(aggCall.toString) match  {
        case Some(expr) =>
          expr

        case None =>
          val exp: GeneratedExpression = generateAggAccess(aggCall)
          aggregates += aggCall
          reusableAggregationExpr(aggCall.toString) = exp
          reusablePerRecordStatements += exp.code
          exp.copy(code = NO_CODE)
      }
    }

    private def generateAggAccess(aggCall: RexCall): GeneratedExpression = {
      val singleAggResultTerm = newName("result")
      val singleAggNullTerm = newName("nullTerm")
      val singleAggResultType = FlinkTypeFactory.toTypeInfo(aggCall.`type`)
      val primitiveSingleAggResultTypeTerm = primitiveTypeTermForTypeInfo(singleAggResultType)
      val boxedSingleAggResultTypeTerm = boxedTypeTermForTypeInfo(singleAggResultType)

      val allAggRowTerm = s"aggRow_$variableUID"

      val rowsForVariableCode = findEventsByPatternName(variable)
      val codeForAgg =
        j"""
           |$rowTypeTerm $allAggRowTerm = $calculateAggFuncName(${rowsForVariableCode.resultTerm});
           |""".stripMargin

      reusablePerRecordStatements += codeForAgg

      val defaultValue = primitiveDefaultValue(singleAggResultType)
      val codeForSingleAgg = if (nullCheck) {
        j"""
           |boolean $singleAggNullTerm;
           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm;
           |if ($allAggRowTerm.getField(${aggregates.size}) != null) {
           |  $singleAggResultTerm = ($boxedSingleAggResultTypeTerm) $allAggRowTerm
           |    .getField(${aggregates.size});
           |  $singleAggNullTerm = false;
           |} else {
           |  $singleAggNullTerm = true;
           |  $singleAggResultTerm = $defaultValue;
           |}
           |""".stripMargin
      } else {
        j"""
           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm =
           |    ($boxedSingleAggResultTypeTerm) $allAggRowTerm.getField(${aggregates.size});
           |""".stripMargin
      }

      reusablePerRecordStatements += codeForSingleAgg

      GeneratedExpression(singleAggResultTerm, singleAggNullTerm, NO_CODE, singleAggResultType)
    }

    def generateAggFunction(): Unit = {
      val matchAgg = extractAggregatesAndExpressions

      val aggGenerator = new AggregationCodeGenerator(
        config,
        false,
        input,
        None,
        s"AggFunction_$variableUID",
        matchAgg.inputExprs.map(r => FlinkTypeFactory.toTypeInfo(r.getType)),
        matchAgg.aggregations.map(_.aggFunction).toArray,
        matchAgg.aggregations.map(_.inputIndices).toArray,
        matchAgg.aggregations.indices.toArray,
        matchAgg.getDistinctAccMapping,
        isStateBackedDataViews = false,
        partialResults = false,
        Array.emptyIntArray,
        None,
        matchAgg.aggregations.size,
        needRetract = false,
        needMerge = false,
        needReset = false,
        None
      )
      val aggFunc = aggGenerator.generateAggregations

      reusableMemberStatements.add(aggFunc.code)

      val transformFuncName = s"transformRowForAgg_$variableUID"
      val inputTransform: String = generateAggInputExprEvaluation(
        matchAgg.inputExprs,
        transformFuncName)

      generateAggCalculation(aggFunc, transformFuncName, inputTransform)
    }

    private def extractAggregatesAndExpressions: MatchAgg = {
      val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]

      val logicalAggregates = aggregates.map(aggCall => {
        val callsWithIndices = aggCall.operands.asScala.map(innerCall => {
          inputRows.get(innerCall.toString) match {
            case Some(x) =>
              x

            case None =>
              val callWithIndex = (innerCall, inputRows.size)
              inputRows(innerCall.toString) = callWithIndex
              callWithIndex
          }
        })

        val agg = aggCall.getOperator.asInstanceOf[SqlAggFunction]
        LogicalSingleAggCall(agg,
          callsWithIndices.map(_._1.getType),
          callsWithIndices.map(_._2).toArray)
      })

      val distinctAccMap: mutable.Map[util.Set[Integer], Integer] = mutable.Map()
      val aggs = logicalAggregates.zipWithIndex.map {
        case (agg, index) =>
          val result = AggregateUtil.extractAggregateCallMetadata(
            agg.function,
            isDistinct = false, // TODO properly set once supported in Calcite
            distinctAccMap,
            new util.ArrayList[Integer](), // TODO properly set once supported in Calcite
            aggregates.length,
            input.getArity,
            agg.inputTypes,
            needRetraction = false,
            config,
            isStateBackedDataViews = false,
            index)

          SingleAggCall(
            result.aggregateFunction,
            agg.exprIndices.toArray,
            result.accumulatorSpecs,
            result.distinctAccIndex)
      }

      MatchAgg(aggs, inputRows.values.map(_._1).toSeq)
    }

    private def generateAggCalculation(
        aggFunc: GeneratedAggregationsFunction,
        transformFuncName: String,
        inputTransformFunc: String)
      : Unit = {
      val aggregatorTerm = s"aggregator_$variableUID"
      val code =
        j"""
           |private final ${aggFunc.name} $aggregatorTerm;
           |
           |$inputTransformFunc
           |
           |private $rowTypeTerm $calculateAggFuncName(java.util.List input)
           |    throws Exception {
           |  $rowTypeTerm accumulator = $aggregatorTerm.createAccumulators();
           |  for ($rowTypeTerm row : input) {
           |    $aggregatorTerm.accumulate(accumulator, $transformFuncName(row));
           |  }
           |  $rowTypeTerm result = $aggregatorTerm.createOutputRow();
           |  $aggregatorTerm.setAggregationResults(accumulator, result);
           |  return result;
           |}
         """.stripMargin

      reusableInitStatements.add(s"$aggregatorTerm = new ${aggFunc.name}();")
      reusableOpenStatements.add(s"$aggregatorTerm.open(getRuntimeContext());")
      reusableCloseStatements.add(s"$aggregatorTerm.close();")
      reusableMemberStatements.add(code)
    }

    private def generateAggInputExprEvaluation(
        inputExprs: Seq[RexNode],
        funcName: String)
      : String = {
      isWithinAggExprState = true
      val resultTerm = newName("result")
      val exprs = inputExprs.zipWithIndex.map {
        case (inputExpr, outputIndex) => {
          val expr = generateExpression(inputExpr)
          s"""
             |${expr.code}
             |if (${expr.nullTerm}) {
             |  $resultTerm.setField($outputIndex, null);
             |} else {
             |  $resultTerm.setField($outputIndex, ${expr.resultTerm});
             |}
         """.stripMargin
        }
      }.mkString("\n")
      isWithinAggExprState = false

      j"""
         |private $rowTypeTerm $funcName($rowTypeTerm $inputAggRowTerm) {
         |  $rowTypeTerm $resultTerm = new $rowTypeTerm(${inputExprs.size});
         |  $exprs
         |  return $resultTerm;
         |}
       """.stripMargin
    }

    private case class LogicalSingleAggCall(
      function: SqlAggFunction,
      inputTypes: Seq[RelDataType],
      exprIndices: Seq[Int]
    )

    private case class SingleAggCall(
      aggFunction: UserDefinedAggregateFunction[_, _],
      inputIndices: Array[Int],
      dataViews: Seq[DataViewSpec[_]],
      distinctAccIndex: Int
    )

    private case class MatchAgg(
      aggregations: Seq[SingleAggCall],
      inputExprs: Seq[RexNode]) {

      def getDistinctAccMapping: Array[(Integer, util.List[Integer])] = {
        val distinctAccMapping = mutable.Map[Integer, util.List[Integer]]()
        aggregations.map(_.distinctAccIndex).zipWithIndex.foreach {
          case (distinctAccIndex, aggIndex) =>
            distinctAccMapping
              .getOrElseUpdate(distinctAccIndex, new util.ArrayList[Integer]())
              .add(aggIndex)
        }
        distinctAccMapping.toArray
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy