All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.expressions.call.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.expressions

import java.util
import java.math.BigDecimal

import scala.collection.mutable
import com.google.common.collect.ImmutableList
import org.apache.calcite.rel.RelFieldCollation
import org.apache.calcite.rex.RexWindowBound._
import org.apache.calcite.rex.{RexCall, RexFieldCollation, RexNode, RexWindowBound}
import org.apache.calcite.sql._
import org.apache.calcite.sql.`type`.OrdinalReturnTypeInference
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.api._
import org.apache.flink.table.api.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.api.scala.{CURRENT_ROW, UNBOUNDED_ROW}
import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkTypeFactory}
import org.apache.flink.table.functions.sql.internal.SqlThrowExceptionFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.plan.logical.{LogicalExprVisitor, LogicalNode, LogicalTableFunctionCall}
import org.apache.flink.table.api.types._
import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval
import org.apache.flink.table.typeutils.TypeUtils
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}

import _root_.scala.collection.JavaConverters._

/**
  * General expression for unresolved function calls. The function can be a built-in
  * scalar function or a user-defined scalar function.
  */
case class Call(functionName: String, args: Seq[Expression]) extends Expression {

  override private[flink] def children: Seq[Expression] = args

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    throw UnresolvedException(s"trying to convert UnresolvedFunction $functionName to RexNode")
  }

  override def toString = s"\\$functionName(${args.mkString(", ")})"

  override private[flink] def resultType =
    throw UnresolvedException(s"calling resultType on UnresolvedFunction $functionName")

  override private[flink] def validateInput(): ValidationResult =
    ValidationFailure(s"Unresolved function call: $functionName")

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

/**
  * Over call with unresolved alias for over window.
  *
  * @param agg The aggregation of the over call.
  * @param alias The alias of the referenced over window.
  */
case class UnresolvedOverCall(agg: Expression, alias: Expression) extends Expression {

  override private[flink] def validateInput() =
    ValidationFailure(s"Over window with alias $alias could not be resolved.")

  override private[flink] def resultType = agg.resultType

  override private[flink] def children = Seq()

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

/**
  * Over expression for Calcite over transform.
  *
  * @param agg            over-agg expression
  * @param partitionBy    The fields by which the over window is partitioned
  * @param orderBy        The field by which the over window is sorted
  * @param preceding      The lower bound of the window
  * @param following      The upper bound of the window
  */
case class OverCall(
    agg: Expression,
    partitionBy: Seq[Expression],
    orderBy: Seq[Expression],
    var preceding: Expression,
    var following: Expression,
    tableEnv: TableEnvironment) extends Expression {

  override def toString: String = s"$agg OVER (" +
    s"PARTITION BY (${partitionBy.mkString(", ")}) " +
    s"ORDER BY (${orderBy.mkString(", ")})" +
    s"PRECEDING $preceding " +
    s"FOLLOWING $following)"

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {

    val rexBuilder = relBuilder.getRexBuilder

    // assemble aggregation
    val operator: SqlAggFunction = agg.asInstanceOf[Aggregation].getSqlAggFunction()
    val aggResultType = relBuilder
      .getTypeFactory.asInstanceOf[FlinkTypeFactory]
      .createTypeFromInternalType(agg.resultType, isNullable = true)

    // assemble exprs by agg children
    val aggExprs = agg.asInstanceOf[Aggregation].children.map(_.toRexNode(relBuilder)).asJava

    // assemble order by key
    def collation(
      node: RexNode, direction: RelFieldCollation.Direction,
      nullDirection: RelFieldCollation.NullDirection, kinds: mutable.Set[SqlKind]): RexNode = {
      node.getKind match {
        case SqlKind.DESCENDING =>
          kinds.add(node.getKind)
          collation(
            node.asInstanceOf[RexCall].getOperands.get(0),
            RelFieldCollation.Direction.DESCENDING, nullDirection, kinds)
        case SqlKind.NULLS_FIRST =>
          kinds.add(node.getKind)
          collation(
            node.asInstanceOf[RexCall].getOperands.get(0), direction,
            RelFieldCollation.NullDirection.FIRST, kinds)
        case SqlKind.NULLS_LAST =>
          kinds.add(node.getKind)
          collation(
            node.asInstanceOf[RexCall].getOperands.get(0), direction,
            RelFieldCollation.NullDirection.LAST, kinds)
        case _ =>
          if (nullDirection == null) {
            // Set the null direction if not specified.
            // Consistent with HIVE/SPARK/MYSQL/BLINK-RUNTIME.
            if (FlinkPlannerImpl.defaultNullCollation.last(
              direction.equals(RelFieldCollation.Direction.DESCENDING))) {
              kinds.add(SqlKind.NULLS_LAST)
            } else {
              kinds.add(SqlKind.NULLS_FIRST)
            }
          }
          node
      }
    }

    val orderKeysBuilder = new ImmutableList.Builder[RexFieldCollation]()
    for (orderExp <- orderBy) {
      val kinds: mutable.Set[SqlKind] = mutable.Set()
      val rexNode = collation(
        orderExp.toRexNode, RelFieldCollation.Direction.ASCENDING, null, kinds)
      val orderKey = new RexFieldCollation(rexNode, kinds.asJava)
      orderKeysBuilder.add(orderKey)
    }
    val orderKeys = orderKeysBuilder.build()

    // assemble partition by keys
    val partitionKeys = partitionBy.map(_.toRexNode).asJava

    // assemble bounds
    var isPhysical: Boolean = false
    var lowerBound, upperBound: RexWindowBound = null
    agg match {
      case _: RowNumber | _: Rank | _: DenseRank =>
        isPhysical = true
        lowerBound = createBound(relBuilder, UNBOUNDED_ROW, SqlKind.PRECEDING)
        upperBound = createBound(relBuilder, CURRENT_ROW, SqlKind.FOLLOWING)
      case _ =>
        isPhysical = preceding.resultType == DataTypes.INTERVAL_ROWS
        lowerBound = createBound(relBuilder, preceding, SqlKind.PRECEDING)
        upperBound = createBound(relBuilder, following, SqlKind.FOLLOWING)
    }

    // build RexOver
    rexBuilder.makeOver(
      aggResultType,
      operator,
      aggExprs,
      partitionKeys,
      orderKeys,
      lowerBound,
      upperBound,
      isPhysical,
      true,
      false,
      false)
  }

  private def createBound(
    relBuilder: RelBuilder,
    bound: Expression,
    sqlKind: SqlKind): RexWindowBound = {

    bound match {
      case _: UnboundedRow | _: UnboundedRange =>
        val unbounded = if (sqlKind == SqlKind.PRECEDING) {
          SqlWindow.createUnboundedPreceding(SqlParserPos.ZERO)
        } else {
          SqlWindow.createUnboundedFollowing(SqlParserPos.ZERO)
        }
        create(unbounded, null)
      case _: CurrentRow | _: CurrentRange =>
        val currentRow = SqlWindow.createCurrentRow(SqlParserPos.ZERO)
        create(currentRow, null)
      case b: Literal =>
        val DECIMAL_PRECISION_NEEDED_FOR_LONG = 19
        val returnType = relBuilder
          .getTypeFactory.asInstanceOf[FlinkTypeFactory]
          .createTypeFromInternalType(
            DecimalType.of(DECIMAL_PRECISION_NEEDED_FOR_LONG, 0),
            isNullable = true)

        val sqlOperator = new SqlPostfixOperator(
          sqlKind.name,
          sqlKind,
          2,
          new OrdinalReturnTypeInference(0),
          null,
          null)

        val operands: Array[SqlNode] = new Array[SqlNode](1)
        operands(0) = SqlLiteral.createExactNumeric("1", SqlParserPos.ZERO)

        val node = new SqlBasicCall(sqlOperator, operands, SqlParserPos.ZERO)

        val expressions: util.ArrayList[RexNode] = new util.ArrayList[RexNode]()
        b.value match {
          case v: Double => expressions.add(relBuilder.literal(
            BigDecimal.valueOf(v.asInstanceOf[Number].doubleValue())))
          case _ => expressions.add(relBuilder.literal(b.value))
        }

        val rexNode = relBuilder.getRexBuilder.makeCall(returnType, sqlOperator, expressions)

        create(node, rexNode)
    }
  }

  override private[flink] def children: Seq[Expression] =
    Seq(agg) ++ orderBy ++ partitionBy ++ Seq(preceding) ++ Seq(following)

  override private[flink] def resultType = agg.resultType

  override private[flink] def validateInput(): ValidationResult = {

    // check that agg expression is aggregation
    agg match {
      case _: Aggregation =>
        ValidationSuccess
      case _ =>
        return ValidationFailure(s"OVER can only be applied on an aggregation.")
    }

    // check partitionBy expression keys are resolved field reference
    partitionBy.foreach {
      case r: ResolvedFieldReference
        if TypeConverters.createExternalTypeInfoFromDataType(r.resultType).isKeyType  =>
        ValidationSuccess
      case r: ResolvedFieldReference =>
        return ValidationFailure(s"Invalid PartitionBy expression: $r. " +
          s"Expression must return key type.")
      case r =>
        return ValidationFailure(s"Invalid PartitionBy expression: $r. " +
          s"Expression must be a resolved field reference.")
    }

    // check preceding is valid
    preceding match {
      case _: CurrentRow | _: CurrentRange | _: UnboundedRow | _: UnboundedRange =>
        ValidationSuccess
      case Literal(_: Long, DataTypes.INTERVAL_ROWS) =>
        ValidationSuccess
      case Literal(_: Long, DataTypes.INTERVAL_RANGE) |
           Literal(_: Double, DataTypes.INTERVAL_RANGE) =>
        ValidationSuccess
      case Literal(v: Long, b) if isTimeInterval(b) && v >= 0 =>
        ValidationSuccess
      case Literal(_, b) if isTimeInterval(b) =>
        return ValidationFailure("Preceding time interval must be equal or larger than 0.")
      case Literal(_, _) =>
        return ValidationFailure("Preceding must be a row interval or time interval literal.")
    }

    // check following is valid
    following match {
      case _: CurrentRow | _: CurrentRange | _: UnboundedRow | _: UnboundedRange =>
        ValidationSuccess
      case Literal(_: Long, DataTypes.INTERVAL_ROWS) =>
        ValidationSuccess
      case Literal(_: Long, DataTypes.INTERVAL_RANGE) |
           Literal(_: Double, DataTypes.INTERVAL_RANGE) =>
        ValidationSuccess
      case Literal(v: Long, b) if isTimeInterval(b) && v >= 0 =>
        ValidationSuccess
      case Literal(_, b) if isTimeInterval(b) =>
        return ValidationFailure("Following time interval must be equal or larger than 0.")
      case Literal(_, _) =>
        return ValidationFailure("Following must be a row interval or time interval literal.")
    }

    // check that preceding and following are of same type
    (preceding, following) match {
      case (p: Expression, f: Expression) if p.resultType == f.resultType =>
        ValidationSuccess
      case (_: UnboundedRange, f: Expression) if f.resultType == DataTypes.INTERVAL_MILLIS =>
        ValidationSuccess
      case (_: CurrentRange, f: Expression) if f.resultType == DataTypes.INTERVAL_MILLIS =>
        ValidationSuccess
      case (p: Expression, _: UnboundedRange) if p.resultType == DataTypes.INTERVAL_MILLIS =>
        ValidationSuccess
      case (p: Expression, _: CurrentRange) if p.resultType == DataTypes.INTERVAL_MILLIS =>
        ValidationSuccess
      case _ =>
        return ValidationFailure("Preceding and following must be of same interval type.")
    }

    if (tableEnv.isInstanceOf[StreamTableEnvironment]) {
      validateInputForStream()
    } else {
      ValidationSuccess
    }
  }

  private[flink] def validateInputForStream(): ValidationResult = {
    // check preceding/following range
    preceding match {
      case Literal(_, DataTypes.INTERVAL_RANGE) =>
        return ValidationFailure("Stream table API does not support value range.")
      case Literal(v: Long, DataTypes.INTERVAL_ROWS) if v < 0 =>
        return ValidationFailure("Stream table API does not support negative preceding.")
      case _ =>
        ValidationSuccess
    }
    following match {
      case _: UnboundedRow | _: UnboundedRange =>
        return ValidationFailure("Stream table API does not support unbounded following.")
      case Literal(_, DataTypes.INTERVAL_RANGE) =>
        return ValidationFailure("Stream table API does not support value range.")
      case Literal(v: Long, DataTypes.INTERVAL_ROWS) if v > 0 =>
        return ValidationFailure("Stream table API does not support positive following.")
      case _ =>
        ValidationSuccess
    }

    // check lead/lag offset
    agg match {
      case a: Lead =>
        if (a.getOffsetValue > 0) {
          return ValidationFailure("Stream table API does not support positive offset for lead.")
        }
      case a: Lag =>
        if (a.getOffsetValue < 0) {
          return ValidationFailure("Stream table API does not support negative offset for lag.")
        }
      case _ =>
        ValidationSuccess
    }

    ValidationSuccess
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

object ScalarFunctionCall {
  def apply(
      scalarFunction: ScalarFunction,
      parameters: Array[Expression]): ScalarFunctionCall =
    new ScalarFunctionCall(scalarFunction, parameters)
}

/**
  * Expression for calling a user-defined scalar functions.
  *
  * @param scalarFunction scalar function to be called (might be overloaded)
  * @param parameters actual parameters that determine target evaluation method
  */
case class ScalarFunctionCall(
    scalarFunction: ScalarFunction,
    parameters: Seq[Expression])
  extends Expression {

  override private[flink] def children: Seq[Expression] = parameters

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
    relBuilder.call(
      createScalarSqlFunction(
        scalarFunction.functionIdentifier,
        scalarFunction.toString,
        scalarFunction,
        typeFactory),
      parameters.map(_.toRexNode): _*)
  }

  override def toString =
    s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"

  override private[flink] def resultType =
    getResultTypeOfCTDFunction(
      scalarFunction,
      parameters.toArray,
      () => {extractTypeFromScalarFunc(
        scalarFunction, parameters.map(_.resultType).toArray)}).toInternalType

  override private[flink] def validateInput(): ValidationResult = {
    val signature = children.map(_.resultType)
    // look for a signature that matches the input types
    if (getEvalUserDefinedMethod(scalarFunction, signature).isEmpty) {
      ValidationFailure(s"Given parameters do not match any signature. \n" +
        s"Actual: ${signatureToString(signature)} \n" +
        s"Expected: ${signaturesToString(scalarFunction, "eval")}")
    } else {
      ValidationSuccess
    }
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

/**
  *
  * Expression for calling a user-defined table function with actual parameters.
  *
  * @param functionName function name
  * @param tableFunction user-defined table function
  * @param parameters actual parameters of function
  * @param externalResultType type information of returned table
  */
case class TableFunctionCall(
    functionName: String,
    tableFunction: TableFunction[_],
    parameters: Seq[Expression],
    externalResultType: DataType)
  extends Expression {

  private var aliases: Option[Seq[String]] = None

  override private[flink] def children: Seq[Expression] = parameters

  override private[flink] def resultType: InternalType = externalResultType.toInternalType

  /**
    * Assigns an alias for this table function's returned fields that the following operator
    * can refer to.
    *
    * @param aliasList alias for this table function's returned fields
    * @return this table function call
    */
  private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = {
    this.aliases = aliasList
    this
  }

  /**
    * Converts an API class to a logical node for planning.
    */
  private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = {
    val originNames = getFieldInfo(resultType)._1

    // determine the final field names
    val fieldNames = if (aliases.isDefined) {
      val aliasList = aliases.get
      if (aliasList.length != originNames.length) {
        throw new ValidationException(
          s"List of column aliases must have same degree as table; " +
            s"the returned table of function '$functionName' has ${originNames.length} " +
            s"columns (${originNames.mkString(",")}), " +
            s"whereas alias list has ${aliasList.length} columns")
      } else {
        aliasList.toArray
      }
    } else {
      originNames
    }

    LogicalTableFunctionCall(
      functionName,
      tableFunction,
      parameters,
      externalResultType,
      fieldNames,
      child)
  }

  override def toString =
    s"${tableFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})"

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class ThrowException(msg: Expression, tp: InternalType) extends UnaryExpression {

  override private[flink] def validateInput(): ValidationResult = {
    if (child.resultType == DataTypes.STRING) {
      ValidationSuccess
    } else {
      ValidationFailure(s"ThrowException operator requires String input, " +
          s"but $child is of type ${child.resultType}")
    }
  }

  override def toString: String = s"ThrowException($msg)"

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    relBuilder.call(
      new SqlThrowExceptionFunction(
        tp,
        relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]),
      msg.toRexNode)
  }

  override private[flink] def child = msg

  override private[flink] def resultType = tp

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy