All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.scala Maven / Gradle / Ivy

There is a newer version: 1.5.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package org.apache.flink.table.functions.utils

import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils._
import org.apache.flink.table.api.functions._
import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem}
import org.apache.flink.table.codegen.CodeGenUtils
import org.apache.flink.table.codegen.CodeGenUtils.{boxedTypeTermForType, externalBoxedTermForType}
import org.apache.flink.table.dataformat.{BaseRow, BinaryString, Decimal}
import org.apache.flink.table.errorcode.TableErrors
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.schema.DeferredTypeFlinkTableFunction
import org.apache.flink.table.types.{BooleanType, ByteArrayType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, GenericType, IntType, InternalType, LongType, RowType, ShortType, StringType, TimeType, TimestampType, TypeInfoWrappedDataType}
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.types.Row
import org.apache.flink.util.InstantiationUtil

import com.google.common.primitives.Primitives
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlFunction, SqlOperatorBinding}
import org.apache.commons.codec.binary.Base64

import java.lang.reflect.{Method, Modifier}
import java.lang.{Integer => JInt, Long => JLong}
import java.sql.{Date, Time, Timestamp}

import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.postfixOps

object UserDefinedFunctionUtils {

  /**
    * Checks if a user-defined function can be easily instantiated.
    */
  def checkForInstantiation(clazz: Class[_]): Unit = {
    if (!InstantiationUtil.isPublic(clazz)) {
      throw new ValidationException(s"Function class ${clazz.getCanonicalName} is not public.")
    }
    else if (!InstantiationUtil.isProperClass(clazz)) {
      throw new ValidationException(
        s"Function class ${clazz.getCanonicalName} is no proper class," +
        " it is either abstract, an interface, or a primitive type.")
    }
    else if (InstantiationUtil.isNonStaticInnerClass(clazz)) {
      throw new ValidationException(
        s"The class ${clazz.getCanonicalName} is an inner class, but" +
        " not statically accessible.")
    }
  }

  /**
    * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented
    * by a Scala object, since concurrent risks.
    */
  def checkNotSingleton(clazz: Class[_]): Unit = {
    // TODO it is not a good way to check singleton. Maybe improve it further.
    if (clazz.getFields.map(_.getName) contains "MODULE$") {
      throw new ValidationException(
        s"TableFunction implemented by class ${clazz.getCanonicalName} " +
          s"is a Scala object, it is forbidden since concurrent risks.")
    }
  }

  // ----------------------------------------------------------------------------------------------
  // Utilities for user-defined methods
  // ----------------------------------------------------------------------------------------------

  def throwValidationException(
      name: String,
      func: UserDefinedFunction,
      parameters: Array[InternalType]): Method = {
    throw new ValidationException(
      s"Given parameters of function '$name' do not match any signature. \n" +
          s"Actual: ${signatureToString(parameters)} \n" +
          s"Expected: ${signaturesToString(func, "eval")}")
  }

  private[table] def getParamClassesConsiderVarArgs(
      isVarArgs: Boolean,
      matchingSignature: Array[Class[_]],
      expectedLength: Int): Array[Class[_]] = {
    var paramClasses = new mutable.ArrayBuffer[Class[_]]
    for (i <- 0 until expectedLength) {
      if (i < matchingSignature.length - 1) {
        paramClasses += matchingSignature(i)
      } else if (isVarArgs) {
        paramClasses += matchingSignature.last.getComponentType
      } else {
        // last argument is not an array type
        paramClasses += matchingSignature.last
      }
    }
    paramClasses.toArray
  }

  def getEvalMethodSignature(
      func: CustomTypeDefinedFunction,
      expectedTypes: Array[InternalType]): Array[Class[_]] = {
    val method = getEvalUserDefinedMethod(func, expectedTypes).getOrElse(
      throwValidationException(func.getClass.getCanonicalName, func, expectedTypes)
    )
    getParamClassesConsiderVarArgs(method.isVarArgs, method.getParameterTypes, expectedTypes.length)
  }

  def getAggUserDefinedInputTypes(
      func: AggregateFunction[_, _],
      externalAccType: DataType,
      expectedTypes: Array[InternalType]): Array[DataType] = {
    val accMethod = getAggFunctionUDIMethod(
      func, "accumulate", externalAccType, expectedTypes).getOrElse(
      throwValidationException(func.getClass.getCanonicalName, func, expectedTypes)
    )
    func.getUserDefinedInputTypes(
      getParamClassesConsiderVarArgs(
        accMethod.isVarArgs,
        // drop first, first must be Acc.
        accMethod.getParameterTypes.drop(1),
        expectedTypes.length)).zipWithIndex.map {
      case (t, i) =>
        // we don't trust GenericType.
        if (TypeUtils.isGeneric(t)) expectedTypes(i) else t
    }
  }

  /**
    * Returns signatures of accumulate methods matching the given signature of [[InternalType]].
    * Elements of the signature can be null (act as a wildcard).
    */
  def getAccumulateMethodSignature(
      function: AggregateFunction[_, _],
      expectedTypes: Seq[InternalType])
      : Option[Array[Class[_]]] = {
    getAggFunctionUDIMethod(
      function,
      "accumulate",
      getAccumulatorTypeOfAggregateFunction(function),
      expectedTypes
    ).map(_.getParameterTypes)
  }

  def getParameterTypes(
      function: UserDefinedFunction,
      signature: Array[Class[_]]): Array[InternalType] = {
    signature.map { c =>
      try {
        TypeExtractor.getForClass(c).toInternalType
      } catch {
        case ite: InvalidTypesException =>
          throw new ValidationException(
            s"Parameter types of function '${function.getClass.getCanonicalName}' cannot be " +
              s"automatically determined. Please provide type information manually.")
      }
    }
  }

  def getEvalUserDefinedMethod(
      function: CustomTypeDefinedFunction,
      expectedTypes: Seq[InternalType])
    : Option[Method] = {
    getUserDefinedMethod(
      function,
      "eval",
      typesToClasses(expectedTypes),
      expectedTypes.toArray,
      (paraClasses) => function.getParameterTypes(paraClasses))
  }

  def getAggFunctionUDIMethod(
      function: AggregateFunction[_, _],
      methodName: String,
      accType: DataType,
      expectedTypes: Seq[InternalType])
      : Option[Method] = {
    val input = (Array(accType) ++ expectedTypes).toSeq
    getUserDefinedMethod(
      function,
      methodName,
      typesToClasses(input),
      input.map{ t => if (t == null) null else t.toInternalType}.toArray,
      (cls) => Array(accType) ++
          function.getUserDefinedInputTypes(cls.drop(1)))
  }

  /**
    * Get method without match DateType.
    */
  def getUserDefinedMethod(
      function: UserDefinedFunction,
      methodName: String,
      signature: Seq[DataType])
  : Option[Method] = {
    getUserDefinedMethod(
      function,
      methodName,
      typesToClasses(signature),
      signature.map{ t => if (t == null) null else t.toInternalType}.toArray,
      (cls) => cls.indices.map((_) => null).toArray)
  }

  /**
    * Returns user defined method matching the given name and signature.
    *
    * @param function                function instance
    * @param methodName              method name
    * @param methodSignature         an array of raw Java classes. We compare the raw Java classes
    *                                not the DateType. DateType does not matter during runtime (e.g.
    *                                within a MapFunction)
    * @param internalTypes           internal data types of methodSignature
    * @param parameterTypes          user provided parameter data types, usually comes from invoking
    *                                CustomTypeDefinedFunction#getParameterTypes
    * @param parameterClassEquals    function ((expect, reflected) -> Boolean) to decide if the
    *                                provided expect parameter class is equals to reflection method
    *                                signature class. The expect class comes from param
    *                                [methodSignature].
    *
    * @param parameterDataTypeEquals function ((expect, dataType) -> Boolean) to decide if the
    *                                provided expect parameter data type is equals to type in
    *                                [parameterTypes].
    */
  def getUserDefinedMethod(
      function: UserDefinedFunction,
      methodName: String,
      methodSignature: Array[Class[_]],
      internalTypes: Array[InternalType],
      parameterTypes: Array[Class[_]] => Array[DataType],
      parameterClassEquals: (Class[_], Class[_]) => Boolean = parameterClassEquals,
      parameterDataTypeEquals: (InternalType, DataType) => Boolean = parameterDataTypeEquals)
    : Option[Method] = {

    val methods = checkAndExtractMethods(function, methodName)

    var applyCnt = 0
    val filtered = methods
      // go over all the methods and filter out matching methods
      .filter {
        case cur if !cur.isVarArgs =>
          val signatures = cur.getParameterTypes
          val dataTypes = parameterTypes(signatures)
          // match parameters of signature to actual parameters
          methodSignature.length == signatures.length &&
            signatures.zipWithIndex.forall { case (clazz, i) => {
              if (methodSignature(i) == classOf[Object]) {
                // The element of the method signature comes from the Table API's apply().
                // We can not decide the type here. It is an Unresolved Expression.
                // Actually, we do not have to decide the type here, any method of the overrides
                // which matches the arguments count will do the job.
                // So here we choose any method is correct.
                applyCnt += 1
              }
              parameterClassEquals(methodSignature(i), clazz) ||
                  parameterDataTypeEquals(internalTypes(i), dataTypes(i))
            }
          }
        case cur if cur.isVarArgs =>
          val signatures = cur.getParameterTypes
          val dataTypes = parameterTypes(signatures)
          methodSignature.zipWithIndex.forall {
            // non-varargs
            case (clazz, i) if i < signatures.length - 1  =>
              parameterClassEquals(clazz, signatures(i)) ||
                  parameterDataTypeEquals(internalTypes(i), dataTypes(i))
            // varargs
            case (clazz, i) if i >= signatures.length - 1 =>
              parameterClassEquals(clazz, signatures.last.getComponentType) ||
                  parameterDataTypeEquals(internalTypes(i), dataTypes(i))
          } || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
    }

    // if there is a fixed method, compiler will call this method preferentially
    val fixedMethodsCount = filtered.count(!_.isVarArgs)
    val found = filtered.filter { cur =>
      fixedMethodsCount > 0 && !cur.isVarArgs ||
      fixedMethodsCount == 0 && cur.isVarArgs
    }.filter { cur =>
      // filter abstract methods
      !Modifier.isVolatile(cur.getModifiers)
    }

    // check if there is a Scala varargs annotation
    if (found.isEmpty &&
      methods.exists { method =>
        val signatures = method.getParameterTypes
        val dataTypes = parameterTypes(signatures)
        if (!method.isVarArgs && signatures.length != methodSignature.length) {
          false
        } else if (method.isVarArgs && signatures.length > methodSignature.length + 1) {
          false
        } else {
          signatures.zipWithIndex.forall {
            case (clazz, i) if i < signatures.length - 1 =>
              parameterClassEquals(methodSignature(i), clazz) ||
                parameterDataTypeEquals(internalTypes(i), dataTypes(i))
            case (clazz, i) if i == signatures.length - 1 =>
              clazz.getName.equals("scala.collection.Seq")
          }
        }
      }) {
      throw new ValidationException(
        s"Scala-style variable arguments in '$methodName' methods are not supported. Please " +
          s"add a @scala.annotation.varargs annotation.")
    } else if (found.length > 1) {
      if (applyCnt > 0) {
        // As we can not decide type while apply() exists, so choose any one is correct
        return found.headOption
      }
      throw new ValidationException(
        s"Found multiple '$methodName' methods which match the signature.")
    }
    found.headOption
  }

  /**
    * Check if a given method exists in the given function
    */
  def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = {
    val methods = function
      .getClass
      .getMethods
      .filter {
        m => m.getName == method
      }
    !methods.isEmpty
  }

  /**
    * Extracts methods and throws a [[ValidationException]] if no implementation
    * can be found, or implementation does not match the requirements.
    */
  def checkAndExtractMethods(
      function: UserDefinedFunction,
      methodName: String): Array[Method] = {
    val methods = function
      .getClass
      .getMethods
      .filter { m =>
        val modifiers = m.getModifiers
        m.getName == methodName &&
          Modifier.isPublic(modifiers) &&
          !Modifier.isAbstract(modifiers) &&
          !(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers))
      }

    if (methods.isEmpty) {
      throw new ValidationException(
        s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
          s"one method named '$methodName' which is public, not abstract and " +
          s"(in case of table functions) not static.")
    }

    methods
  }

  def getMethodSignatures(
      function: UserDefinedFunction,
      methodName: String): Array[Array[Class[_]]] = {
    checkAndExtractMethods(function, methodName).map(_.getParameterTypes)
  }

  // ----------------------------------------------------------------------------------------------
  // Utilities for SQL functions
  // ----------------------------------------------------------------------------------------------

  /**
    * Create [[SqlFunction]] for a [[ScalarFunction]]
    *
    * @param name function name
    * @param function scalar function
    * @param typeFactory type factory
    * @return the ScalarSqlFunction
    */
  def createScalarSqlFunction(
      name: String,
      displayName: String,
      function: ScalarFunction,
      typeFactory: FlinkTypeFactory)
    : SqlFunction = {
    new ScalarSqlFunction(name, displayName, function, typeFactory)
  }

  /**
    * Create [[SqlFunction]] instance for a [[TableFunction]].
    *
    * There are 2 cases that we should invoke this function to create TableFunction instance:
    * case 1: The function instance is initiated from a full class name, e.g. from a persisted
    * catalog.
    * case 2: Some builtin table functions registered in BuiltInFunctionCatalog.
    *
    * In total, the implicitResultType will only affect TableAPI result type inference when
    * we do expression validation, which would be used as a fallback type of
    * CustomTypeDefinedFunction#getResultType. See [[getResultTypeOfCTDFunction]] for details.
    *
    * @param name function name
    * @param tableFunction table function instance
    * @param typeFactory type factory, default to be [new FlinkTypeFactory(new FlinkTypeSystem())]
    * @return [[TableSqlFunction]] instance of the [[TableFunction]]
    */
  def createTableSqlFunction(
      name: String,
      displayName: String,
      tableFunction: TableFunction[_],
      typeFactory: FlinkTypeFactory = new FlinkTypeFactory(new FlinkTypeSystem()))
    : TableSqlFunction = {
    val implicitResultType = extractResultTypeFromTableFunction(tableFunction)
    createTableSqlFunction(name, displayName, tableFunction, implicitResultType, typeFactory)
  }

  /**
    * Create [[SqlFunction]] for a [[TableFunction]].
    *
    * Caution that the implicitResultType is only expect to be passed explicitly by Scala implicit
    * type inference.
    *
    * The entrance in BatchTableEnvironment.scala and StreamTableEnvironment.scala
    * {{{
    *   def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T])
    * }}}
    *
    * The implicitResultType would be inferred from type `T`.
    *
    * For all the other cases, please use
    * createTableSqlFunction (String, String, TableFunction, FlinkTypeFactory) instead.
    *
    * @param name function name
    * @param tableFunction table function
    * @param implicitResultType the implicit type information of returned table
    * @param typeFactory type factory
    * @return the TableSqlFunction
    */
  def createTableSqlFunction(
      name: String,
      displayName: String,
      tableFunction: TableFunction[_],
      implicitResultType: DataType,
      typeFactory: FlinkTypeFactory)
    : TableSqlFunction = {
    // we don't know the exact result type yet.
    val function = new DeferredTypeFlinkTableFunction(tableFunction, implicitResultType)
    new TableSqlFunction(name, displayName, tableFunction, implicitResultType,
      typeFactory, function)
  }

  /**
    * Create [[SqlFunction]] for an [[AggregateFunction]]
    *
    * @param name function name
    * @param aggFunction aggregate function
    * @param typeFactory type factory
    * @return the TableSqlFunction
    */
  def createAggregateSqlFunction(
      name: String,
      displayName: String,
      aggFunction: AggregateFunction[_, _],
      externalResultType: DataType,
      externalAccType: DataType,
      typeFactory: FlinkTypeFactory)
  : SqlFunction = {
    //check if a qualified accumulate method exists before create Sql function
    checkAndExtractMethods(aggFunction, "accumulate")

    AggSqlFunction(
      name,
      displayName,
      aggFunction,
      externalResultType,
      externalAccType,
      typeFactory,
      aggFunction.requiresOver)
  }

  // ----------------------------------------------------------------------------------------------
  // Utilities for user-defined functions
  // ----------------------------------------------------------------------------------------------

  /**
    * Tries to infer the DataType of an AggregateFunction's return type.
    *
    * @param aggregateFunction The AggregateFunction for which the return type is inferred.
    * @param extractedType The implicitly inferred type of the result type.
    *
    * @return The inferred result type of the AggregateFunction.
    */
  def getResultTypeOfAggregateFunction(
      aggregateFunction: AggregateFunction[_, _],
      extractedType: DataType = null)
    : DataType = {

    val resultType = aggregateFunction.getResultType
    if (resultType != null) {
      resultType
    } else if (extractedType != null) {
      extractedType
    } else {
      try {
        extractTypeFromAggregateFunction(aggregateFunction, 0)
      } catch {
        case ite: InvalidTypesException =>
          throw new TableException(
            "Cannot infer generic type of ${aggregateFunction.getClass}. " +
              "You can override AggregateFunction.getResultType() to specify the type.",
            ite
          )
      }
    }
  }

  /**
    * Tries to infer the InternalType of an AggregateFunction's accumulator type.
    *
    * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred.
    * @param extractedType The implicitly inferred type of the accumulator type.
    *
    * @return The inferred accumulator type of the AggregateFunction.
    */
  def getAccumulatorTypeOfAggregateFunction(
    aggregateFunction: AggregateFunction[_, _],
    extractedType: DataType = null): DataType = {

    val accType = aggregateFunction.getAccumulatorType
    if (accType != null) {
      accType
    } else if (extractedType != null) {
      extractedType
    } else {
      try {
        extractTypeFromAggregateFunction(aggregateFunction, 1)
      } catch {
        case ite: InvalidTypesException =>
          throw new TableException(
            "Cannot infer generic type of ${aggregateFunction.getClass}. " +
              "You can override AggregateFunction.getAccumulatorType() to specify the type.",
            ite
          )
      }
    }
  }

  def isResultInternalOfAggregateFunction(
      aggregateFunction: AggregateFunction[_, _], t: InternalType): Boolean = {
    try {
      CodeGenUtils.isInternalClass(
        extractTypeFromAggregateFunction(aggregateFunction, 0).getTypeClass, t)
    } catch {
      case _: InvalidTypesException =>
        false
    }
  }

  def isAccInternalOfAggregateFunction(
      aggregateFunction: AggregateFunction[_, _], t: InternalType): Boolean = {
    try {
      CodeGenUtils.isInternalClass(
        extractTypeFromAggregateFunction(aggregateFunction, 1).getTypeClass, t)
    } catch {
      case _: InvalidTypesException =>
        false
    }
  }

  def getAccBoxedTypeTermOfAggregateFunction(
      aggregateFunction: AggregateFunction[_, _], t: DataType): String = {
      try {
          extractTypeFromAggregateFunction(aggregateFunction, 1).getTypeClass.getCanonicalName
      } catch {
        case _: InvalidTypesException =>
          if (isAccInternalOfAggregateFunction(aggregateFunction, t.toInternalType)) {
            boxedTypeTermForType(t.toInternalType)
          } else {
            externalBoxedTermForType(t)
          }
      }
  }

  /**
    * Internal method to extract a type from an AggregateFunction's type parameters.
    *
    * @param aggregateFunction The AggregateFunction for which the type is extracted.
    * @param parameterTypePos The position of the type parameter for which the type is extracted.
    *
    * @return The extracted type.
    */
  @throws(classOf[InvalidTypesException])
  private def extractTypeFromAggregateFunction(
      aggregateFunction: AggregateFunction[_, _],
      parameterTypePos: Int): DataType = {

    new TypeInfoWrappedDataType(TypeExtractor.createTypeInfo(
      aggregateFunction,
      classOf[AggregateFunction[_, _]],
      aggregateFunction.getClass,
      parameterTypePos))
  }

  def getResultTypeOfScalarFunction(
      function: ScalarFunction,
      arguments: Array[AnyRef],
      argTypes: Array[InternalType]): DataType = {
    val userDefinedTypeInfo = function.getResultType(
      arguments, getEvalMethodSignature(function, argTypes))
    if (userDefinedTypeInfo != null) {
      userDefinedTypeInfo
    } else {
      extractTypeFromScalarFunc(function, argTypes)
    }
  }

  private[flink] def extractTypeFromScalarFunc(
      function: ScalarFunction,
      argTypes: Array[InternalType]): DataType = {
    try {TypeExtractor.getForClass(
        getResultTypeClassOfScalarFunction(function, argTypes))
    } catch {
      case _: InvalidTypesException =>
        throw new ValidationException(
          s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
              s"automatically determined. Please provide type information manually.")
    }
  }

  /**
    * Returns the return type of the evaluation method matching the given signature.
    */
  def getResultTypeClassOfScalarFunction(
      function: ScalarFunction,
      argTypes: Array[InternalType]): Class[_] = {
    // find method for signature
    getEvalUserDefinedMethod(function, argTypes).getOrElse(
      throw new IllegalArgumentException("Given signature is invalid.")).getReturnType
  }

  /**
    * Returns the return type of the evaluation method matching the given signature.
    */
  def getResultTypeClassOfPythonScalarFunction(returnType: InternalType): Class[_] = {
    returnType match {
      case _: StringType => classOf[org.apache.flink.table.dataformat.BinaryString]
      case _: BooleanType => classOf[java.lang.Boolean]
      case _: ByteType => classOf[java.lang.Byte]
      case _: ShortType => classOf[java.lang.Short]
      case _: IntType => classOf[java.lang.Integer]
      case _: LongType => classOf[java.lang.Long]
      case _: FloatType => classOf[java.lang.Float]
      case _: DoubleType => classOf[java.lang.Double]
      case _: DecimalType => classOf[java.math.BigDecimal]
      case _: DateType => classOf[java.lang.Integer]
      case _: TimeType => classOf[java.lang.Integer]
      case _: TimestampType => classOf[java.lang.Long]
      case _: ByteArrayType => classOf[Array[Byte]]
    }
  }

  // ----------------------------------------------------------------------------------------------
  // Miscellaneous
  // ----------------------------------------------------------------------------------------------

  /**
    * Returns field names and field positions for a given [[DataType]].
    *
    * Field names are automatically extracted for [[RowType]].
    *
    * @param inputType The DataType to extract the field names and positions from.
    * @return A tuple of two arrays holding the field names and corresponding field positions.
    */
  def getFieldInfo(inputType: DataType)
    : (Array[String], Array[Int], Array[InternalType]) = {
    (
        TableEnvironment.getFieldNames(inputType),
        TableEnvironment.getFieldIndices(inputType),
        TableEnvironment.getFieldTypes(inputType))
  }

  /**
    * Prints one signature consisting of classes.
    */
  def signatureToString(signature: Array[Class[_]]): String =
  signature.map { clazz =>
    if (clazz == null) {
      "null"
    } else {
      clazz.getCanonicalName
    }
  }.mkString("(", ", ", ")")

  /**
    * Prints one signature consisting of DataType.
    */
  def signatureToString(signature: Seq[DataType]): String = {
    signatureToString(typesToClasses(signature))
  }

  /**
    * Prints all signatures of methods with given name in a class.
    */
  def signaturesToString(function: UserDefinedFunction, name: String): String = {
    getMethodSignatures(function, name).map(signatureToString).mkString(", ")
  }

  /**
    * Extracts type classes of [[DataType]] in a null-aware way.
    */
  def typesToClasses(types: Seq[DataType]): Array[Class[_]] =
    types.map { t =>
    if (t == null) {
      null
    } else {
      TypeUtils.getExternalClassForType(t)
    }
  }.toArray

  /**
    * Compares parameter candidate classes with expected classes. If true, the parameters match.
    * Candidate can be null (acts as a wildcard).
    */
  private def parameterClassEquals(candidate: Class[_], expected: Class[_]): Boolean =
  candidate == null ||
    candidate == expected ||
    expected == classOf[Object] ||
    candidate == classOf[Object]  ||  // Special case when we don't know the type
    expected.isPrimitive && Primitives.wrap(expected) == candidate ||
    candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt])  ||
    candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
    candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
    candidate == classOf[BinaryString] && expected == classOf[String] ||
    candidate == classOf[String] && expected == classOf[BinaryString] ||
    classOf[BaseRow].isAssignableFrom(candidate) && expected == classOf[Row] ||
    candidate == classOf[Row] && classOf[BaseRow].isAssignableFrom(expected) ||
    candidate == classOf[Decimal] && expected == classOf[BigDecimal] ||
    candidate == classOf[BigDecimal] && expected == classOf[Decimal] ||
    (candidate.isArray &&
      expected.isArray &&
      candidate.getComponentType.isInstanceOf[Object] &&
      expected.getComponentType == classOf[Object])

  private def parameterDataTypeEquals(internal: InternalType, parameterType: DataType): Boolean = {
    // There is a special equal to GenericType. We need rewrite type extract to BaseRow etc...
    parameterType.toInternalType == internal || internal.getTypeClass == parameterType.getTypeClass
  }

  @throws[Exception]
  def serialize(function: UserDefinedFunction): String = {
    val byteArray = InstantiationUtil.serializeObject(function)
    Base64.encodeBase64URLSafeString(byteArray)
  }

  @throws[Exception]
  def deserialize(data: String): UserDefinedFunction = {
    val byteData = Base64.decodeBase64(data)
    InstantiationUtil
      .deserializeObject[UserDefinedFunction](byteData, Thread.currentThread.getContextClassLoader)
  }

  /**
    * Creates a [[LogicalTableFunctionCall]] by parsing a String expression.
    *
    * @param tableEnv The table environmenent to lookup the function.
    * @param udtf a String expression of a TableFunctionCall, such as "split(c)"
    * @return A LogicalTableFunctionCall.
    */
  def createLogicalFunctionCall(
      tableEnv: TableEnvironment,
      udtf: String): LogicalTableFunctionCall = {

    var alias: Option[Seq[String]] = None

    // unwrap an Expression until we get a TableFunctionCall
    def unwrap(expr: Expression): TableFunctionCall = expr match {
      case Alias(child, name, extraNames) =>
        alias = Some(Seq(name) ++ extraNames)
        unwrap(child)
      case Call(name, args) =>
        val function = tableEnv.chainedFunctionCatalog.lookupFunction(name, args)
        unwrap(function)
      case c: TableFunctionCall => c
      case _ =>
        throw new TableException(
          "Table(TableEnv, String) constructor only accept String that " +
            "define table function followed by some Alias.")
    }

    val functionCall: LogicalTableFunctionCall = unwrap(ExpressionParser.parseExpression(udtf))
      .as(alias).toLogicalTableFunctionCall(child = null)
    functionCall
  }

  def getOperandType(callBinding: SqlOperatorBinding): Seq[InternalType] = {
    val operandTypes = for (i <- 0 until callBinding.getOperandCount)
      yield callBinding.getOperandType(i)
    operandTypes.map { operandType =>
      if (operandType.getSqlTypeName == SqlTypeName.NULL) {
        null
      } else {
        FlinkTypeFactory.toInternalType(operandType)
      }
    }
  }

  private[table] def getResultTypeOfCTDFunction(
      func: CustomTypeDefinedFunction,
      params: Array[Expression],
      getImplicitResultType: () => DataType): DataType = {
    val arguments = params.map {
      case exp: Literal =>
        exp.value.asInstanceOf[AnyRef]
      case _ =>
        null
    }
    val signature = params.map { param =>
      if (param.valid) param.resultType
      else new GenericType(new Object().getClass)  // apply() can not decide type
    }
    val argTypes = getEvalMethodSignature(func, signature)
    val udt = func.getResultType(arguments, argTypes)
    if (udt != null) udt else getImplicitResultType()
  }

  /**
    * Transform the rex nodes to Objects
    * Only literal rex nodes will be transformed, non-literal rex nodes will be
    * translated to nulls.
    *
    * @param rexNodes actual parameters of the function
    * @return A Array of the Objects
    */
  private[table] def transformRexNodes(
      rexNodes: java.util.List[RexNode]): Array[AnyRef] = {
    rexNodes.map {
      case rexNode: RexLiteral =>
        val value = rexNode.getValue2
        rexNode.getType.getSqlTypeName match {
          case SqlTypeName.INTEGER =>
            value.asInstanceOf[Long].toInt.asInstanceOf[AnyRef]
          case SqlTypeName.SMALLINT =>
            value.asInstanceOf[Long].toShort.asInstanceOf[AnyRef]
          case SqlTypeName.TINYINT =>
            value.asInstanceOf[Long].toByte.asInstanceOf[AnyRef]
          case SqlTypeName.FLOAT =>
            value.asInstanceOf[Double].toFloat.asInstanceOf[AnyRef]
          case SqlTypeName.REAL =>
            value.asInstanceOf[Double].toFloat.asInstanceOf[AnyRef]
          case _ =>
            value.asInstanceOf[AnyRef]
        }
      case _ =>
        null
    }.toArray
  }

  private[table] def buildRelDataType(
      typeFactory: RelDataTypeFactory,
      resultType: InternalType,
      fieldNames: Array[String],
      fieldIndexes: Array[Int]): RelDataType = {

    if (fieldIndexes.length != fieldNames.length) {
      throw new TableException(
        "Number of field indexes and field names must be equal.")
    }

    // check uniqueness of field names
    if (fieldNames.length != fieldNames.toSet.size) {
      throw new TableException(
        "Table field names must be unique.")
    }

    val fieldTypes: Array[InternalType] =
      resultType match {
        case bt: RowType =>
          if (fieldNames.length != bt.getArity) {
            throw new TableException(
              s"Arity of type (" + bt.getFieldNames.deep + ") " +
                  "not equal to number of field names " + fieldNames.deep + ".")
          }
          fieldIndexes.map(i => bt.getInternalTypeAt(i).toInternalType)
        case _ =>
          if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) {
            throw new TableException(
              "Non-composite input type may have only a single field and its index must be 0.")
          }
          Array(resultType)
      }

    val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory]
    val builder = flinkTypeFactory.builder
    fieldNames
        .zip(fieldTypes)
        .foreach { f =>
          builder.add(f._1, flinkTypeFactory.createTypeFromInternalType(f._2, isNullable = true))
        }
    builder.build
  }

  def createUserDefinedFunction(
      classLoader: ClassLoader,
      funcName: String,
      functionDef: String): Any = {

    var func: Any = Nil
    val javaClass = functionDef.contains(".")

    if (javaClass) {
      try {
        func = classLoader.loadClass(functionDef).newInstance()
      } catch {
        case e: Exception =>
          throw new RuntimeException(
            TableErrors.INST.sqlCreateUserDefinedFuncError(
              funcName,
              functionDef,
              e.getClass.getCanonicalName + " : " + e.getMessage),
            e)
      }
    } else {
      try {
        // try deserialize first
        func = InstantiationUtil
          .deserializeObject[Object](Base64.decodeBase64(hexString2String(functionDef)),
          classLoader)
      } catch {
        case e: Exception =>
          try {
            // It might be a java class without package name
            func = classLoader.loadClass(functionDef).newInstance()
          } catch {
            case e: Exception =>
              throw new RuntimeException(
                TableErrors.INST.sqlCreateUserDefinedFuncError(
                  funcName,
                  functionDef,
                  e.getClass.getCanonicalName + " : " + e.getMessage),
                e)
          }
      }
    }

    // TODO: Initiate Hive UDF
    // Check hive and covert to flink udf
//    func match {
//      case _: UDF =>
//        func = new HiveSimpleUDF(new HiveFunctionWrapper[UDF](functionDef))
//      case _: GenericUDF =>
//        func = new HiveGenericUDF(new HiveFunctionWrapper[GenericUDF](functionDef))
//      case _: UDAF =>
//        func = new HiveUDAFFunction(new HiveFunctionWrapper[UDAF](functionDef))
//      case _: GenericUDAFResolver2 =>
//        func = new HiveUDAFFunction(new HiveFunctionWrapper[GenericUDAFResolver2](functionDef))
//      case _: GenericUDTF =>
//        func = new HiveGenericUDTF(new HiveFunctionWrapper[GenericUDTF](functionDef))
//      case _ =>
//    }
    func
  }

  /**
    * Extract implicit type from table function through reflection,
    *
    * Broadly, We would consider CustomTypeDefinedFunction#getResultType first, this function
    * should always be considered as a fallback.
    *
    * @return Inferred implicit [[TypeInformation]], if [[InvalidTypesException]] throws, return
    *         GenericTypeInfo(classOf[AnyRef]) as fallback
    */
  def extractResultTypeFromTableFunction[T](tf: TableFunction[T]): TypeInformation[T] = {
    val implicitResultType = try {
      TypeExtractor.createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0)
    } catch {
      case _: InvalidTypesException =>
        new GenericTypeInfo(classOf[AnyRef])
    }
    implicitResultType.asInstanceOf[TypeInformation[T]]
  }

  private def hexString2String(str: String): String = {
    str.sliding(2, 2).map(s =>
      Integer.parseInt(s, 16).asInstanceOf[Char].toString).mkString("")
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy