All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.codegen.CodeGenUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.codegen

import java.lang.reflect.Method
import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong, Short => JShort}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Time, Timestamp}
import java.util.concurrent.atomic.AtomicInteger

import org.apache.calcite.avatica.util.ByteString
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.`type`.SqlTypeName.{ROW => _, _}
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.calcite.util.BuiltInMethod
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, MapFunction}
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo._
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.table.api.types._
import org.apache.flink.table.codegen.CodeGeneratorContext.{BASE_ROW_UTIL, BINARY_STRING}
import org.apache.flink.table.codegen.GeneratedExpression.NEVER_NULL
import org.apache.flink.table.codegen.calls.ScalarOperators._
import org.apache.flink.table.codegen.calls.{BinaryStringCallGen, BuiltInMethods, CurrentTimePointCallGen, FunctionGenerator}
import org.apache.flink.table.dataformat._
import org.apache.flink.table.errorcode.TableErrors
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.functions.sql.internal.{SqlRuntimeFilterBuilderFunction, SqlRuntimeFilterFunction, SqlThrowExceptionFunction}
import org.apache.flink.table.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval}
import org.apache.flink.table.typeutils._
import org.apache.flink.table.util.Logging.CODE_LOG
import org.apache.flink.types.Row
import org.apache.flink.util.StringUtils

import org.codehaus.commons.compiler.{CompileException, ICookable}
import org.codehaus.janino.SimpleCompiler

import scala.collection.mutable.ListBuffer

object CodeGenUtils {

  private val nameCounter = new AtomicInteger

  def newName(name: String): String = {
    s"$name$$${nameCounter.getAndIncrement}"
  }

  def newNames(names: Seq[String]): Seq[String] = {
    require(names.toSet.size == names.length, "Duplicated names")
    val newId = nameCounter.getAndIncrement
    names.map(name => s"$name$$$newId")
  }

  /**
    * Retrieve the canonical name of a class type.
    */
  def className[T](implicit m: Manifest[T]): String = m.runtimeClass.getCanonicalName

  def needCopyForType(t: InternalType): Boolean = t match {
    case DataTypes.STRING => true
    case _: ArrayType => true
    case _: MapType => true
    case _: RowType => true
    case _: GenericType[_] => true
    case _ => false
  }

  def needCloneRefForType(t: InternalType): Boolean = t match {
    case DataTypes.STRING => true
    case _ => false
  }

  def needCloneRefForDataType(t: DataType): Boolean =
    TypeConverters.createExternalTypeInfoFromDataType(t) match {
      case BinaryStringTypeInfo.INSTANCE => true
      case _ => false
  }

  // when casting we first need to unbox Primitives, for example,
  // float a = 1.0f;
  // byte b = (byte) a;
  // works, but for boxed types we need this:
  // Float a = 1.0f;
  // Byte b = (byte)(float) a;
  def primitiveTypeTermForType(t: InternalType): String = t match {
    case DataTypes.INT => "int"
    case DataTypes.LONG => "long"
    case DataTypes.SHORT => "short"
    case DataTypes.BYTE => "byte"
    case DataTypes.FLOAT => "float"
    case DataTypes.DOUBLE => "double"
    case DataTypes.BOOLEAN => "boolean"
    case DataTypes.CHAR => "char"

    case _: DateType => "int"
    case DataTypes.TIME => "int"
    case _: TimestampType => "long"

    case DataTypes.INTERVAL_MONTHS => "int"
    case DataTypes.INTERVAL_MILLIS => "long"

    case _ => boxedTypeTermForType(t)
  }

  def isInternalPrimitive(tpe: InternalType): Boolean = {
    // now, only temporal type use primitive for representation
    isTemporal(tpe)
  }

  def externalBoxedTermForType(t: DataType): String = t match {
    case DataTypes.STRING => classOf[String].getCanonicalName
    case _: DecimalType => classOf[JBigDecimal].getCanonicalName
    case at: ArrayType if at.isPrimitive =>
      s"${primitiveTypeTermForType(at.getElementInternalType)}[]"
    case at: ArrayType => s"${externalBoxedTermForType(at.getElementType)}[]"
    case bt: RowType => classOf[Row].getCanonicalName
    case _: MapType => classOf[java.util.Map[_, _]].getCanonicalName
    case _: TimestampType if t != DataTypes.INTERVAL_MILLIS => classOf[Timestamp].getCanonicalName
    case _: DateType if t != DataTypes.INTERVAL_MONTHS => classOf[Date].getCanonicalName
    case DataTypes.TIME => classOf[Time].getCanonicalName
    case it: InternalType => boxedTypeTermForType(it)
    case wt: TypeInfoWrappedDataType => wt.getTypeInfo match {
      // From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections
      // does not seem to like this, so we manually give the correct type here.
      case INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]"
      case LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]"
      case SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]"
      case BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]"
      case FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]"
      case DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]"
      case BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]"
      case CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]"
      case _ => wt.getTypeInfo.getTypeClass.getCanonicalName
    }
  }

  def boxedTypeTermForType(t: InternalType): String = t match {
    case DataTypes.INT => classOf[JInt].getCanonicalName
    case DataTypes.LONG => classOf[JLong].getCanonicalName
    case DataTypes.SHORT => classOf[JShort].getCanonicalName
    case DataTypes.BYTE => classOf[JByte].getCanonicalName
    case DataTypes.FLOAT => classOf[JFloat].getCanonicalName
    case DataTypes.DOUBLE => classOf[JDouble].getCanonicalName
    case DataTypes.BOOLEAN => classOf[JBoolean].getCanonicalName
    case DataTypes.CHAR => classOf[JChar].getCanonicalName

    case _: DateType => boxedTypeTermForType(DataTypes.INT)
    case DataTypes.TIME => boxedTypeTermForType(DataTypes.INT)
    case _: TimestampType => boxedTypeTermForType(DataTypes.LONG)

    case DataTypes.STRING => BINARY_STRING
    case DataTypes.BYTE_ARRAY => "byte[]"
    case _: DecimalType => classOf[Decimal].getCanonicalName
    case _: ArrayType => classOf[BaseArray].getCanonicalName
    case _: MapType => classOf[BaseMap].getCanonicalName
    case _: RowType => classOf[BaseRow].getCanonicalName

    case gt: GenericType[_] => gt.getTypeInfo.getTypeClass.getCanonicalName
  }

  def primitiveDefaultValue(t: InternalType): String = t match {
    case DataTypes.INT | DataTypes.BYTE | DataTypes.SHORT => "-1"
    case DataTypes.LONG => "-1L"
    case DataTypes.FLOAT => "-1.0f"
    case DataTypes.DOUBLE => "-1.0d"
    case DataTypes.BOOLEAN => "false"
    case DataTypes.STRING => s"$BINARY_STRING.EMPTY_UTF8"
    case DataTypes.CHAR => "'\\0'"

    case _: DateType | DataTypes.TIME => "-1"
    case _: TimestampType => "-1L"

    case _ => "null"
  }

  /**
    * If it's internally compatible, don't need to DataStructure converter.
    * clazz != classOf[Row] => Row can only infer GenericType[Row].
    */
  def isInternalClass(clazz: Class[_], t: DataType): Boolean =
    clazz != classOf[Object] && clazz != classOf[Row] &&
      (classOf[BaseRow].isAssignableFrom(clazz) ||
        clazz == TypeConverters.createInternalTypeInfoFromDataType(t).getTypeClass)

  def qualifyMethod(method: Method): String =
    method.getDeclaringClass.getCanonicalName + "." + method.getName

  def qualifyEnum(enum: Enum[_]): String =
    enum.getClass.getCanonicalName + "." + enum.name()

  def internalToStringCode(t: InternalType,
                           resultTerm: String,
                           zoneTerm: String): String =
    t match {
      case DataTypes.DATE =>
        s"${qualifyMethod(BuiltInMethod.UNIX_DATE_TO_STRING.method)}($resultTerm)"
      case DataTypes.TIME =>
        s"${qualifyMethod(BuiltInMethods.UNIX_TIME_TO_STRING)}($resultTerm)"
      case _: TimestampType =>
        s"""${qualifyMethod(BuiltInMethods.TIMESTAMP_TO_STRING)}($resultTerm, 3, $zoneTerm)"""
    }

  def compareEnum(term: String, enum: Enum[_]): Boolean = term == qualifyEnum(enum)

  def getEnum(genExpr: GeneratedExpression): Enum[_] = {
    val split = genExpr.resultTerm.split('.')
    val value = split.last
    enumValueOf(genExpr.resultType.asInstanceOf[GenericType[_]].getTypeInfo.getTypeClass, value)
  }

  def enumValueOf[T <: Enum[T]](cls: Class[_], stringValue: String): Enum[_] =
    Enum.valueOf(cls.asInstanceOf[Class[T]], stringValue).asInstanceOf[Enum[_]]

  // ----------------------------------------------------------------------------------------------

  def requireNumeric(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isNumeric(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Numeric expression type expected, but was " + s"'${genExpr.resultType}'.",
          operatorName))
    }

  def requireComparable(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isComparable(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          s"Comparable type expected, but was '${genExpr.resultType}'.",
          operatorName))
    }

  def requireString(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isString(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "String expression type expected.",
          operatorName))
    }

  def requireBoolean(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isBoolean(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Boolean expression type expected.",
          operatorName))
    }

  def requireTemporal(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isTemporal(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Temporal expression type expected.",
          operatorName))
    }

  def requireTimeInterval(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isTimeInterval(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Interval expression type expected.",
          operatorName))
    }

  def requireArray(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isArray(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Array expression type expected.",
          operatorName))
    }

  def requireMap(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isMap(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Array expression type expected.",
          operatorName))
    }

  def requireInteger(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isInteger(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "Integer expression type expected.",
          operatorName))
    }

  def requireList(genExpr: GeneratedExpression, operatorName: String): Unit =
    if (!TypeCheckUtils.isList(genExpr.resultType)) {
      throw new CodeGenException(
        TableErrors.INST.sqlCodeGenOperatorParamError(
          "List expression type expected.",
          operatorName))
    }

  def generateNullLiteral(
      resultType: InternalType,
      nullCheck: Boolean): GeneratedExpression = {
    val defaultValue = primitiveDefaultValue(resultType)
    val resultTypeTerm = primitiveTypeTermForType(resultType)
    if (nullCheck) {
      GeneratedExpression(
        s"(($resultTypeTerm)$defaultValue)",
        "true",
        "",
        resultType,
        literal = true)
    } else {
      throw new CodeGenException("Null literals are not allowed if nullCheck is disabled.")
    }
  }

  def generateNonNullLiteral(
      literalType: InternalType,
      literalCode: String,
      literalValue: Any,
      nullCheck: Boolean): GeneratedExpression = {
    val resultTypeTerm = primitiveTypeTermForType(literalType)
    GeneratedExpression(
      s"(($resultTypeTerm)$literalCode)",
      "false",
      "",
      literalType,
      literal = true,
      literalValue = literalValue)
  }

  def generateLiteral(
      ctx: CodeGeneratorContext,
      literalRelDataType: RelDataType,
      literalInternalType: InternalType,
      literalValue: Any,
      nullCheck: Boolean): GeneratedExpression = {
    if (literalValue == null) {
      return generateNullLiteral(literalInternalType, nullCheck)
    }
    // non-null values
    literalRelDataType.getSqlTypeName match {

      case BOOLEAN =>
        generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)

      case TINYINT =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        generateNonNullLiteral(
          literalInternalType,
          decimal.byteValue().toString,
          decimal.byteValue(), nullCheck)

      case SMALLINT =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        generateNonNullLiteral(
          literalInternalType,
          decimal.shortValue().toString,
          decimal.shortValue(), nullCheck)

      case INTEGER =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        generateNonNullLiteral(
          literalInternalType,
          decimal.intValue().toString,
          decimal.intValue(), nullCheck)

      case BIGINT =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        generateNonNullLiteral(
          literalInternalType,
          decimal.longValue().toString + "L",
          decimal.longValue(), nullCheck)

      case FLOAT =>
        val floatValue = literalValue.asInstanceOf[JBigDecimal].floatValue()
        floatValue match {
          case Float.NaN => generateNonNullLiteral(
            literalInternalType, "java.lang.Float.NaN", Float.NaN, nullCheck)
          case Float.NegativeInfinity =>
            generateNonNullLiteral(
              literalInternalType,
              "java.lang.Float.NEGATIVE_INFINITY",
              Float.NegativeInfinity, nullCheck)
          case Float.PositiveInfinity => generateNonNullLiteral(
            literalInternalType,
            "java.lang.Float.POSITIVE_INFINITY",
            Float.PositiveInfinity, nullCheck)
          case _ => generateNonNullLiteral(
            literalInternalType,
            floatValue.toString + "f",
            floatValue,
            nullCheck)
        }

      case DOUBLE =>
        val doubleValue = literalValue.asInstanceOf[JBigDecimal].doubleValue()
        doubleValue match {
          case Double.NaN => generateNonNullLiteral(
            literalInternalType, "java.lang.Double.NaN", Double.NaN, nullCheck)
          case Double.NegativeInfinity =>
            generateNonNullLiteral(
              literalInternalType,
              "java.lang.Double.NEGATIVE_INFINITY",
              Double.NegativeInfinity, nullCheck)
          case Double.PositiveInfinity =>
            generateNonNullLiteral(
              literalInternalType,
              "java.lang.Double.POSITIVE_INFINITY",
              Double.PositiveInfinity, nullCheck)
          case _ => generateNonNullLiteral(
            literalInternalType, doubleValue.toString + "d", doubleValue, nullCheck)
        }
      case DECIMAL =>
        val precision = literalRelDataType.getPrecision
        val scale = literalRelDataType.getScale
        val fieldTerm = newName("decimal")
        val fieldDecimal =
          s"""
             |${classOf[Decimal].getCanonicalName} $fieldTerm =
             |    ${Decimal.Ref.castFrom}("${literalValue.toString}", $precision, $scale);
             |""".stripMargin
        ctx.addReusableMember(fieldDecimal)
        generateNonNullLiteral(
          literalInternalType,
          fieldTerm,
          Decimal.fromBigDecimal(literalValue.asInstanceOf[JBigDecimal], precision, scale),
          nullCheck)

      case VARCHAR | CHAR =>
        val escapedValue = StringEscapeUtils.ESCAPE_JAVA.translate(literalValue.toString)
        val field = ctx.addReusableStringConstants(escapedValue)
        generateNonNullLiteral(
          literalInternalType,
          field,
          BinaryString.fromString(escapedValue),
          nullCheck)
      case VARBINARY | BINARY =>
        val bytesVal = literalValue.asInstanceOf[ByteString].getBytes
        val fieldTerm = ctx.addReusableObject(bytesVal, "binary",
                                              bytesVal.getClass.getCanonicalName)
        generateNonNullLiteral(
          literalInternalType,
          fieldTerm,
          BinaryString.fromBytes(bytesVal),
          nullCheck)
      case SYMBOL =>
        generateSymbol(literalValue.asInstanceOf[Enum[_]])

      case DATE =>
        generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)

      case TIME =>
        generateNonNullLiteral(literalInternalType, literalValue.toString, literalValue, nullCheck)

      case TIMESTAMP =>
        // Hack
        // Currently, in RexLiteral/SqlLiteral(Calcite), TimestampString has no time zone.
        // TimeString, DateString TimestampString are treated as UTC time/(unix time)
        // when they are converted/formatted/validated
        // Here, we adjust millis before Calcite solve TimeZone perfectly
        val millis = literalValue.asInstanceOf[Long]
        val adjustedValue = millis - ctx.getTableConfig.getTimeZone.getOffset(millis)
        generateNonNullLiteral(
          literalInternalType, adjustedValue.toString + "L", literalValue, nullCheck)
      case typeName if YEAR_INTERVAL_TYPES.contains(typeName) =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        if (decimal.isValidInt) {
          generateNonNullLiteral(
            literalInternalType,
            decimal.intValue().toString,
            decimal.intValue(), nullCheck)
        } else {
          throw new CodeGenException(
            s"Decimal '$decimal' can not be converted to interval of months.")
        }

      case typeName if DAY_INTERVAL_TYPES.contains(typeName) =>
        val decimal = BigDecimal(literalValue.asInstanceOf[JBigDecimal])
        if (decimal.isValidLong) {
          generateNonNullLiteral(
            literalInternalType,
            decimal.longValue().toString + "L",
            decimal.longValue(), nullCheck)
        } else {
          throw new CodeGenException(
            s"Decimal '$decimal' can not be converted to interval of milliseconds.")
        }

      case t@_ =>
        throw new CodeGenException(s"Type not supported: $t")
    }
  }

  def generateNonNullField(
      t: InternalType,
      code: String,
      nullCheck: Boolean): GeneratedExpression = {
    GeneratedExpression(s"((${primitiveTypeTermForType(t)}) $code)", "false", "", t)
  }

  def generateSymbol(enum: Enum[_]): GeneratedExpression =
    GeneratedExpression(qualifyEnum(enum), "false", "", new GenericType(enum.getDeclaringClass))

  def generateProctimeTimestamp(
    contextTerm: String,
    ctx: CodeGeneratorContext): GeneratedExpression = {
    val resultTerm = ctx.newReusableField("result", "long")
    val resultCode =
      s"""
         |$resultTerm = $contextTerm.timerService().currentProcessingTime();
         |""".stripMargin.trim
    GeneratedExpression(resultTerm, NEVER_NULL, resultCode, DataTypes.TIMESTAMP)
  }

  def generateCurrentTimestamp(
      ctx: CodeGeneratorContext): GeneratedExpression = {
    new CurrentTimePointCallGen(false).generate(ctx, Seq(), DataTypes.TIMESTAMP, false)
  }

  def generateRowtimeAccess(
      contextTerm: String,
      ctx: CodeGeneratorContext): GeneratedExpression = {
    val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
      Seq("result", "isNull"),
      Seq("Long", "boolean"))

    val accessCode =
      s"""
         |$resultTerm = $contextTerm.timestamp();
         |if ($resultTerm == null) {
         |  throw new RuntimeException("Rowtime timestamp is null. Please make sure that a " +
         |    "proper TimestampAssigner is defined and the stream environment uses the EventTime " +
         |    "time characteristic.");
         |}
         |$nullTerm = false;
       """.stripMargin.trim

    GeneratedExpression(resultTerm, nullTerm, accessCode, DataTypes.ROWTIME_INDICATOR)
  }

  def generateInputAccess(
      ctx: CodeGeneratorContext,
      inputType: InternalType,
      inputTerm: String,
      index: Int,
      nullableInput: Boolean,
      nullCheck: Boolean,
      fieldCopy: Boolean = false): GeneratedExpression = {
    // if input has been used before, we can reuse the code that
    // has already been generated
    val inputExpr = ctx.getReusableInputUnboxingExprs(inputTerm, index) match {
      // input access and unboxing has already been generated
      case Some(expr) => expr

      // generate input access and unboxing if necessary
      case None =>
        val expr = if (nullableInput) {
          generateNullableInputFieldAccess(ctx, inputType, inputTerm, index, nullCheck, fieldCopy)
        } else {
          generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck, fieldCopy)
        }

        ctx.addReusableInputUnboxingExprs(inputTerm, index, expr)
        expr
    }
    // hide the generated code as it will be executed only once
    GeneratedExpression(inputExpr.resultTerm, inputExpr.nullTerm, "", inputExpr.resultType)
  }

  /**
    * Generates field access code expression. The different between this method and
    * [[generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck)]] is that this method
    * accepts an additional `fieldCopy` parameter. When copyResult is set to true, the returned
    * result will be copied.
    *
    * NOTE: Please set `fieldCopy` to true when the result will be buffered.
    */
  def generateFieldAccess(
      ctx: CodeGeneratorContext,
      inputType: InternalType,
      inputTerm: String,
      index: Int,
      nullCheck: Boolean,
      fieldCopy: Boolean): GeneratedExpression = {
    val expr = generateFieldAccess(ctx, inputType, inputTerm, index, nullCheck)
    if (fieldCopy) {
      expr.copyResultIfNeeded(ctx, fieldCopy)
    } else {
      expr
    }
  }

  def generateFieldAccess(
      ctx: CodeGeneratorContext,
      inputType: InternalType,
      inputTerm: String,
      index: Int,
      nullCheck: Boolean): GeneratedExpression =
    inputType match {
      case ct: RowType =>
        val fieldType = ct.getFieldTypes()(index).toInternalType
        val resultTypeTerm = primitiveTypeTermForType(fieldType)
        val defaultValue = primitiveDefaultValue(fieldType)
        val readCode = baseRowFieldReadAccess(ctx, index.toString, inputTerm, fieldType)
        val Seq(fieldTerm, nullTerm) = ctx.newReusableFields(
          Seq("field", "isNull"),
          Seq(resultTypeTerm, "boolean"))
        val inputCode = if (nullCheck) {
          s"""
             |$nullTerm = $inputTerm.isNullAt($index);
             |$fieldTerm = $defaultValue;
             |if (!$nullTerm) {
             |  $fieldTerm = $readCode;
             |}
           """.stripMargin.trim
        } else {
          s"""
             |$nullTerm = false;
             |$fieldTerm = $readCode;
           """.stripMargin
        }
        GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType)

      case _ =>
        val fieldTypeTerm = boxedTypeTermForType(inputType)
        val inputCode = s"($fieldTypeTerm) $inputTerm"
        generateInputFieldUnboxing(inputType, inputCode, nullCheck, ctx)
    }

  def generateNullableInputFieldAccess(
      ctx: CodeGeneratorContext,
      inputType: InternalType,
      inputTerm: String,
      index: Int,
      nullCheck: Boolean,
      fieldCopy: Boolean = false): GeneratedExpression = {

    val fieldType = inputType match {
      case ct: RowType => ct.getFieldTypes()(index).toInternalType
      case _ => inputType
    }
    val resultTypeTerm = primitiveTypeTermForType(fieldType)
    val defaultValue = primitiveDefaultValue(fieldType)

    val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
      Seq("result", "isNull"),
      Seq(resultTypeTerm, "boolean"))
    val fieldAccessExpr = generateFieldAccess(
      ctx, inputType, inputTerm, index, nullCheck, fieldCopy)

    val inputCheckCode =
      s"""
         |$resultTerm = $defaultValue;
         |$nullTerm = true;
         |if ($inputTerm != null) {
         |  ${fieldAccessExpr.code}
         |  $resultTerm = ${fieldAccessExpr.resultTerm};
         |  $nullTerm = ${fieldAccessExpr.nullTerm};
         |}
         |""".stripMargin.trim

    GeneratedExpression(resultTerm, nullTerm, inputCheckCode, fieldType)
  }

  /**
   * Converts the external boxed format to an internal mostly primitive field representation.
   * Wrapper types can autoboxed to their corresponding primitive type (Integer -> int).
   *
   * @param fieldType type of field
   * @param fieldTerm expression term of field to be unboxed
   * @param nullCheck whether to check null
   * @return internal unboxed field representation
   */
  def generateInputFieldUnboxing(
      fieldType: InternalType,
      fieldTerm: String,
      nullCheck: Boolean,
      ctx: CodeGeneratorContext): GeneratedExpression = {

    val resultTypeTerm = primitiveTypeTermForType(fieldType)
    val defaultValue = primitiveDefaultValue(fieldType)

    val Seq(resultTerm, nullTerm) = ctx.newReusableFields(
      Seq("result", "isNull"),
      Seq(resultTypeTerm, "boolean"))

    val wrappedCode = if (nullCheck) {
      s"""
         |$nullTerm = $fieldTerm == null;
         |$resultTerm = $defaultValue;
         |if (!$nullTerm) {
         |  $resultTerm = $fieldTerm;
         |}
         |""".stripMargin.trim
    } else {
      s"""
         |$resultTerm = $fieldTerm;
         |""".stripMargin.trim
    }

    GeneratedExpression(resultTerm, nullTerm, wrappedCode, fieldType)
  }

  def generateCallExpression(
      ctx: CodeGeneratorContext,
      operator: SqlOperator,
      operands: Seq[GeneratedExpression],
      resultType: InternalType,
      nullCheck: Boolean): GeneratedExpression = {
    operator match {
      // arithmetic
      case PLUS if isNumeric(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireNumeric(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "+", nullCheck, resultType, left, right)

      case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireTemporal(left, operator.getName)
        requireTemporal(right, operator.getName)
        generateTemporalPlusMinus(ctx, plus = true, nullCheck, resultType, left, right)

      case MINUS if isNumeric(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireNumeric(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "-", nullCheck, resultType, left, right)

      case MINUS | MINUS_DATE if isTemporal(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireTemporal(left, operator.getName)
        requireTemporal(right, operator.getName)
        generateTemporalPlusMinus(ctx, plus = false, nullCheck, resultType, left, right)

      case MULTIPLY if isNumeric(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireNumeric(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "*", nullCheck, resultType, left, right)

      case MULTIPLY if isTimeInterval(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireTimeInterval(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "*", nullCheck, resultType, left, right)

      case ScalarSqlFunctions.DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireNumeric(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "/", nullCheck, resultType, left, right)

      case MOD if isNumeric(resultType) =>
        val left = operands.head
        val right = operands(1)
        requireNumeric(left, operator.getName)
        requireNumeric(right, operator.getName)
        generateArithmeticOperator(ctx, "%", nullCheck, resultType, left, right)

      case UNARY_MINUS if isNumeric(resultType) =>
        val operand = operands.head
        requireNumeric(operand, operator.getName)
        generateUnaryArithmeticOperator(ctx, "-", nullCheck, resultType, operand)

      case UNARY_MINUS if isTimeInterval(resultType) =>
        val operand = operands.head
        requireTimeInterval(operand, operator.getName)
        generateUnaryIntervalPlusMinus(ctx, plus = false, nullCheck, operand)

      case UNARY_PLUS if isNumeric(resultType) =>
        val operand = operands.head
        requireNumeric(operand, operator.getName)
        generateUnaryArithmeticOperator(ctx, "+", nullCheck, resultType, operand)

      case UNARY_PLUS if isTimeInterval(resultType) =>
        val operand = operands.head
        requireTimeInterval(operand, operator.getName)
        generateUnaryIntervalPlusMinus(ctx, plus = true, nullCheck, operand)

      // comparison
      case EQUALS =>
        val left = operands.head
        val right = operands(1)
        generateEquals(ctx, nullCheck, left, right)

      case NOT_EQUALS =>
        val left = operands.head
        val right = operands(1)
        generateNotEquals(ctx, nullCheck, left, right)

      case GREATER_THAN =>
        val left = operands.head
        val right = operands(1)
        requireComparable(left, operator.getName)
        requireComparable(right, operator.getName)
        generateComparison(ctx, ">", nullCheck, left, right)

      case GREATER_THAN_OR_EQUAL =>
        val left = operands.head
        val right = operands(1)
        requireComparable(left, operator.getName)
        requireComparable(right, operator.getName)
        generateComparison(ctx, ">=", nullCheck, left, right)

      case LESS_THAN =>
        val left = operands.head
        val right = operands(1)
        requireComparable(left, operator.getName)
        requireComparable(right, operator.getName)
        generateComparison(ctx, "<", nullCheck, left, right)

      case LESS_THAN_OR_EQUAL =>
        val left = operands.head
        val right = operands(1)
        requireComparable(left, operator.getName)
        requireComparable(right, operator.getName)
        generateComparison(ctx, "<=", nullCheck, left, right)

      case IS_NULL =>
        val operand = operands.head
        generateIsNull(nullCheck, operand)

      case IS_NOT_NULL =>
        val operand = operands.head
        generateIsNotNull(nullCheck, operand)

      // logic
      case AND =>
        operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
          requireBoolean(left, operator.getName)
          requireBoolean(right, operator.getName)
          generateAnd(nullCheck, left, right)
        }

      case OR =>
        operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
          requireBoolean(left, operator.getName)
          requireBoolean(right, operator.getName)
          generateOr(nullCheck, left, right)
        }

      case NOT =>
        val operand = operands.head
        requireBoolean(operand, operator.getName)
        generateNot(ctx, nullCheck, operand)

      case CASE =>
        generateIfElse(ctx, nullCheck, operands, resultType)

      case IS_TRUE =>
        val operand = operands.head
        requireBoolean(operand, operator.getName)
        generateIsTrue(operand)

      case IS_NOT_TRUE =>
        val operand = operands.head
        requireBoolean(operand, operator.getName)
        generateIsNotTrue(operand)

      case IS_FALSE =>
        val operand = operands.head
        requireBoolean(operand, operator.getName)
        generateIsFalse(operand)

      case IS_NOT_FALSE =>
        val operand = operands.head
        requireBoolean(operand, operator.getName)
        generateIsNotFalse(operand)

      case IN =>
        val left = operands.head
        val right = operands.tail
        generateIn(ctx, left, right, nullCheck)

      case NOT_IN =>
        val left = operands.head
        val right = operands.tail
        generateNot(ctx, nullCheck, generateIn(ctx, left, right, nullCheck))

      // casting
      case CAST =>
        val operand = operands.head
        generateCast(ctx, nullCheck, operand, resultType)

      // Reinterpret
      case REINTERPRET =>
        val operand = operands.head
        generateReinterpret(ctx, nullCheck, operand, resultType)

      // as / renaming
      case AS =>
        operands.head

      // rows
      case ROW =>
        generateRow(ctx, resultType, operands, nullCheck)

      // arrays
      case ARRAY_VALUE_CONSTRUCTOR =>
        generateArray(ctx, resultType, operands, nullCheck)

      // maps
      case MAP_VALUE_CONSTRUCTOR =>
        generateMap(ctx, resultType, operands, nullCheck)

      case ITEM =>
        operands.head.resultType match {
          case t: InternalType if TypeCheckUtils.isArray(t) =>
            val array = operands.head
            val index = operands(1)
            requireInteger(index, operator.getName)
            generateArrayElementAt(ctx, array, index, nullCheck)

          case t: InternalType if TypeCheckUtils.isMap(t) =>
            val key = operands(1)
            generateMapGet(ctx, operands.head, key, nullCheck)

          case _ => throw new CodeGenException("Expect an array or a map.")
        }

      case CARDINALITY =>
        operands.head.resultType match {
          case t: InternalType if TypeCheckUtils.isArray(t) =>
            val array = operands.head
            generateArrayCardinality(ctx, nullCheck, array)

          case t: InternalType if TypeCheckUtils.isMap(t) =>
            val map = operands.head
            generateMapCardinality(ctx, nullCheck, map)

          case _ => throw new CodeGenException("Expect an array or a map.")
        }

      case ELEMENT =>
        val array = operands.head
        requireArray(array, operator.getName)
        generateArrayElement(ctx, array, nullCheck)

      case DOT =>
        generateDOT(ctx, operands, nullCheck)

      case func: SqlRuntimeFilterFunction => generateRuntimeFilter(ctx, operands, func)

      case func: SqlRuntimeFilterBuilderFunction =>
        generateRuntimeFilterBuilder(ctx, operands, func)

      case _: SqlThrowExceptionFunction =>
        val nullValue = generateNullLiteral(resultType, nullCheck)
        val code =
          s"""
             |${nullValue.code}
             |org.apache.flink.util.ExceptionUtils.rethrow(
             |  new RuntimeException(${operands.head.resultTerm}.toString()));
             |""".stripMargin
        GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)

      case ScalarSqlFunctions.PROCTIME =>
        // attribute is proctime indicator.
        // We use a null literal and generate a timestamp when we need it.
        generateNullLiteral(DataTypes.PROCTIME_INDICATOR, nullCheck)

      // advanced scalar functions
      case sqlOperator: SqlOperator =>
        BinaryStringCallGen.generateCallExpression(ctx, operator, operands, resultType).getOrElse{
          FunctionGenerator.getCallGenerator(
            sqlOperator,
            operands.map(expr => expr.resultType),
            resultType).getOrElse(
            throw new CodeGenException(TableErrors.INST.sqlCodeGenUnsupportedScalaFunc(
              s"$sqlOperator(${operands.map(_.resultType).mkString(",")})")))
            .generate(ctx, operands, resultType, nullCheck)
        }

      // unknown or invalid
      case call@_ =>
        throw new CodeGenException(
          TableErrors.INST.sqlCodeGenUnsupportedCall(
            s"$call${operands.map(_.resultType).mkString(",")}"))
    }
  }

  // ----------------------------------------------------------------------------------------------

  def isReference(genExpr: GeneratedExpression): Boolean = isReference(genExpr.resultType)

  def isReference(t: InternalType): Boolean = t match {
    case DataTypes.INT
         | DataTypes.LONG
         | DataTypes.SHORT
         | DataTypes.BYTE
         | DataTypes.FLOAT
         | DataTypes.DOUBLE
         | DataTypes.BOOLEAN
         | DataTypes.CHAR => false
    case _ => true
  }

  def baseRowFieldReadAccess(
      ctx: CodeGeneratorContext,
      pos: Int,
      rowTerm: String,
      fieldType: InternalType) : String =
    baseRowFieldReadAccess(ctx, pos.toString, rowTerm, fieldType)

  def baseRowFieldReadAccess(
      ctx: CodeGeneratorContext,
      pos: String,
      rowTerm: String,
      fieldType: InternalType) : String =
    fieldType match {
      case DataTypes.INT => s"$rowTerm.getInt($pos)"
      case DataTypes.LONG => s"$rowTerm.getLong($pos)"
      case DataTypes.SHORT => s"$rowTerm.getShort($pos)"
      case DataTypes.BYTE => s"$rowTerm.getByte($pos)"
      case DataTypes.FLOAT => s"$rowTerm.getFloat($pos)"
      case DataTypes.DOUBLE => s"$rowTerm.getDouble($pos)"
      case DataTypes.BOOLEAN => s"$rowTerm.getBoolean($pos)"
      case DataTypes.STRING =>
        val reuse = newName("reuseBString")
        ctx.addReusableMember(s"$BINARY_STRING $reuse = new $BINARY_STRING();")
        s"$rowTerm.getBinaryString($pos, $reuse)"
      case dt: DecimalType => s"$rowTerm.getDecimal($pos, ${dt.precision()}, ${dt.scale()})"
      case DataTypes.CHAR => s"$rowTerm.getChar($pos)"
      case _: TimestampType => s"$rowTerm.getLong($pos)"
      case _: DateType => s"$rowTerm.getInt($pos)"
      case DataTypes.TIME => s"$rowTerm.getInt($pos)"
      case DataTypes.BYTE_ARRAY => s"$rowTerm.getByteArray($pos)"
      case _: ArrayType => s"$rowTerm.getBaseArray($pos)"
      case _: MapType  => s"$rowTerm.getBaseMap($pos)"
      case rt: RowType =>
        s"$rowTerm.getBaseRow($pos, ${rt.getArity})"

      case gt: GenericType[_] =>
        s"""
           |(${gt.getTypeClass.getCanonicalName})
           |  $rowTerm.getGeneric($pos, ${ctx.addReusableTypeSerializer(fieldType)})
         """.stripMargin.trim
    }

  def binaryWriterWriteNull(pos: Int, writerTerm: String, t: InternalType): String = t match {
    case d: DecimalType if !Decimal.isCompact(d.precision()) =>
      s"$writerTerm.writeDecimal($pos, null, ${d.precision()}, ${d.scale()})"
    case _ => s"$writerTerm.setNullAt($pos)"
  }

  def binaryRowSetNull(pos: Int, rowTerm: String, t: InternalType): String = t match {
    case d: DecimalType if !Decimal.isCompact(d.precision()) =>
      s"$rowTerm.setDecimal($pos, null, ${d.precision()}, ${d.scale()})"
    case _ => s"$rowTerm.setNullAt($pos)"
  }

  def binaryRowFieldSetAccess(
      pos: Int,
      binaryRowTerm: String,
      fieldType: InternalType,
      fieldValTerm: String)
    : String =
    fieldType match {
      case DataTypes.INT => s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
      case DataTypes.LONG => s"$binaryRowTerm.setLong($pos, $fieldValTerm)"
      case DataTypes.SHORT => s"$binaryRowTerm.setShort($pos, $fieldValTerm)"
      case DataTypes.BYTE => s"$binaryRowTerm.setByte($pos, $fieldValTerm)"
      case DataTypes.FLOAT => s"$binaryRowTerm.setFloat($pos, $fieldValTerm)"
      case DataTypes.DOUBLE => s"$binaryRowTerm.setDouble($pos, $fieldValTerm)"
      case DataTypes.BOOLEAN => s"$binaryRowTerm.setBoolean($pos, $fieldValTerm)"
      case DataTypes.CHAR =>  s"$binaryRowTerm.setChar($pos, $fieldValTerm)"
      case _: DateType =>  s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
      case DataTypes.TIME =>  s"$binaryRowTerm.setInt($pos, $fieldValTerm)"
      case _: TimestampType =>  s"$binaryRowTerm.setLong($pos, $fieldValTerm)"
      case d: DecimalType =>
        s"$binaryRowTerm.setDecimal($pos, $fieldValTerm, ${d.precision()}, ${d.scale()})"
      case _ =>
        throw new CodeGenException("Fail to find binary row field setter method of InternalType "
            + fieldType + ".")
    }

  def binaryWriterWriteField(
      ctx: CodeGeneratorContext,
      pos: Int,
      fieldValTerm: String,
      writerTerm: String,
      fieldType: InternalType): String =
    fieldType match {
      case DataTypes.INT => s"$writerTerm.writeInt($pos, $fieldValTerm)"
      case DataTypes.LONG => s"$writerTerm.writeLong($pos, $fieldValTerm)"
      case DataTypes.SHORT => s"$writerTerm.writeShort($pos, $fieldValTerm)"
      case DataTypes.BYTE => s"$writerTerm.writeByte($pos, $fieldValTerm)"
      case DataTypes.FLOAT => s"$writerTerm.writeFloat($pos, $fieldValTerm)"
      case DataTypes.DOUBLE => s"$writerTerm.writeDouble($pos, $fieldValTerm)"
      case DataTypes.BOOLEAN => s"$writerTerm.writeBoolean($pos, $fieldValTerm)"
      case DataTypes.STRING => s"$writerTerm.writeBinaryString($pos, $fieldValTerm)"
      case d: DecimalType =>
        s"$writerTerm.writeDecimal($pos, $fieldValTerm, ${d.precision()}, ${d.scale()})"
      case DataTypes.CHAR => s"$writerTerm.writeChar($pos, $fieldValTerm)"
      case _: DateType => s"$writerTerm.writeInt($pos, $fieldValTerm)"
      case DataTypes.TIME => s"$writerTerm.writeInt($pos, $fieldValTerm)"
      case _: TimestampType => s"$writerTerm.writeLong($pos, $fieldValTerm)"
      case DataTypes.BYTE_ARRAY => s"$writerTerm.writeByteArray($pos, $fieldValTerm)"
      case _: ArrayType =>
        s"$BASE_ROW_UTIL.writeBaseArray($writerTerm, $pos, $fieldValTerm, " +
          s"(${classOf[BaseArraySerializer].getCanonicalName}) " +
            s"${ctx.addReusableTypeSerializer(fieldType)})"

      case _: MapType =>
        s"$BASE_ROW_UTIL.writeBaseMap($writerTerm, $pos, $fieldValTerm, " +
          s"(${classOf[BaseMapSerializer].getCanonicalName}) " +
          s"${ctx.addReusableTypeSerializer(fieldType)})"

      case _: RowType =>
        s"$BASE_ROW_UTIL.writeBaseRow($writerTerm, $pos, $fieldValTerm, " +
          s"(${classOf[BaseRowSerializer[_]].getCanonicalName}) " +
          s"${ctx.addReusableTypeSerializer(fieldType)})"

      case _: GenericType[_] => s"$writerTerm.writeGeneric($pos, $fieldValTerm, " +
        s"${ctx.addReusableTypeSerializer(fieldType)})"
    }

  def baseArraySetNull(
      pos: Int,
      term: String,
      t: InternalType): String = t match {
    case DataTypes.BOOLEAN => s"$term.setNullBoolean($pos)"
    case DataTypes.BYTE => s"$term.setNullByte($pos)"
    case DataTypes.CHAR => s"$term.setNullChar($pos)"
    case DataTypes.SHORT => s"$term.setNullShort($pos)"
    case DataTypes.INT => s"$term.setNullInt($pos)"
    case DataTypes.LONG => s"$term.setNullLong($pos)"
    case DataTypes.FLOAT => s"$term.setNullFloat($pos)"
    case DataTypes.DOUBLE => s"$term.setNullDouble($pos)"
    case DataTypes.TIME => s"$term.setNullInt($pos)"
    case _: DateType => s"$term.setNullInt($pos)"
    case _: TimestampType => s"$term.setNullLong($pos)"
    case _ => s"$term.setNullLong($pos)"
  }

  def boxedWrapperRowFieldUpdateAccess(
      pos: Int,
      fieldValTerm: String,
      rowTerm: String,
      fieldType: InternalType): String =
    fieldType match {
      case DataTypes.INT => s"$rowTerm.setInt($pos, $fieldValTerm)"
      case DataTypes.LONG => s"$rowTerm.setLong($pos, $fieldValTerm)"
      case DataTypes.SHORT => s"$rowTerm.setShort($pos, $fieldValTerm)"
      case DataTypes.BYTE => s"$rowTerm.setByte($pos, $fieldValTerm)"
      case DataTypes.FLOAT => s"$rowTerm.setFloat($pos, $fieldValTerm)"
      case DataTypes.DOUBLE => s"$rowTerm.setDouble($pos, $fieldValTerm)"
      case DataTypes.BOOLEAN => s"$rowTerm.setBoolean($pos, $fieldValTerm)"
      case DataTypes.CHAR =>  s"$rowTerm.setChar($pos, $fieldValTerm)"
      case _: DateType =>  s"$rowTerm.setInt($pos, $fieldValTerm)"
      case DataTypes.TIME =>  s"$rowTerm.setInt($pos, $fieldValTerm)"
      case _: TimestampType =>  s"$rowTerm.setLong($pos, $fieldValTerm)"
      case _ => s"$rowTerm.setNonPrimitiveValue($pos, $fieldValTerm)"
    }

  // ----------------------------------------------------------------------------------------------

  @throws(classOf[CompileException])
  def compile[T](cl: ClassLoader, name: String, code: String): Class[T] = {
    CODE_LOG.debug(s"Compiling: $name \n\n Code:\n$code")
    require(cl != null, "Classloader must not be null.")
    val compiler = new SimpleCompiler()
    compiler.setParentClassLoader(cl)
    try {
      compiler.cook(code)
    } catch {
      case t: Throwable =>
        println(CodeFormatter.format(code))
        throw new InvalidProgramException("Table program cannot be compiled. " +
            "This is a bug. Please file an issue.", t)
    }
    compiler.getClassLoader.loadClass(name).asInstanceOf[Class[T]]
  }

  /**
    * enable code generate debug for janino
    * like "gcc -g"
    */
  def enableCodeGenerateDebug(): Unit = {
    System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_ENABLE, "true")
  }

  def disableCodeGenerateDebug(): Unit = {
    System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_ENABLE, "false")
  }

  def setCodeGenerateTmpDir(path: String): Unit = {
    if (!StringUtils.isNullOrWhitespaceOnly(path)) {
      System.setProperty(ICookable.SYSTEM_PROPERTY_SOURCE_DEBUGGING_DIR, path)
    } else {
      throw new RuntimeException("code generate tmp dir can't be empty")
    }
  }

  // ----------------------------------------------------------------------------------------------

  def genLogInfo(logTerm: String, format: String, argTerm: String): String =
    s"""$logTerm.info("$format", $argTerm);"""

  /**
    *
    * @param codeBuffer sequence of code need to be split
    * @param limitLength maxLength of split code
    * @param subFunctionName name of function which code will be split into
    * @param subFunctionModifier modifier of function which code will be split into
    * @param defineParams  params for function definition
    * @param callingParams params for function call
    * @return
    */
  def generateSplitFunctionCalls(
    codeBuffer: Seq[String],
    limitLength: Int,
    subFunctionName: String,
    subFunctionModifier: String,
    fieldStatementLength: Int,
    defineParams: String = "",
    callingParams: String = ""): GeneratedSplittableExpression = {

    val bodies = new ListBuffer[String]()
    val rest = codeBuffer.foldLeft("")((acc, code) => {
      if (acc.length + code.length <= limitLength) {
        if (acc.length > 0) {
          acc + "\n" + code
        } else {
          code
        }
      } else {
        if (acc.length > 0) {
          bodies += acc
        }
        code
      }
    })
    bodies += rest

    val defines = bodies.indices
      .map(index => s"$subFunctionModifier ${subFunctionName}_$index($defineParams)")

    val callings = bodies.indices
      .map(index => s"${subFunctionName}_$index($callingParams);")

    val isSplit = (defines.length > 1) ||
      (defines.length == 1 && codeBuffer.map(_.length).sum + fieldStatementLength > limitLength)

    GeneratedSplittableExpression(defines, bodies, callings, isSplit)
  }

  def getDefineParamsByFunctionClass(clazz: Class[_]): String = {
    if (clazz == classOf[FlatMapFunction[_, _]]) {
      s"Object _in1, " +
        s"org.apache.flink.util.Collector ${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else if (clazz == classOf[MapFunction[_, _]]) {
      s"Object _in1"
    } else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
      s"Object _in1, Object _in2, " +
        s"org.apache.flink.util.Collector ${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else if (clazz == classOf[ProcessFunction[_, _]]) {
      "Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context " +
        s"${CodeGeneratorContext.DEFAULT_CONTEXT_TERM}, org.apache.flink.util.Collector " +
        s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else {
      ""
    }
  }

  def getCallingParamsByFunctionClass(clazz: Class[_]): String = {
    if (clazz == classOf[FlatMapFunction[_, _]]) {
      s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
        s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else if (clazz == classOf[MapFunction[_, _]]) {
      s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}"
    } else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
      s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
        s"${CodeGeneratorContext.DEFAULT_INPUT2_TERM}, " +
        s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else if (clazz == classOf[ProcessFunction[_, _]]) {
      s"${CodeGeneratorContext.DEFAULT_INPUT1_TERM}, " +
        s"${CodeGeneratorContext.DEFAULT_CONTEXT_TERM}, " +
        s"${CodeGeneratorContext.DEFAULT_COLLECTOR_TERM}"
    } else {
      ""
    }
  }

  // ----------------------------------------------------------------------------------------------

  // Cast numeric type to another numeric type with larger range.
  // This function must be in sync with [[NumericOrDefaultReturnTypeInference]].
  def getNumericCastedResultTerm(expr: GeneratedExpression, targetType: InternalType): String = {
    (expr.resultType, targetType) match {
      case _ if expr.resultType == targetType => expr.resultTerm

      // byte -> other numeric types
      case (_: ByteType, _: ShortType) => s"(short) ${expr.resultTerm}"
      case (_: ByteType, _: IntType) => s"(int) ${expr.resultTerm}"
      case (_: ByteType, _: LongType) => s"(long) ${expr.resultTerm}"
      case (_: ByteType, dt: DecimalType) =>
        s"${classOf[Decimal].getCanonicalName}.castFrom(" +
          s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
      case (_: ByteType, _: FloatType) => s"(float) ${expr.resultTerm}"
      case (_: ByteType, _: DoubleType) => s"(double) ${expr.resultTerm}"

      // short -> other numeric types
      case (_: ShortType, _: IntType) => s"(int) ${expr.resultTerm}"
      case (_: ShortType, _: LongType) => s"(long) ${expr.resultTerm}"
      case (_: ShortType, dt: DecimalType) =>
        s"${classOf[Decimal].getCanonicalName}.castFrom(" +
          s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
      case (_: ShortType, _: FloatType) => s"(float) ${expr.resultTerm}"
      case (_: ShortType, _: DoubleType) => s"(double) ${expr.resultTerm}"

      // int -> other numeric types
      case (_: IntType, _: LongType) => s"(long) ${expr.resultTerm}"
      case (_: IntType, dt: DecimalType) =>
        s"${classOf[Decimal].getCanonicalName}.castFrom(" +
          s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
      case (_: IntType, _: FloatType) => s"(float) ${expr.resultTerm}"
      case (_: IntType, _: DoubleType) => s"(double) ${expr.resultTerm}"

      // long -> other numeric types
      case (_: LongType, dt: DecimalType) =>
        s"${classOf[Decimal].getCanonicalName}.castFrom(" +
          s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
      case (_: LongType, _: FloatType) => s"(float) ${expr.resultTerm}"
      case (_: LongType, _: DoubleType) => s"(double) ${expr.resultTerm}"

      // decimal -> other numeric types
      case (_: DecimalType, dt: DecimalType) =>
        s"${classOf[Decimal].getCanonicalName}.castToDecimal(" +
          s"${expr.resultTerm}, ${dt.precision}, ${dt.scale})"
      case (_: DecimalType, _: FloatType) =>
        s"${classOf[Decimal].getCanonicalName}.castToFloat(${expr.resultTerm})"
      case (_: DecimalType, _: DoubleType) =>
        s"${classOf[Decimal].getCanonicalName}.castToDouble(${expr.resultTerm})"

      // float -> other numeric types
      case (_: FloatType, _: DoubleType) => s"(double) ${expr.resultTerm}"

      case _ => null
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy