All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.codegen.MatchCodeGenerator.scala Maven / Gradle / Ivy

There is a newer version: 1.5.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.codegen

import org.apache.flink.api.common.functions.Function
import org.apache.flink.cep._
import org.apache.flink.cep.pattern.conditions.{IterativeCondition, RichIterativeCondition}
import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableConfig, TableConfigOptions, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.MatchCodeGenerator._
import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.functions.sql.ProctimeSqlFunction
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.plan.util.AggregateUtil
import org.apache.flink.table.plan.util.MatchUtil.AggregationPatternVariableFinder
import org.apache.flink.table.runtime.conversion.DataStructureConverters.genToInternal
import org.apache.flink.table.runtime.functions.{AggsHandleFunction, ExecutionContextImpl}
import org.apache.flink.table.types.{DataTypes, RowType, StringType}
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.table.utils.EncodingUtils
import org.apache.flink.util.Collector
import org.apache.flink.util.MathUtils.checkedDownCast

import org.apache.calcite.rel.RelCollation
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex._
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.tools.RelBuilder

import java.lang.{Long => JLong}

import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.JavaConverters._
import _root_.scala.collection.mutable

/**
  * A code generator for generating CEP related functions.
  *
  * @param ctx the cotext of the code generator
  * @param nullableInput input(s) can be null.
  * @param nullCheck whether to do null check
  * @param patternNames sorted sequence of pattern variables
  * @param currentPattern if generating condition the name of pattern, which the condition will
  *                       be applied to
  */
class MatchCodeGenerator(
    ctx: CodeGeneratorContext,
    relBuilder: RelBuilder,
    nullableInput: Boolean,
    nullCheck: Boolean,
    patternNames: Seq[String],
    currentPattern: Option[String] = None,
    collectorTerm: String = CodeGeneratorContext.DEFAULT_COLLECTOR_TERM)
  extends ExprCodeGenerator(ctx, nullableInput, nullCheck) {

  private case class GeneratedPatternList(resultTerm: String, code: String, filled: Boolean = true)

  private case class GeneratedClassifierList(
      resultTerm: String, code: String, filled: Boolean = true)

  /**
    * Used to assign unique names for list of events per pattern variable name. Those lists
    * are treated as inputs and are needed by input access code.
    */
  private val reusablePatternLists: mutable.HashMap[(Boolean, String), GeneratedPatternList] =
    mutable.HashMap[(Boolean, String), GeneratedPatternList]()

  private val reusableClassiferList: mutable.HashMap[Boolean, GeneratedClassifierList] =
    mutable.HashMap[Boolean, GeneratedClassifierList]()

  private val reusableInputUnboxingExprs: mutable.Map[(String, Int), GeneratedExpression] =
    mutable.Map[(String, Int), GeneratedExpression]()

  private val reusablePerRecordStatements: mutable.LinkedHashSet[String] =
    mutable.LinkedHashSet[String]()

  /**
    * Used to deduplicate aggregations calculation. The deduplication is performed by
    * [[RexNode#toString]]. Those expressions needs to be accessible from splits, if such exists.
    */
  private val reusableAggregationExpr = new mutable.HashMap[String, GeneratedExpression]()

  /**
    * Context information used by Pattern reference variable to index rows mapped to it.
    * Indexes element at offset either from beginning or the end based on the value of first.
    */
  private var offset: Int = 0
  private var first : Boolean = false

  private var oneRowPerMatch: Boolean = false
  private var running: Boolean = false

  /**
    * Flags that tells if we generate expressions inside an aggregate. It tells how to access input
    * row.
    */
  private var isWithinAggExprState: Boolean = false

  /**
    * Used to collect all aggregates per pattern variable.
    */
  private val aggregatesPerVariable = new mutable.HashMap[(Boolean, String), AggBuilder]

  /**
    * Name of term in function used to transform input row into aggregate input row.
    */
  private val inputAggRowTerm = "inAgg"

  private val keyRowTerm = "keyRow"

  /**
    * @return term of pattern names
    */
  private val patternNamesTerm = newName("patternNames")

  private lazy val eventTypeTerm = boxedTypeTermForType(input1Type)

  /**
    * Sets the new reference variable indexing context. This should be used when resolving logical
    * offsets = LAST/FIRST
    *
    * @param first  true if indexing from the beginning, false otherwise
    * @param offset offset from either beginning or the end
    */
  private def updateOffsets(first: Boolean, offset: Int): Unit = {
    this.first = first
    this.offset = offset
  }

  /** Resets indexing context of Pattern variable. */
  private def resetOffsets(): Unit = {
    first = false
    offset = 0
  }

  private def setRunning(running: Boolean): Unit = {
    this.running = running
  }

  private def resetRunning(): Unit = {
    this.running = false
  }

  private def setOneRowPerMatch(oneRowPerMatch: Boolean): Unit = {
    this.oneRowPerMatch = oneRowPerMatch
  }

  private def reusePatternLists(): String = {
    reusablePatternLists.values.map(_.code).mkString("\n")
  }

  private def reuseClassifierLists(): String = {
    reusableClassiferList.values.map(_.code).mkString("\n")
  }

  private def reuseInputUnboxingCode(): String = {
    reusableInputUnboxingExprs.values.map(_.code).mkString("\n")
  }

  private def reusePerRecordCode(): String = {
    reusablePerRecordStatements.mkString("\n")
  }

  private def addReusablePatternNames(): Unit = {
    ctx.addReusableMember(s"private String[] $patternNamesTerm = new String[] { ${
        patternNames.map(p => s""""${EncodingUtils.escapeJava(p)}"""").mkString(", ")
      } };")
  }

  /**
    * Generates a [[org.apache.flink.api.common.functions.Function]] that can be passed to Java
    * compiler.
    *
    * @param name Class name of the Function. Must not be unique but has to be a valid Java class
    *             identifier.
    * @param clazz Flink Function to be generated.
    * @param bodyCode code contents of the SAM (Single Abstract Method). Inputs, collector, or
    *                 output record can be accessed via the given term methods.
    * @tparam F Flink Function to be generated.
    * @tparam T Return type of the Flink Function.
    * @return instance of GeneratedFunction
    */
  def generateMatchFunction[F <: Function, T <: Any](
      name: String,
      config: TableConfig,
      clazz: Class[F],
      bodyCode: String)
  : GeneratedClass[_] = {
    val funcName = newName(name)
    val collectorTypeTerm = classOf[Collector[Any]].getCanonicalName
    val inputTypeTerm = boxedTypeTermForType(input1Type)
    val (functionClass, signature, inputStatements, unboxingCodeSplit) =
      if (clazz == classOf[RichIterativeCondition[_]]) {
        val baseClass = classOf[RichIterativeCondition[_]]
        val contextType = classOf[IterativeCondition.Context[_]].getCanonicalName
        val unboxingCodeSplit = generateSplitFunctionCalls(
          ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
          config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
          "inputUnbox",
          "private final void",
          ctx.reuseFieldCode().length,
          defineParams = s"$inputTypeTerm $input1Term, " +
            s"${classOf[IterativeCondition.Context[_]].getCanonicalName} $contextTerm",
          callingParams = s"$input1Term, $contextTerm"
        )

        (baseClass,
          s"boolean filter(Object _in1, $contextType $contextTerm)",
          List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"),
          unboxingCodeSplit)
      } else if (clazz == classOf[RichPatternSelectFunction[_, _]]) {
        val baseClass = classOf[RichPatternSelectFunction[_, _]]
        val unboxingCodeSplit = generateSplitFunctionCalls(
          ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
          config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
          "inputUnbox",
          "private final void",
          ctx.reuseFieldCode().length,
          defineParams = s"java.util.Map> $input1Term",
          callingParams = input1Term
        )

        (baseClass,
          s"Object select(java.util.Map> $input1Term)",
          List(),
          unboxingCodeSplit)
      } else if (clazz == classOf[RichPatternFlatSelectFunction[_, _]]) {
        val baseClass = classOf[RichPatternFlatSelectFunction[_, _]]
        val unboxingCodeSplit = generateSplitFunctionCalls(
          ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
          config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
          "inputUnbox",
          "private final void",
          ctx.reuseFieldCode().length,
          defineParams = s"java.util.Map> $input1Term, " +
            s"$collectorTypeTerm $collectorTerm",
          callingParams = s"$input1Term, $collectorTerm"
        )

        (baseClass,
          s"void flatSelect(java.util.Map> $input1Term, " +
            s"$collectorTypeTerm $collectorTerm)",
          List(),
          unboxingCodeSplit)
      } else if (clazz == classOf[RichPatternTimeoutFunction[_, _]]) {
        val baseClass = classOf[RichPatternTimeoutFunction[_, _]]
        val unboxingCodeSplit = generateSplitFunctionCalls(
          ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
          config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
          "inputUnbox",
          "private final void",
          ctx.reuseFieldCode().length,
          defineParams = s"java.util.Map> $input1Term, " +
            s"long timeoutTimestamp",
          callingParams = s"$input1Term, timeoutTimestamp"
        )

        (baseClass,
          s"Object timeout(java.util.Map> $input1Term, " +
            "long timeoutTimestamp)",
          List(),
          unboxingCodeSplit)
      } else if (clazz == classOf[RichPatternFlatTimeoutFunction[_, _]]) {
        val baseClass = classOf[RichPatternFlatTimeoutFunction[_, _]]
        val unboxingCodeSplit = generateSplitFunctionCalls(
          ctx.reusableInputUnboxingExprs.values.map(_.code).toSeq,
          config.getConf.getInteger(TableConfigOptions.SQL_CODEGEN_LENGTH_MAX),
          "inputUnbox",
          "private final void",
          ctx.reuseFieldCode().length,
          defineParams = s"java.util.Map> $input1Term, " +
            s"$collectorTypeTerm $collectorTerm",
          callingParams = s"$input1Term, $collectorTerm"
        )

        (baseClass,
          s"void timeout(java.util.Map> $input1Term, " +
            s"long timeoutTimestamp, $collectorTypeTerm $collectorTerm)",
          List(),
          unboxingCodeSplit)
      } else {
        throw new CodeGenException("Unsupported Function.")
      }

    val funcCode = if (unboxingCodeSplit.isSplit) {
      j"""
        public class $funcName extends ${functionClass.getCanonicalName} {
          ${ctx.reuseMemberCode()}
          ${ctx.reuseFieldCode()}

          public $funcName(Object[] references) throws Exception {
            ${ctx.reuseInitCode()}
          }

          @Override
          public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
            ${ctx.reuseOpenCode()}
          }

          @Override
          public $signature throws Exception {
            ${inputStatements.mkString("\n")}
            ${reusePatternLists()}
            ${reuseClassifierLists()}
            ${unboxingCodeSplit.callings.mkString("\n")}
            ${ctx.reusePerRecordCode()}
            $bodyCode
          }

          ${
            unboxingCodeSplit.definitions.zip(unboxingCodeSplit.bodies) map {
              case (define, body) =>
                s"""
                   |$define throws Exception {
                   | ${ctx.reusePerRecordCode()}
                   | $body
                   |}
                   """.stripMargin
            } mkString "\n"
          }

          @Override
          public void close() throws Exception {
            ${ctx.reuseCloseCode()}
          }
        }
      """.stripMargin
    } else {
      j"""
        public class $funcName extends ${functionClass.getCanonicalName} {
          ${ctx.reuseMemberCode()}
          ${ctx.reuseFieldCode()}

          public $funcName(Object[] references) throws Exception {
            ${ctx.reuseInitCode()}
          }

          @Override
          public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
            ${ctx.reuseOpenCode()}
          }

          @Override
          public $signature throws Exception {
            ${inputStatements.mkString("\n")}
            ${reusePatternLists()}
            ${reuseClassifierLists()}
            ${ctx.reusePerRecordCode()}
            ${ctx.reuseInputUnboxingCode()}
            $bodyCode
          }

          @Override
          public void close() throws Exception {
            ${ctx.reuseCloseCode()}
          }
        }
      """.stripMargin
    }

    if (clazz == classOf[RichIterativeCondition[_]]) {
      GeneratedIterativeCondition(funcName, funcCode, ctx.references.toArray)
    } else if (clazz == classOf[RichPatternSelectFunction[_, _]]) {
      GeneratedPatternSelectFunction(funcName, funcCode, ctx.references.toArray)
    } else if (clazz == classOf[RichPatternFlatSelectFunction[_, _]]) {
      GeneratedPatternFlatSelectFunction(funcName, funcCode, ctx.references.toArray)
    } else if (clazz == classOf[RichPatternTimeoutFunction[_, _]]) {
      GeneratedPatternTimeoutFunction(funcName, funcCode, ctx.references.toArray)
    } else if (clazz == classOf[RichPatternFlatTimeoutFunction[_, _]]) {
      GeneratedPatternFlatTimeoutFunction(funcName, funcCode, ctx.references.toArray)
    } else {
      throw new CodeGenException("Unsupported Function.")
    }
  }

  def generateOneRowPerMatchExpression(
      partitionKeys: java.util.List[RexNode],
      measures: java.util.Map[String, RexNode],
      returnSchema: BaseRowSchema): GeneratedExpression = {
    setOneRowPerMatch(oneRowPerMatch = true)

    // For "ONE ROW PER MATCH", the output columns include:
    // 1) the partition columns;
    // 2) the columns defined in the measures clause.
    val resultExprs =
      partitionKeys.asScala.map { case inputRef: RexInputRef =>
        generatePartitionKeyAccess(inputRef)
      } ++ returnSchema.fieldNames.filter(measures.containsKey(_)).map { fieldName =>
        generateExpression(measures.get(fieldName))
      }

    val resultCodeGenerator = new ExprCodeGenerator(ctx, nullableInput, nullCheck)
        .bindInput(input1Type, inputTerm = input1Term)
    val resultExpression = resultCodeGenerator.generateResultExpression(
      resultExprs,
      new RowType(
        returnSchema.fieldTypeInfos,
        returnSchema.fieldNames.toArray),
      classOf[GenericRow])

    aggregatesPerVariable.values.foreach(_.generateAggFunction())

    resultExpression
  }

  def generateAllRowsPerMatchExpression(
      partitionKeys: java.util.List[RexNode],
      orderKeys: RelCollation,
      measures: java.util.Map[String, RexNode],
      returnSchema: BaseRowSchema): GeneratedExpression = {

    val patternNameTerm = newName("patternName")
    val eventNameTerm = newName("event")
    val eventNameListTerm = newName("eventList")
    val listTypeTerm = classOf[java.util.List[_]].getCanonicalName

    setOneRowPerMatch(oneRowPerMatch = false)

    // For "ALL ROWS PER MATCH", the output columns include:
    // 1) the partition columns;
    // 2) the ordering columns;
    // 3) the columns defined in the measures clause;
    // 4) any remaining columns defined of the input.
    val fieldsAccessed = mutable.Set[Int]()
    val resultExprs =
      partitionKeys.asScala.map { case inputRef: RexInputRef =>
        fieldsAccessed += inputRef.getIndex
        generateFieldAccess(ctx, input1Type, eventNameTerm, inputRef.getIndex, nullCheck)
      } ++ orderKeys.getFieldCollations.asScala.map { fieldCollation =>
        fieldsAccessed += fieldCollation.getFieldIndex
        generateFieldAccess(ctx, input1Type, eventNameTerm, fieldCollation.getFieldIndex, nullCheck)
      } ++ (0 until TypeUtils.getArity(input1Type)).filterNot(fieldsAccessed.contains).map { idx =>
        generateFieldAccess(ctx, input1Type, eventNameTerm, idx, nullCheck)
      } ++ returnSchema.fieldNames.filter(measures.containsKey(_)).map { fieldName =>
        generateExpression(measures.get(fieldName))
      }

    val resultCodeGenerator = new ExprCodeGenerator(ctx, nullableInput, nullCheck)
        .bindInput(input1Type, inputTerm = input1Term)
    val resultExpression = resultCodeGenerator.generateResultExpression(
      resultExprs,
      new RowType(
        returnSchema.fieldTypeInfos,
        returnSchema.fieldNames.toArray),
      classOf[GenericRow])

    val resultCode = {
      addReusablePatternNames()

      def fillPatternLists(): String = {
        reusablePatternLists.filterNot(_._2.filled).map { patternList =>
          val patternName = patternList._1._2
          val listName = patternList._2.resultTerm
          if (patternName == ALL_PATTERN_VARIABLE) {
            j"""
               |$listName.add($eventNameTerm);
               |""".stripMargin
          } else {
            val escapedPatternName = EncodingUtils.escapeJava(patternName)

            j"""
               |if ($patternNameTerm.equals("$escapedPatternName")) {
               |  $listName.add($eventNameTerm);
               |}
               |""".stripMargin
          }
        }.mkString("\n")
      }

      def fillClassifierLists(): String = {
        reusableClassiferList.filterNot(_._2.filled).map { classiferList =>
          val listName = classiferList._2.resultTerm
          j"""
             |$listName.add($patternNameTerm);
             |""".stripMargin
        }.mkString("\n")
      }

      j"""
         |for (String $patternNameTerm : $patternNamesTerm) {
         |  $listTypeTerm $eventNameListTerm = ($listTypeTerm) $input1Term.get($patternNameTerm);
         |  if ($eventNameListTerm != null) {
         |    for ($eventTypeTerm $eventNameTerm : $eventNameListTerm) {
         |      ${fillPatternLists()}
         |      ${fillClassifierLists()}
         |      ${reuseInputUnboxingCode()}
         |      ${reusePerRecordCode()}
         |      ${resultExpression.code}
         |      $collectorTerm.collect(${resultExpression.resultTerm});
         |    }
         |  }
         |}
         |""".stripMargin
    }

    aggregatesPerVariable.values.foreach(_.generateAggFunction())

    GeneratedExpression("", "false", resultCode, null)
  }

  def generateCondition(call: RexNode): GeneratedExpression = {
    val exp = call.accept(this)
    aggregatesPerVariable.values.foreach(_.generateAggFunction())
    exp
  }

  override def visitCall(call: RexCall): GeneratedExpression = {
    call.getOperator match {
      case PREV | NEXT =>
        val countLiteral = call.getOperands.get(1).asInstanceOf[RexLiteral]
        val count = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
        if (count != 0) {
          throw new TableException("Flink does not support physical offsets within partition.")
        } else {
          updateOffsets(first = false, 0)
          val exp = call.getOperands.get(0).accept(this)
          resetOffsets()
          exp
        }

      case FIRST | LAST =>
        val countLiteral = call.getOperands.get(1).asInstanceOf[RexLiteral]
        val offset = checkedDownCast(countLiteral.getValueAs(classOf[JLong]))
        updateOffsets(call.getOperator == FIRST, offset)
        val expr = call.operands.get(0).accept(this)
        resetOffsets()
        expr

      case CLASSIFIER => findClassifierByLogicalPosition()

      case RUNNING =>
        if (oneRowPerMatch) {
          // running is the same as final
          call.getOperands.get(0).accept(this)
        } else {
          setRunning(true)
          val expr = call.operands.get(0).accept(this)
          resetRunning()
          expr
        }

      case FINAL => call.getOperands.get(0).accept(this)

      case _: SqlAggFunction =>
        val variable = call.accept(new AggregationPatternVariableFinder)
          .getOrElse(throw new TableException("No pattern variable specified in aggregate"))

        val matchAgg = aggregatesPerVariable.get((running, variable)) match {
          case Some(agg) => agg
          case None =>
            val agg = new AggBuilder(variable)
            aggregatesPerVariable((running, variable)) = agg
            agg
        }

        matchAgg.generateDeduplicatedAggAccess(call)

      case ProctimeSqlFunction =>
          MatchCodeGenerator.generateProctimeTimestamp()

      case _ => super.visitCall(call)
    }
  }

  /**
    * Extracts partition keys from any element of the match
    *
    * @param partitionKey partition key to be extracted
    * @return generated code for the given key
    */
  private def generatePartitionKeyAccess(partitionKey: RexInputRef): GeneratedExpression = {
    val keyRow = generateKeyRow()
    generateFieldAccess(ctx, keyRow.resultType, keyRow.resultTerm, partitionKey.getIndex, nullCheck)
  }

  private def generateKeyRow(): GeneratedExpression = {
    val exp = ctx.getReusableInputUnboxingExprs(keyRowTerm, 0) match {
      case Some(expr) =>
        expr

      case None =>
        val nullTerm = newName("isNull")

        ctx.addReusableMember(s"$eventTypeTerm $keyRowTerm;")

        val keyCode =
          j"""
             |boolean $nullTerm = true;
             |for (java.util.Map.Entry entry : $input1Term.entrySet()) {
             |  java.util.List value = (java.util.List) entry.getValue();
             |  if (value != null && value.size() > 0) {
             |    $keyRowTerm = ($eventTypeTerm) value.get(0);
             |    $nullTerm = false;
             |    break;
             |  }
             |}
             |""".stripMargin

        val exp = GeneratedExpression(keyRowTerm, nullTerm, keyCode, input1Type)
        ctx.addReusableInputUnboxingExprs(keyRowTerm, 0, exp)
        exp
    }
    exp.copy(code = NO_CODE)
  }

  override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
    if (isWithinAggExprState) {
      generateFieldAccess(ctx, input1Type, inputAggRowTerm, fieldRef.getIndex, nullCheck)
    } else {
      if (fieldRef.getAlpha.equals(ALL_PATTERN_VARIABLE) &&
        currentPattern.isDefined && offset == 0 && !first) {
        generateInputAccess(
          ctx, input1Type, input1Term, fieldRef.getIndex, nullableInput, nullCheck)
      } else {
        generatePatternFieldRef(fieldRef)
      }
    }
  }

  private def generateDefinePatternVariableExp(
      patternName: String,
      currentPattern: String)
  : GeneratedPatternList = {
    val Seq(listName, eventNameTerm) = newNames(Seq("patternEvents", "event"))

    ctx.addReusableMember(s"java.util.List $listName;")

    val addCurrent = if (currentPattern == patternName || patternName == ALL_PATTERN_VARIABLE) {
      j"""
         |$listName.add($input1Term);
         |""".stripMargin
    } else {
      ""
    }

    val listCode = if (patternName == ALL_PATTERN_VARIABLE) {
      addReusablePatternNames()
      val patternTerm = newName("pattern")

      j"""
         |$listName = new java.util.ArrayList();
         |for (String $patternTerm : $patternNamesTerm) {
         |  for ($eventTypeTerm $eventNameTerm :
         |  $contextTerm.getEventsForPattern($patternTerm)) {
         |    $listName.add($eventNameTerm);
         |  }
         |}
         |""".stripMargin
    } else {
      val escapedPatternName = EncodingUtils.escapeJava(patternName)
      j"""
         |$listName = new java.util.ArrayList();
         |for ($eventTypeTerm $eventNameTerm :
         |  $contextTerm.getEventsForPattern("$escapedPatternName")) {
         |    $listName.add($eventNameTerm);
         |}
         |""".stripMargin
    }

    val code =
      j"""
         |$listCode
         |$addCurrent
         |""".stripMargin

    GeneratedPatternList(listName, code)
  }

  private def generateMeasurePatternVariableExp(patternName: String): GeneratedPatternList = {
    val Seq(listName, patternTerm) = newNames(Seq("patternEvents", "pattern"))

    ctx.addReusableMember(s"java.util.List $listName;")

    val (code, filled) = if (running) {
      val code =
        j"""
           |$listName = new java.util.ArrayList();
           |""".stripMargin
      (code, false)
    } else if (patternName == ALL_PATTERN_VARIABLE) {
      addReusablePatternNames()

      val code =
        j"""
           |$listName = new java.util.ArrayList();
           |for (String $patternTerm : $patternNamesTerm) {
           |  java.util.List rows = (java.util.List) $input1Term.get($patternTerm);
           |  if (rows != null) {
           |    $listName.addAll(rows);
           |  }
           |}
           |""".stripMargin
      (code, true)
    } else {
      val escapedPatternName = EncodingUtils.escapeJava(patternName)

      val code =
        j"""
           |$listName = (java.util.List) $input1Term.get("$escapedPatternName");
           |if ($listName == null) {
           |  $listName = java.util.Collections.emptyList();
           |}
           |""".stripMargin
      (code, true)
    }

    GeneratedPatternList(listName, code, filled)
  }

  private def findEventByLogicalPosition(
      patternFieldAlpha: String)
  : GeneratedExpression = {
    val rowNameTerm = newName("row")

    val listName = findEventsByPatternName(patternFieldAlpha).resultTerm
    val resultIndex = if (first) {
      j"""$offset"""
    } else {
      j"""$listName.size() - $offset - 1"""
    }

    ctx.addReusableMember(s"$eventTypeTerm $rowNameTerm;")

    val funcCode =
      j"""
         |$rowNameTerm = null;
         |if ($listName.size() > $offset) {
         |  $rowNameTerm = (($eventTypeTerm) $listName.get($resultIndex));
         |}
         |""".stripMargin

    GeneratedExpression(rowNameTerm, "", funcCode, input1Type)
  }

  private def findEventsByPatternName(
      patternFieldAlpha: String): GeneratedPatternList = {
    reusablePatternLists.get((running, patternFieldAlpha)) match {
      case Some(expr) =>
        expr

      case None =>
        val exp = currentPattern match {
          case Some(p) => generateDefinePatternVariableExp(patternFieldAlpha, p)
          case None => generateMeasurePatternVariableExp(patternFieldAlpha)
        }
        reusablePatternLists((running, patternFieldAlpha)) = exp
        exp
    }
  }

  private def generateDefineClassifierVariableExp(currentPattern: String)
      : GeneratedClassifierList = {
    val Seq(listName, eventNameTerm) = newNames(Seq("classifiers", "event"))

    ctx.addReusableMember(s"java.util.List $listName;")

    val addCurrent =
      j"""
         |$listName.add("$currentPattern");
         |""".stripMargin

    val listCode = {
      val patternNamesToVisit = patternNames.take(patternNames.indexOf(currentPattern) + 1)

      for (patternName <- patternNamesToVisit) yield {
        val escapedPatternName = EncodingUtils.escapeJava(patternName)
        j"""
           |for ($eventTypeTerm $eventNameTerm :
           |  $contextTerm.getEventsForPattern("$escapedPatternName")) {
           |    $listName.add("${EncodingUtils.escapeJava(patternName)}");
           |}
           |""".stripMargin
      }
    }.mkString("\n")

    val code =
      j"""
         |$listName = new java.util.ArrayList();
         |$listCode
         |$addCurrent
         |""".stripMargin

    GeneratedClassifierList(listName, code)
  }

  private def generateMeasureClassifierVariableExp(): GeneratedClassifierList = {
    val Seq(listName, patternTerm, eventNameTerm) = newNames(Seq("classifiers", "pattern", "event"))

    ctx.addReusableMember(s"java.util.List $listName;")

    val (code, filled) = if (!oneRowPerMatch && running) {
      val code =
        j"""
           |$listName = new java.util.ArrayList();
           |""".stripMargin
      (code, false)
    } else {
      addReusablePatternNames()

      val code =
        j"""
           |$listName = new java.util.ArrayList();
           |for (String $patternTerm : $patternNamesTerm) {
           |  java.util.List rows = (java.util.List)
           |    $input1Term.get(${EncodingUtils.escapeJava(patternTerm)});
           |  if (rows != null) {
           |    for ($eventTypeTerm $eventNameTerm : rows) {
           |      $listName.add($patternTerm);
           |    }
           |  }
           |}
           |""".stripMargin
      (code, true)
    }

    GeneratedClassifierList(listName, code, filled)
  }

  private def findClassifierByLogicalPosition(): GeneratedExpression = {
    val Seq(nullTerm, classifierTerm) = newNames(Seq("isNull", "classifier"))

    val listName = findClassifiers().resultTerm
    val resultIndex = if (first) {
      j"""$offset"""
    } else {
      j"""$listName.size() - $offset - 1"""
    }

    val resultType = StringType.INSTANCE
    val resultTypeTerm = primitiveTypeTermForType(resultType)
    val resultInternal = genToInternal(ctx, resultType, s"$listName.get($resultIndex)")

    ctx.addReusableMember(s"$resultTypeTerm $classifierTerm;")

    val funcCode =
      j"""
         |boolean $nullTerm = true;
         |$classifierTerm = null;
         |if ($listName.size() > $offset) {
         |  $classifierTerm = ($resultTypeTerm) $resultInternal;
         |  $nullTerm = false;
         |}
         |""".stripMargin

    GeneratedExpression(classifierTerm, nullTerm, funcCode, resultType)
  }

  private def findClassifiers(): GeneratedClassifierList = {
    reusableClassiferList.get(running) match {
      case Some(expr) =>
        expr

      case None =>
        val exp = currentPattern match {
          case Some(p) => generateDefineClassifierVariableExp(p)
          case None => generateMeasureClassifierVariableExp()
        }
        reusableClassiferList(running) = exp
        exp
    }
  }

  private def generatePatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression = {
    val escapedAlpha = EncodingUtils.escapeJava(fieldRef.getAlpha)

    val patternVariableRef = getReusableInputUnboxingExprs(s"$escapedAlpha#$first", offset) match {
      case Some(expr) =>
        expr

      case None =>
        val exp = findEventByLogicalPosition(fieldRef.getAlpha)
        addReusableInputUnboxingExprs(s"$escapedAlpha#$first", offset, exp)
        exp
    }

    generateNullableInputFieldAccess(
      ctx,
      patternVariableRef.resultType,
      patternVariableRef.resultTerm,
      fieldRef.getIndex,
      nullCheck)
  }

  private def addReusablePerRecordStatement(s: String): Unit = {
    if (running) {
      reusablePerRecordStatements.add(s)
    } else {
      ctx.addPerRecordStatement(s)
    }
  }

  private def getReusableInputUnboxingExprs(inputTerm: String, index: Int)
      : Option[GeneratedExpression] =
    if (running) {
      reusableInputUnboxingExprs.get((inputTerm, index))
    } else {
      ctx.getReusableInputUnboxingExprs(inputTerm, index)
    }

  private def addReusableInputUnboxingExprs(
      inputTerm: String, index: Int, expr: GeneratedExpression): Unit =
    if (running) {
      reusableInputUnboxingExprs((inputTerm, index)) = expr
    } else {
      ctx.addReusableInputUnboxingExprs(inputTerm, index, expr)
    }

  class AggBuilder(variable: String) {

    private val aggregates = new mutable.ListBuffer[RexCall]()

    private val variableUID = newName("variable")

    private val calculateAggFuncName = s"calculateAgg_$variableUID"

    def generateDeduplicatedAggAccess(aggCall: RexCall): GeneratedExpression = {
      reusableAggregationExpr.get(s"${aggCall.toString}#$running") match  {
        case Some(expr) =>
          expr

        case None =>
          val exp: GeneratedExpression = generateAggAccess(aggCall)
          aggregates += aggCall
          reusableAggregationExpr(s"${aggCall.toString}#$running") = exp
          addReusablePerRecordStatement(exp.code)
          exp.copy(code = NO_CODE)
      }
    }

    private def generateAggAccess(aggCall: RexCall): GeneratedExpression = {
      val singleAggResultTerm = newName("result")
      val singleAggNullTerm = newName("nullTerm")
      val singleAggResultType = FlinkTypeFactory.toTypeInfo(aggCall.`type`).toInternalType
      val primitiveSingleAggResultTypeTerm = primitiveTypeTermForType(singleAggResultType)
      val boxedSingleAggResultTypeTerm = boxedTypeTermForType(singleAggResultType)

      val allAggRowTerm = s"aggRow_$variableUID"

      val rowsForVariableCode = findEventsByPatternName(variable)
      val codeForAgg =
        j"""
           |$GENERIC_ROW $allAggRowTerm = $calculateAggFuncName(${rowsForVariableCode.resultTerm});
           |""".stripMargin

      addReusablePerRecordStatement(codeForAgg)

      val defaultValue = primitiveDefaultValue(singleAggResultType)
      val codeForSingleAgg = if (nullCheck) {
        j"""
           |boolean $singleAggNullTerm;
           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm;
           |if ($allAggRowTerm.getField(${aggregates.size}) != null) {
           |  $singleAggResultTerm = ($boxedSingleAggResultTypeTerm) $allAggRowTerm
           |    .getField(${aggregates.size});
           |  $singleAggNullTerm = false;
           |} else {
           |  $singleAggNullTerm = true;
           |  $singleAggResultTerm = $defaultValue;
           |}
           |""".stripMargin
      } else {
        j"""
           |$primitiveSingleAggResultTypeTerm $singleAggResultTerm =
           |    ($boxedSingleAggResultTypeTerm) $allAggRowTerm.getField(${aggregates.size});
           |""".stripMargin
      }

      addReusablePerRecordStatement(codeForSingleAgg)

      GeneratedExpression(singleAggResultTerm, singleAggNullTerm, NO_CODE, singleAggResultType)
    }

    def generateAggFunction(): Unit = {
      val matchAgg = extractAggregatesAndExpressions

      val aggCalls = matchAgg.aggregations.map(a => AggregateCall.create(
        a.sqlAggFunction,
        false,
        false,
        a.exprIndices,
        -1,
        a.resultType,
        a.sqlAggFunction.getName))

      val needRetraction = matchAgg.aggregations.map(_ => false).toArray

      val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
      val inputRelType = typeFactory.createStructType(
        matchAgg.inputExprs.map(_.getType),
        matchAgg.inputExprs.indices.map(i => s"TMP$i"))

      val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
        aggCalls,
        inputRelType,
        needRetraction,
        needInputCount = false,
        isStateBackendDataViews = false,
        needDistinctInfo = false)

      val inputFieldTypes = matchAgg.inputExprs
        .map(expr => FlinkTypeFactory.toInternalType(expr.getType))

      val aggsHandlerCodeGenerator = new AggsHandlerCodeGenerator(
        CodeGeneratorContext(new TableConfig, supportReference = true),
        relBuilder,
        inputFieldTypes,
        needRetract = false,
        needMerge = false,
        nullCheck = true,
        copyInputField = false)
      val generatedAggsHandler = aggsHandlerCodeGenerator.generateAggsHandler(
        s"AggFunction_$variableUID",
        aggInfoList)

      val generatedTerm = ctx.addReusableObject(generatedAggsHandler, "generatedAggHandler")
      val functionTerm = s"aggregator_$variableUID"

      val declareCode = s"private $AGGS_HANDLE_FUNCTION $functionTerm;"
      val initCode = s"$functionTerm = ($AGGS_HANDLE_FUNCTION) " +
        s"$generatedTerm.newInstance($CURRENT_CLASS_LOADER);"
      ctx.addReusableMember(declareCode, initCode)

      val transformFuncName = s"transformRowForAgg_$variableUID"
      val inputTransform: String = generateAggInputExprEvaluation(
        matchAgg.inputExprs,
        transformFuncName)

      generateAggCalculation(functionTerm, transformFuncName, inputTransform)
    }

    private def extractAggregatesAndExpressions: MatchAgg = {
      val inputRows = new mutable.LinkedHashMap[String, (RexNode, Int)]

      val singleAggregates = aggregates.map { aggCall =>
        val callsWithIndices = aggCall.operands.asScala.map(innerCall => {
          inputRows.get(innerCall.toString) match {
            case Some(x) =>
              x

            case None =>
              val callWithIndex = (innerCall, inputRows.size)
              inputRows(innerCall.toString) = callWithIndex
              callWithIndex
          }
        })

        SingleAggCall(
          aggCall.getOperator.asInstanceOf[SqlAggFunction],
          aggCall.`type`,
          callsWithIndices.map(callsWithIndice => Integer.valueOf(callsWithIndice._2)))
      }

      MatchAgg(singleAggregates, inputRows.values.map(_._1).toSeq)
    }

    private def generateAggCalculation(
        functionTerm: String,
        transformFuncName: String,
        inputTransformFunc: String): Unit = {
      val code =
        j"""
           |$inputTransformFunc
           |
           |private $GENERIC_ROW $calculateAggFuncName(java.util.List input)
           |    throws Exception {
           |  $functionTerm.setAccumulators($functionTerm.createAccumulators());
           |  for ($BASE_ROW row : input) {
           |    $functionTerm.accumulate($transformFuncName(row));
           |  }
           |  $GENERIC_ROW result = ($GENERIC_ROW) $functionTerm.getValue();
           |  return result;
           |}
           |""".stripMargin

      ctx.addReusableMember(code)

      ctx.addReusableOpenStatement(
        s"$functionTerm.open(new $EXECUTION_CONTEXT_IMPL(null, getRuntimeContext()));")
      ctx.addReusableCloseStatement(s"$functionTerm.close();")
    }

    private def generateAggInputExprEvaluation(
        inputExprs: Seq[RexNode],
        funcName: String): String = {
      isWithinAggExprState = true
      val resultTerm = newName("result")
      val exprs = inputExprs.zipWithIndex.map {
        case (inputExpr, outputIndex) =>
          val expr = generateExpression(inputExpr)
          s"""
             |  ${expr.code}
             |  if (${expr.nullTerm}) {
             |    $resultTerm.update($outputIndex, null);
             |  } else {
             |    $resultTerm.update($outputIndex, ${expr.resultTerm});
             |  }
         """.stripMargin
      }.mkString("\n")
      isWithinAggExprState = false

      j"""
         |private $GENERIC_ROW $funcName($BASE_ROW $inputAggRowTerm) {
         |  $GENERIC_ROW $resultTerm = new $GENERIC_ROW(${inputExprs.size});
         |  $exprs
         |  return $resultTerm;
         |}
         |""".stripMargin
    }

    private case class SingleAggCall(
        sqlAggFunction: SqlAggFunction,
        resultType: RelDataType,
        exprIndices: Seq[Integer]
    )

    private case class MatchAgg(
        aggregations: Seq[SingleAggCall],
        inputExprs: Seq[RexNode]
    )
  }
}

object MatchCodeGenerator {
  val ALL_PATTERN_VARIABLE = "*"

  val EXECUTION_CONTEXT_IMPL: String = className[ExecutionContextImpl]

  val AGGS_HANDLE_FUNCTION: String = className[AggsHandleFunction]

  val GENERIC_ROW: String = className[GenericRow]

  val BASE_ROW: String = className[BaseRow]

  val CURRENT_CLASS_LOADER = "Thread.currentThread().getContextClassLoader()"

  def generateProctimeTimestamp(): GeneratedExpression = {
    val resultTerm = newName("result")
    val resultCode =
      s"""
         |long $resultTerm = System.currentTimeMillis();
         |""".stripMargin.trim
    GeneratedExpression(resultTerm, NEVER_NULL, resultCode, DataTypes.TIMESTAMP)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy