All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.codegen.agg.DeclarativeAggCodeGen.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.codegen.agg

import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.api.functions.DeclarativeAggregateFunction
import org.apache.flink.table.api.types.{DataType, DataTypes, InternalType}
import org.apache.flink.table.codegen.CodeGenUtils.primitiveTypeTermForType
import org.apache.flink.table.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression}
import org.apache.flink.table.codegen.agg.AggsHandlerCodeGenerator._
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.util.AggregateInfo

/**
  * It is for code generate aggregation functions that are specified using expressions.
  * The aggregate buffer is embedded inside of a larger shared aggregation buffer.
  *
  * @param ctx the code gen context
  * @param aggInfo  the aggregate information
  * @param filterExpression filter argument access expression, none if no filter
  * @param mergedAccOffset the mergedAcc may come from local aggregate,
  *                        this is the first buffer offset in the row
  * @param aggBufferOffset  the offset in the buffers of this aggregate
  * @param aggBufferSize  the total size of aggregate buffers
  * @param inputTypes   the input field type infos
  * @param constantExprs  the constant expressions
  * @param relBuilder  the rel builder to translate expressions to calcite rex nodes
  */
class DeclarativeAggCodeGen(
    ctx: CodeGeneratorContext,
    aggInfo: AggregateInfo,
    filterExpression: Option[Expression],
    mergedAccOffset: Int,
    aggBufferOffset: Int,
    aggBufferSize: Int,
    inputTypes: Seq[InternalType],
    constantExprs: Seq[GeneratedExpression],
    relBuilder: RelBuilder)
  extends AggCodeGen {

  val function: DeclarativeAggregateFunction =
    aggInfo.function.asInstanceOf[DeclarativeAggregateFunction]

  val bufferTypes: Array[DataType] = aggInfo.externalAccTypes
  val bufferIndexes: Array[Int] = Array.range(
    aggBufferOffset, aggBufferOffset + bufferTypes.length)
  val bufferTerms: Array[String] = function.aggBufferAttributes
    .map(a => s"agg${aggInfo.aggIndex}_${a.name}").toArray
  val bufferNullTerms: Array[String] = bufferTerms.map(_ + "_isNull")

  val argIndexes: Array[Int] = aggInfo.argIndexes
  val argTypes: Array[InternalType] = {
    val types = inputTypes ++ constantExprs.map(_.resultType)
    argIndexes.map(types(_))
  }

  def createAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = {
    function.initialValuesExpressions
      .map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))
  }

  def setAccumulator(generator: ExprCodeGenerator): String = {
    val aggBufferAccesses = function.aggBufferAttributes.zipWithIndex
      .map { case (attr, index) =>
        ResolvedAggInputReference(
          attr.name, bufferIndexes(index), bufferTypes(index).toInternalType)
      }
      .map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))

    val setters = aggBufferAccesses.zipWithIndex.map {
      case (access, index) =>
        val typeTerm = primitiveTypeTermForType(access.resultType)
        val memberName = bufferTerms(index)
        val memberNullTerm = bufferNullTerms(index)
        ctx.addReusableMember(s"private $typeTerm $memberName;")
        ctx.addReusableMember(s"private boolean $memberNullTerm;")
        s"""
           |${access.copyResultTermToTargetIfChanged(ctx, memberName)};
           |$memberNullTerm = ${access.nullTerm};
         """.stripMargin
    }

    setters.mkString("\n")
  }

  override def resetAccumulator(generator: ExprCodeGenerator): String = {
    val initialExprs = function.initialValuesExpressions
      .map(expr => generator.generateExpression(expr.toRexNode(relBuilder)))
    val codes = initialExprs.zipWithIndex.map {
      case (init, index) =>
        val memberName = bufferTerms(index)
        val memberNullTerm = bufferNullTerms(index)
        s"""
           |${init.code}
           |$memberName = ${init.resultTerm};
           |$memberNullTerm = ${init.nullTerm};
         """.stripMargin
    }
    codes.mkString("\n")
  }

  def getAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = {
    bufferTypes.zipWithIndex.map { case (bufferType, index) =>
      GeneratedExpression(
        bufferTerms(index), bufferNullTerms(index), "", bufferType.toInternalType)
    }
  }

  def accumulate(generator: ExprCodeGenerator): String = {
    val resolvedExprs = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
      // called from distinct merge
      function.accumulateExpressions
        .map(_.postOrderTransform(resolveReference(isDistinctMerge = true)))
    } else {
      // called from accumulate
      function.accumulateExpressions
        .map(_.postOrderTransform(resolveReference()))
    }

    val exprs = resolvedExprs
      .map(_.toRexNode(relBuilder)) // rex nodes
      .map(generator.generateExpression) // generated expressions

    val codes = exprs.zipWithIndex.map { case (expr, index) =>
      s"""
         |${expr.code}
         |${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
         |${bufferNullTerms(index)} = ${expr.nullTerm};
       """.stripMargin
    }

    filterExpression match {
      case Some(expr) =>
        val generated = generator.generateExpression(expr.toRexNode(relBuilder))
        s"""
           |if (${generated.resultTerm}) {
           |  ${codes.mkString("\n")}
           |}
         """.stripMargin
      case None =>
        codes.mkString("\n")
    }
  }

  def retract(generator: ExprCodeGenerator): String = {
    val resolvedExprs = if (generator.input1Term.startsWith(DISTINCT_KEY_TERM)) {
      // called from distinct merge
      function.retractExpressions
        .map(_.postOrderTransform(resolveReference(isDistinctMerge = true)))
    } else {
      // called from retract
      function.retractExpressions
        .map(_.postOrderTransform(resolveReference()))
    }

    val exprs = resolvedExprs
      .map(_.toRexNode(relBuilder)) // rex nodes
      .map(generator.generateExpression) // generated expressions

    val codes = exprs.zipWithIndex.map { case (expr, index) =>
      s"""
         |${expr.code}
         |${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
         |${bufferNullTerms(index)} = ${expr.nullTerm};
       """.stripMargin
    }

    filterExpression match {
      case Some(expr) =>
        val generated = generator.generateExpression(expr.toRexNode(relBuilder))
        s"""
           |if (${generated.resultTerm}) {
           |  ${codes.mkString("\n")}
           |}
         """.stripMargin
      case None =>
        codes.mkString("\n")
    }
  }

  def merge(generator: ExprCodeGenerator): String = {
    val exprs = function.mergeExpressions
      .map(_.postOrderTransform(resolveReference(isMerge = true)))
      .map(_.toRexNode(relBuilder)) // rex nodes
      .map(generator.generateExpression) // generated expressions

    val codes = exprs.zipWithIndex.map { case (expr, index) =>
      s"""
         |${expr.code}
         |${expr.copyResultTermToTargetIfChanged(ctx, bufferTerms(index))};
         |${bufferNullTerms(index)} = ${expr.nullTerm};
       """.stripMargin
    }

    codes.mkString("\n")
  }

  def getValue(generator: ExprCodeGenerator): GeneratedExpression = {
    val resolvedGetValueExpression = function.getValueExpression
      .postOrderTransform(resolveReference())
    generator.generateExpression(resolvedGetValueExpression.toRexNode(relBuilder))
  }

  /**
    * Resolves the given expression to a [[NamedExpression]].
    *
    * @param isMerge this is called from merge() method
    * @param isDistinctMerge this is called from distinct merge method
    */
  private def resolveReference(isMerge: Boolean = false, isDistinctMerge: Boolean = false)
  : PartialFunction[Expression, Expression] = {
    case input: UnresolvedFieldReference =>
      // We always use UnresolvedFieldReference to represent reference of input field.
      // In non-merge case, the input is the operand of the aggregate function.
      // In merge case, the input is the aggregate buffers sent by local aggregate.
      if (isMerge) {
        val localIndex = function.inputAggBufferAttributes.indexOf(input)
        // in merge case, the input1 is mergedAcc
        ResolvedAggInputReference(
          input.name,
          mergedAccOffset + bufferIndexes(localIndex),
          bufferTypes(localIndex).toInternalType)
      } else {
        val localIndex = function.operands.indexOf(input)
        val inputIndex = argIndexes(localIndex)
        // index to constant
        if (inputIndex >= inputTypes.length) {
          val constantIndex = inputIndex - inputTypes.length
          val constantTerm = constantExprs(constantIndex).resultTerm
          val nullTerm = constantExprs(constantIndex).nullTerm
          val constantType = constantExprs(constantIndex).resultType
          // constant is reused as member variable
          ResolvedAggLocalReference(constantTerm, nullTerm, constantType)
        } else {
          if (isDistinctMerge) {  // this is called from distinct merge
            if (function.inputCount == 1) {
              // the distinct key is a BoxedValue
              ResolvedDistinctKeyReference(input.name, argTypes(0))
            } else {
              // the distinct key is a BaseRow
              ResolvedAggInputReference(input.name, localIndex, argTypes(localIndex))
            }
          } else {
            // the input is the inputRow
            ResolvedAggInputReference(input.name, argIndexes(localIndex), argTypes(localIndex))
          }
        }
      }

    case buffer: UnresolvedAggBufferReference =>
      val localIndex = function.aggBufferAttributes.indexOf(buffer)
      val name = bufferTerms(localIndex)
      val nullTerm = bufferNullTerms(localIndex)
      // buffer access is reused as member variable
      ResolvedAggLocalReference(name, nullTerm, bufferTypes(localIndex).toInternalType)
  }

  def checkNeededMethods(
     needAccumulate: Boolean = false,
     needRetract: Boolean = false,
     needMerge: Boolean = false,
     needReset: Boolean = false): Unit = {
    // skip the check for DeclarativeAggregateFunction for now
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy