All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.scala Maven / Gradle / Ivy

Go to download

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package org.apache.flink.table.functions.utils

import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor}
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.dataview.{ListViewTypeInfo, MapViewTypeInfo}
import org.apache.flink.table.functions._
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.table.typeutils.FieldInfoUtils

import com.google.common.primitives.Primitives
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
import org.apache.calcite.sql.`type`._
import org.apache.calcite.sql.{SqlCallBinding, SqlFunction, SqlOperandCountRange, SqlOperator}

import java.lang.reflect.{Method, Modifier}
import java.lang.{Integer => JInt, Long => JLong}
import java.sql.{Date, Time, Timestamp}
import java.util

import scala.collection.mutable

object UserDefinedFunctionUtils {

  // ----------------------------------------------------------------------------------------------
  // Utilities for user-defined methods
  // ----------------------------------------------------------------------------------------------

  /**
    * Returns the signature of the eval method matching the given signature of [[TypeInformation]].
    * Elements of the signature can be null (act as a wildcard).
    */
  def getEvalMethodSignature(
      function: UserDefinedFunction,
      signature: Seq[TypeInformation[_]])
    : Option[Array[Class[_]]] = {

    getUserDefinedMethod(function, "eval", typeInfoToClass(signature)).map(_.getParameterTypes)
  }

  /**
    * Returns the signature of the accumulate method matching the given signature
    * of [[TypeInformation]]. Elements of the signature can be null (act as a wildcard).
    */
  def getAccumulateMethodSignature(
      function: UserDefinedAggregateFunction[_, _],
      signature: Seq[TypeInformation[_]])
  : Option[Array[Class[_]]] = {
    val accType = TypeExtractor.createTypeInfo(
      function, classOf[UserDefinedAggregateFunction[_, _]], function.getClass, 1)
    val input = (Array(accType) ++ signature).toSeq
    getUserDefinedMethod(
      function,
      "accumulate",
      typeInfoToClass(input)).map(_.getParameterTypes)
  }

  def getParameterTypes(
      function: UserDefinedFunction,
      signature: Array[Class[_]]): Array[TypeInformation[_]] = {
    signature.map { c =>
      try {
        TypeExtractor.getForClass(c)
      } catch {
        case ite: InvalidTypesException =>
          throw new ValidationException(
            s"Parameter types of function '${function.getClass.getCanonicalName}' cannot be " +
              s"automatically determined. Please provide type information manually.")
      }
    }
  }

  /**
    * Returns user defined method matching the given name and signature.
    *
    * @param function        function instance
    * @param methodName      method name
    * @param methodSignature an array of raw Java classes. We compare the raw Java classes not the
    *                        TypeInformation. TypeInformation does not matter during runtime (e.g.
    *                        within a MapFunction)
    */
  def getUserDefinedMethod(
      function: UserDefinedFunction,
      methodName: String,
      methodSignature: Array[Class[_]])
    : Option[Method] = {

    val methods = checkAndExtractMethods(function, methodName)

    val filtered = methods
      // go over all the methods and filter out matching methods
      .filter {
        case cur if !cur.isVarArgs =>
          val signatures = cur.getParameterTypes
          // match parameters of signature to actual parameters
          methodSignature.length == signatures.length &&
            signatures.zipWithIndex.forall { case (clazz, i) =>
              parameterTypeApplicable(methodSignature(i), clazz)
          }
        case cur if cur.isVarArgs =>
          val signatures = cur.getParameterTypes
          methodSignature.zipWithIndex.forall {
            // non-varargs
            case (clazz, i) if i < signatures.length - 1  =>
              parameterTypeApplicable(clazz, signatures(i))
            // varargs
            case (clazz, i) if i >= signatures.length - 1 =>
              parameterTypeApplicable(clazz, signatures.last.getComponentType)
          } || (methodSignature.isEmpty && signatures.length == 1) // empty varargs
    }

    // if there is a fixed method, compiler will call this method preferentially
    val fixedMethodsCount = filtered.count(!_.isVarArgs)
    val found = filtered.filter { cur =>
      fixedMethodsCount > 0 && !cur.isVarArgs ||
      fixedMethodsCount == 0 && cur.isVarArgs
    }
    val maximallySpecific = if (found.length > 1) {
      implicit val methodOrdering = new scala.Ordering[Method] {
        override def compare(x: Method, y: Method): Int = {
          def specificThan(left: Method, right: Method) = {
            // left parameter type is more specific than right parameter type
            left.getParameterTypes.zip(right.getParameterTypes).forall {
              case (leftParameterType, rightParameterType) =>
                parameterTypeApplicable(leftParameterType, rightParameterType)
            } &&
            // non-equal
            left.getParameterTypes.zip(right.getParameterTypes).exists {
              case (leftParameterType, rightParameterType) =>
                !parameterTypeEquals(leftParameterType, rightParameterType)
            }
          }

          if (specificThan(x, y)) {
            1
          } else if (specificThan(y, x)) {
            -1
          } else {
            0
          }
        }
      }

      val max = found.max
      found.filter(methodOrdering.compare(max, _) == 0)
    } else {
      found
    }

    // check if there is a Scala varargs annotation
    if (maximallySpecific.isEmpty &&
      methods.exists { method =>
        val signatures = method.getParameterTypes
        signatures.zipWithIndex.forall {
          case (clazz, i) if i < signatures.length - 1 =>
            parameterTypeApplicable(methodSignature(i), clazz)
          case (clazz, i) if i == signatures.length - 1 =>
            clazz.getName.equals("scala.collection.Seq")
        }
      }) {
      throw new ValidationException(
        s"Scala-style variable arguments in '$methodName' methods are not supported. Please " +
          s"add a @scala.annotation.varargs annotation.")
    } else if (maximallySpecific.length > 1) {
      throw new ValidationException(
        s"Found multiple '$methodName' methods which match the signature.")
    }
    maximallySpecific.headOption
  }

  /**
    * Checks if a given method exists in the given function
    */
  def ifMethodExistInFunction(method: String, function: UserDefinedFunction): Boolean = {
    val methods = function
      .getClass
      .getMethods
      .filter {
        m => m.getName == method
      }
    !methods.isEmpty
  }

  /**
    * Extracts methods and throws a [[ValidationException]] if no implementation
    * can be found, or implementation does not match the requirements.
    */
  def checkAndExtractMethods(
      function: UserDefinedFunction,
      methodName: String): Array[Method] = {
    val methods = function
      .getClass
      .getMethods
      .filter { m =>
        val modifiers = m.getModifiers
        m.getName == methodName &&
          Modifier.isPublic(modifiers) &&
          !Modifier.isAbstract(modifiers) &&
          !(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers))
      }

    if (methods.isEmpty) {
      throw new ValidationException(
        s"Function class '${function.getClass.getCanonicalName}' does not implement at least " +
          s"one method named '$methodName' which is public, not abstract and " +
          s"(in case of table functions) not static.")
    }

    methods
  }

  def getMethodSignatures(
      function: UserDefinedFunction,
      methodName: String): Array[Array[Class[_]]] = {
    checkAndExtractMethods(function, methodName).map(_.getParameterTypes)
  }

  // ----------------------------------------------------------------------------------------------
  // Utilities for SQL functions
  // ----------------------------------------------------------------------------------------------

  /**
    * Creates [[SqlFunction]] for a [[ScalarFunction]]
    *
    * @param name function name
    * @param function scalar function
    * @param typeFactory type factory
    * @return the ScalarSqlFunction
    */
  def createScalarSqlFunction(
      name: String,
      displayName: String,
      function: ScalarFunction,
      typeFactory: FlinkTypeFactory)
    : SqlFunction = {
    new ScalarSqlFunction(name, displayName, function, typeFactory)
  }

  /**
    * Creates [[SqlFunction]] for a [[TableFunction]]
    *
    * @param name function name
    * @param tableFunction table function
    * @param resultType the type information of returned table
    * @param typeFactory type factory
    * @return the TableSqlFunction
    */
  def createTableSqlFunction(
      name: String,
      displayName: String,
      tableFunction: TableFunction[_],
      resultType: TypeInformation[_],
      typeFactory: FlinkTypeFactory)
    : SqlFunction = {
    val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)
    val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames)
    new TableSqlFunction(name, displayName, tableFunction, resultType, typeFactory, function)
  }

  /**
    * Creates [[SqlFunction]] for an [[AggregateFunction]]
    *
    * @param name function name
    * @param aggFunction aggregate function
    * @param typeFactory type factory
    * @return the TableSqlFunction
    */
  def createAggregateSqlFunction(
      name: String,
      displayName: String,
      aggFunction: UserDefinedAggregateFunction[_, _],
      resultType: TypeInformation[_],
      accTypeInfo: TypeInformation[_],
      typeFactory: FlinkTypeFactory)
  : SqlFunction = {
    //check if a qualified accumulate method exists before create Sql function
    checkAndExtractMethods(aggFunction, "accumulate")

    AggSqlFunction(
      name,
      displayName,
      aggFunction,
      resultType,
      accTypeInfo,
      typeFactory)
  }

  /**
    * Creates a [[SqlOperandTypeChecker]] for SQL validation of
    * eval functions (scalar and table functions).
    */
  def createEvalOperandTypeChecker(
      name: String,
      function: UserDefinedFunction)
    : SqlOperandTypeChecker = {

    val methods = checkAndExtractMethods(function, "eval")

    new SqlOperandTypeChecker {
      override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
        s"$opName[${signaturesToString(function, "eval")}]"
      }

      override def getOperandCountRange: SqlOperandCountRange = {
        var min = 254
        var max = -1
        var isVarargs = false
        methods.foreach( m => {
          var len = m.getParameterTypes.length
          if (len > 0 && m.isVarArgs && m.getParameterTypes()(len - 1).isArray) {
            isVarargs = true
            len = len - 1
          }
          max = Math.max(len, max)
          min = Math.min(len, min)
        })
        if (isVarargs) {
          // if eval method is varargs, set max to -1 to skip length check in Calcite
          max = -1
        }

        SqlOperandCountRanges.between(min, max)
      }

      override def checkOperandTypes(
          callBinding: SqlCallBinding,
          throwOnFailure: Boolean)
        : Boolean = {
        val operandTypeInfo = getOperandTypeInfo(callBinding)

        val foundSignature = getEvalMethodSignature(function, operandTypeInfo)

        if (foundSignature.isEmpty) {
          if (throwOnFailure) {
            throw new ValidationException(
              s"Given parameters of function '$name' do not match any signature. \n" +
                s"Actual: ${signatureToString(operandTypeInfo)} \n" +
                s"Expected: ${signaturesToString(function, "eval")}")
          } else {
            false
          }
        } else {
          true
        }
      }

      override def isOptional(i: Int): Boolean = false

      override def getConsistency: Consistency = Consistency.NONE

    }
  }

  /**
    * Creates a [[SqlOperandTypeInference]] for the SQL validation of eval functions
    * (scalar and table functions).
    */
  def createEvalOperandTypeInference(
    name: String,
    function: UserDefinedFunction,
    typeFactory: FlinkTypeFactory)
  : SqlOperandTypeInference = {

    new SqlOperandTypeInference {
      override def inferOperandTypes(
          callBinding: SqlCallBinding,
          returnType: RelDataType,
          operandTypes: Array[RelDataType]): Unit = {

        val operandTypeInfo = getOperandTypeInfo(callBinding)

        val foundSignature = getEvalMethodSignature(function, operandTypeInfo)
          .getOrElse(throw new ValidationException(
            s"Given parameters of function '$name' do not match any signature. \n" +
              s"Actual: ${signatureToString(operandTypeInfo)} \n" +
              s"Expected: ${signaturesToString(function, "eval")}"))

        val inferredTypes = function match {
          case sf: ScalarFunction =>
            sf.getParameterTypes(foundSignature)
              .map(typeFactory.createTypeFromTypeInfo(_, isNullable = true))
          case tf: TableFunction[_] =>
            tf.getParameterTypes(foundSignature)
              .map(typeFactory.createTypeFromTypeInfo(_, isNullable = true))
          case _ => throw new TableException("Unsupported function.")
        }

        for (i <- operandTypes.indices) {
          if (i < inferredTypes.length - 1) {
            operandTypes(i) = inferredTypes(i)
          } else if (null != inferredTypes.last.getComponentType) {
            // last argument is a collection, the array type
            operandTypes(i) = inferredTypes.last.getComponentType
          } else {
            operandTypes(i) = inferredTypes.last
          }
        }
      }
    }
  }

  // ----------------------------------------------------------------------------------------------
  // Utilities for user-defined functions
  // ----------------------------------------------------------------------------------------------

  /**
    * Remove StateView fields from accumulator type information.
    *
    * @param index index of aggregate function
    * @param acc accumulator
    * @param accType accumulator type information, only support pojo type
    * @param isStateBackedDataViews is data views use state backend
    * @return mapping of accumulator type information and data view config which contains id,
    *         field name and state descriptor
    */
  def removeStateViewFieldsFromAccTypeInfo[ACC](
      index: Int,
      acc: ACC,
      accType: TypeInformation[_],
      isStateBackedDataViews: Boolean)
    : (TypeInformation[_], Option[Seq[DataViewSpec[_]]]) = {

    /** Recursively checks if composite type includes a data view type. */
    def includesDataView(ct: CompositeType[_]): Boolean = {
      (0 until ct.getArity).exists(i =>
        ct.getTypeAt(i) match {
          case nestedCT: CompositeType[_] => includesDataView(nestedCT)
          case t: TypeInformation[_] if t.getTypeClass == classOf[ListView[_]] => true
          case t: TypeInformation[_] if t.getTypeClass == classOf[MapView[_, _]] => true
          case _ => false
        }
      )
    }

    accType match {
      case pojoType: PojoTypeInfo[_] if pojoType.getArity > 0 =>
        val arity = pojoType.getArity
        val newPojoFields = new util.ArrayList[PojoField]()
        val accumulatorSpecs = new mutable.ArrayBuffer[DataViewSpec[_]]
        for (i <- 0 until arity) {
          val pojoField = pojoType.getPojoFieldAt(i)
          val field = pojoField.getField
          val fieldName = field.getName
          field.setAccessible(true)

          pojoField.getTypeInformation match {
            case ct: CompositeType[_] if includesDataView(ct) =>
              throw new TableException(
                "MapView and ListView only supported at first level of accumulators of Pojo type.")
            case map: MapViewTypeInfo[_, _] =>
              val mapView = field.get(acc).asInstanceOf[MapView[_, _]]
              if (mapView != null) {
                val keyTypeInfo = mapView.keyType
                val valueTypeInfo = mapView.valueType
                val newTypeInfo = if (keyTypeInfo != null && valueTypeInfo != null) {
                  new MapViewTypeInfo(keyTypeInfo, valueTypeInfo)
                } else {
                  map
                }

                // create map view specs with unique id (used as state name)
                var spec = MapViewSpec(
                  "agg" + index + "$" + fieldName,
                  field,
                  newTypeInfo)

                accumulatorSpecs += spec
                if (!isStateBackedDataViews) {
                  // add data view field if it is not backed by a state backend.
                  // data view fields which are backed by state backend are not serialized.
                  newPojoFields.add(new PojoField(field, newTypeInfo))
                }
              }

            case list: ListViewTypeInfo[_] =>
              val listView = field.get(acc).asInstanceOf[ListView[_]]
              if (listView != null) {
                val elementTypeInfo = listView.elementType
                val newTypeInfo = if (elementTypeInfo != null) {
                  new ListViewTypeInfo(elementTypeInfo)
                } else {
                  list
                }

                // create list view specs with unique is (used as state name)
                var spec = ListViewSpec(
                  "agg" + index + "$" + fieldName,
                  field,
                  newTypeInfo)

                accumulatorSpecs += spec
                if (!isStateBackedDataViews) {
                  // add data view field if it is not backed by a state backend.
                  // data view fields which are backed by state backend are not serialized.
                  newPojoFields.add(new PojoField(field, newTypeInfo))
                }
              }

            case _ => newPojoFields.add(pojoField)
          }
        }
        (new PojoTypeInfo(accType.getTypeClass, newPojoFields), Some(accumulatorSpecs))
      case ct: CompositeType[_] if includesDataView(ct) =>
        throw new TableException(
          "MapView and ListView only supported in accumulators of POJO type.")
      case _ => (accType, None)
    }
  }

  /**
    * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses
    * [[TypeExtractor]] as default return type inference.
    */
  def getResultTypeOfScalarFunction(
      function: ScalarFunction,
      signature: Array[Class[_]])
    : TypeInformation[_] = {

    val userDefinedTypeInfo = function.getResultType(signature)
    if (userDefinedTypeInfo != null) {
      userDefinedTypeInfo
    } else {
      try {
        TypeExtractor.getForClass(getResultTypeClassOfScalarFunction(function, signature))
      } catch {
        case ite: InvalidTypesException =>
          throw new ValidationException(
            s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " +
              s"automatically determined. Please provide type information manually.")
      }
    }
  }

  /**
    * Returns the return type of the evaluation method matching the given signature.
    */
  def getResultTypeClassOfScalarFunction(
      function: ScalarFunction,
      signature: Array[Class[_]])
    : Class[_] = {
    // find method for signature
    val evalMethod = checkAndExtractMethods(function, "eval")
      .find(m => signature.sameElements(m.getParameterTypes))
      .getOrElse(throw new IllegalArgumentException("Given signature is invalid."))
    evalMethod.getReturnType
  }

  // ----------------------------------------------------------------------------------------------
  // Miscellaneous
  // ----------------------------------------------------------------------------------------------

  /**
    * Returns field names and field positions for a given [[TypeInformation]].
    *
    * Field names are automatically extracted for
    * [[org.apache.flink.api.common.typeutils.CompositeType]].
    *
    * @param inputType The TypeInformation to extract the field names and positions from.
    * @return A tuple of two arrays holding the field names and corresponding field positions.
    */
  def getFieldInfo(inputType: TypeInformation[_])
    : (Array[String], Array[Int], Array[TypeInformation[_]]) = {

    (FieldInfoUtils.getFieldNames(inputType),
      FieldInfoUtils.getFieldIndices(inputType),
      FieldInfoUtils.getFieldTypes(inputType))
  }

  /**
    * Prints one signature consisting of classes.
    */
  def signatureToString(signature: Array[Class[_]]): String =
  signature.map { clazz =>
    if (clazz == null) {
      "null"
    } else {
      clazz.getCanonicalName
    }
  }.mkString("(", ", ", ")")

  /**
    * Prints one signature consisting of TypeInformation.
    */
  def signatureToString(signature: Seq[TypeInformation[_]]): String = {
    signatureToString(typeInfoToClass(signature))
  }

  /**
    * Prints all signatures of methods with given name in a class.
    */
  def signaturesToString(function: UserDefinedFunction, name: String): String = {
    getMethodSignatures(function, name).map(signatureToString).mkString(", ")
  }

  /**
    * Extracts type classes of [[TypeInformation]] in a null-aware way.
    */
  def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] =
  typeInfos.map { typeInfo =>
    if (typeInfo == null) {
      null
    } else {
      typeInfo.getTypeClass
    }
  }.toArray

  /**
    * Compares parameter candidate classes with expected classes. If true, the parameters match.
    * Candidate can be null (acts as a wildcard).
    */
  private def parameterTypeApplicable(candidate: Class[_], expected: Class[_]): Boolean =
    parameterTypeEquals(candidate, expected) ||
      ((expected != null && expected.isAssignableFrom(candidate)) ||
        expected.isPrimitive && Primitives.wrap(expected).isAssignableFrom(candidate))

  private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean =
  candidate == null ||
    candidate == expected ||
    expected.isPrimitive && Primitives.wrap(expected) == candidate ||
    // time types
    candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt])  ||
    candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) ||
    candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) ||
    // arrays
    (candidate.isArray && expected.isArray &&
      (candidate.getComponentType == expected.getComponentType))

  def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = {
    val operandTypes = for (i <- 0 until callBinding.getOperandCount)
      yield callBinding.getOperandType(i)
    operandTypes.map { operandType =>
      if (operandType.getSqlTypeName == SqlTypeName.NULL) {
        null
      } else {
        FlinkTypeFactory.toTypeInfo(operandType)
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy