All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.expressions.aggregations.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.expressions

import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.fun._
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.tools.RelBuilder.AggCall
import org.apache.flink.table.api.types.{DataType, DataTypes, InternalType, MultisetType}
import org.apache.flink.table.api.functions.AggregateFunction
import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem}
import org.apache.flink.table.functions.sql.AggSqlFunctions
import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.plan.logical.LogicalExprVisitor
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}

abstract sealed class Aggregation extends Expression {

  override def toString = s"Aggregate"

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
    throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode")

  /**
    * Convert Aggregate to its counterpart in Calcite, i.e. AggCall
    */
  private[flink] def toAggCall(
      name: String,
      isDistinct: Boolean = false
  )(implicit relBuilder: RelBuilder): AggCall

  /**
    * Returns the SqlAggFunction for this Aggregation.
    */
  private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder): SqlAggFunction

}

// build-in Aggregations
case class DistinctAgg(child: Expression) extends Aggregation {

  private[flink] def distinct: Expression = DistinctAgg(child)

  override private[flink] def resultType: InternalType = child.resultType

  override private[flink] def validateInput(): ValidationResult = {
    super.validateInput()
    child match {
      case agg: Aggregation =>
        child.validateInput()
      case _ =>
        ValidationFailure(s"Distinct modifier cannot be applied to $child! " +
            s"It can only be applied to an aggregation expression, for example, " +
            s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
    }
  }

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = {
    child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true)
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    child.asInstanceOf[Aggregation].getSqlAggFunction()
  }

  override private[flink] def children = Seq(child)

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Sum(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"sum($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.SUM,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveSumType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "sum")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    val returnType = relBuilder
      .getTypeFactory.asInstanceOf[FlinkTypeFactory]
      .createTypeFromInternalType(resultType, isNullable = true)
    new SqlSumAggFunction(returnType)
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Sum0(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"sum0($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.SUM0,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveSumType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "sum0")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.SUM0

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class IncrSum(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"incr_sum($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      AggSqlFunctions.INCR_SUM,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveSumType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "incr_sum")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    AggSqlFunctions.INCR_SUM

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Min(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"min($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.MIN,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertOrderableExpr(child.resultType, "min")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.MIN

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Max(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"max($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.MAX,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertOrderableExpr(child.resultType, "max")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.MAX

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Count(child: Expression) extends Aggregation {
  private[flink] val isWildcard = child.checkEquals(Literal("*"))
  override private[flink] def children: Seq[Expression] = if (isWildcard) Seq() else Seq(child)
  override def toString = s"count($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    if (isWildcard) {
      relBuilder.aggregateCall(
        SqlStdOperatorTable.COUNT,
        isDistinct,
        false,
        null,
        name)
    } else {
      relBuilder.aggregateCall(
        SqlStdOperatorTable.COUNT,
        isDistinct,
        false,
        null,
        name,
        child.toRexNode)
    }

  override private[flink] def resultType = DataTypes.LONG

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.COUNT

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Avg(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"avg($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.AVG,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "avg")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.AVG
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

/**
  * Returns a multiset aggregates.
  */
case class Collect(child: Expression) extends Aggregation  {

  override private[flink] def children: Seq[Expression] = Seq(child)

  override private[flink] def resultType: InternalType =
    new MultisetType(child.resultType)

  override def toString: String = s"collect($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.COLLECT,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.COLLECT

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Rank() extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq()
  override def toString = s"rank()"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.RANK,
      isDistinct,
      false,
      null,
      name)

  override private[flink] def resultType = DataTypes.LONG

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.RANK

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class DenseRank() extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq()
  override def toString = s"dense_rank()"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.DENSE_RANK,
      isDistinct,
      false,
      null,
      name)

  override private[flink] def resultType = DataTypes.LONG

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.DENSE_RANK

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class RowNumber() extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq()
  override def toString = s"row_number()"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.ROW_NUMBER,
      isDistinct,
      false,
      null,
      name)

  override private[flink] def resultType = DataTypes.LONG

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.ROW_NUMBER

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Lead(exp: Expression, offset: Literal = Literal(1),
  var default: Expression = null) extends Aggregation {
  private[flink] var offsetValue: Long = _

  override private[flink] def children: Seq[Expression] = if (default == null) {
    Seq(exp, offset)
  } else {
    Seq(exp, offset, default)
  }

  override def toString = s"lead($exp, $offset, $default)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.LEAD,
      isDistinct,
      false,
      null,
      name,
      children.map(_.toRexNode): _*)

  override private[flink] def resultType = exp.resultType

  override private[flink] def validateInput(): ValidationResult = {
    if (default == null) {
      default = Null(exp.resultType)
    }
    if (exp.resultType != default.resultType) {
      ValidationFailure("Expression and default value must have the same type.")
    } else {
      offset.value match {
        case o: Int =>
          offsetValue = o.toLong
          ValidationSuccess
        case o: Long =>
          offsetValue = o
          ValidationSuccess
        case _ => ValidationFailure("Lead offset must be an integer.")
      }
    }
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.LEAD

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)

  def getOffsetValue: Long = offsetValue
}

case class Lag(exp: Expression, offset: Literal = Literal(1),
  var default: Expression = null) extends Aggregation {
  private[flink] var offsetValue: Long = _

  override private[flink] def children: Seq[Expression] = if (default == null) {
    Seq(exp, offset)
  } else {
    Seq(exp, offset, default)
  }

  override def toString = s"lag($exp, $offset, $default)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.LAG,
      isDistinct,
      false,
      null,
      name,
      children.map(_.toRexNode): _*)

  override private[flink] def resultType = exp.resultType

  override private[flink] def validateInput(): ValidationResult = {
    if (default == null) {
      default = Null(exp.resultType)
    }
    if (exp.resultType != default.resultType) {
      ValidationFailure("Expression and default value must have the same type.")
    } else {
      offset.value match {
        case o: Int =>
          offsetValue = o.toLong
          ValidationSuccess
        case o: Long =>
          offsetValue = o
          ValidationSuccess
        case _ => ValidationFailure("Lag offset must be an integer.")
      }
    }
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.LAG

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)

  def getOffsetValue: Long = offsetValue
}

case class StddevPop(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"stddev_pop($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.STDDEV_POP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_pop")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.STDDEV_POP

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class StddevSamp(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"stddev_samp($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.STDDEV_SAMP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_samp")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.STDDEV_SAMP

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Stddev(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"stddev($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.STDDEV,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "stddev")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.STDDEV

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class VarPop(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"var_pop($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.VAR_POP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "var_pop")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.VAR_POP

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class VarSamp(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"var_samp($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.VAR_SAMP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "var_samp")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.VAR_SAMP

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Variance(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"variance($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      SqlStdOperatorTable.VARIANCE,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = FlinkTypeSystem.deriveAvgAggType(child.resultType)

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "variance")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.VARIANCE

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class FirstValue(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"first_value($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      AggSqlFunctions.FIRST_VALUE,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = child.resultType

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    AggSqlFunctions.FIRST_VALUE

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class LastValue(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"last_value($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      AggSqlFunctions.LAST_VALUE,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  override private[flink] def resultType = child.resultType

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    AggSqlFunctions.LAST_VALUE

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class SingleValue(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"single_value($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      getSqlAggFunction,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)

  def getCalciteType(relBuilder: RelBuilder): RelDataType =
    relBuilder
        .getTypeFactory.asInstanceOf[FlinkTypeFactory]
        .createTypeFromInternalType(resultType, isNullable = true)

  override private[flink] def resultType = child.resultType

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    new SqlSingleValueAggFunction(getCalciteType(relBuilder))

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class ConcatAgg(child: Expression, separator: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child, separator)
  override def toString = s"concat_agg($child, $separator)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      getSqlAggFunction,
      isDistinct,
      false,
      null,
      name,
      separator.toRexNode,
      child.toRexNode)

  override private[flink] def resultType = child.resultType

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    AggSqlFunctions.CONCAT_AGG

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

//Aggregate function calls

/**
  * Represent aggregate function call
  */
case class AggFunctionCall(
    aggregateFunction: AggregateFunction[_, _],
    externalResultType: DataType,
    externalAccType: DataType,
    args: Seq[Expression])
  extends Aggregation {

  override private[flink] def children: Seq[Expression] = args

  override def resultType: InternalType = externalResultType.toInternalType

  override def validateInput(): ValidationResult = {
    val signature = children.map(_.resultType)
    // look for a signature that matches the input types
    val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
    if (foundSignature.isEmpty) {
      ValidationFailure(s"Given parameters do not match any signature. \n" +
                          s"Actual: ${signatureToString(signature)} \n" +
                          s"Expected: ${
                            getMethodSignatures(aggregateFunction, "accumulate").drop(1)
                              .map(signatureToString).mkString(", ")}")
    } else {
      ValidationSuccess
    }
  }

  override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)"

  override def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall =
    relBuilder.aggregateCall(
      this.getSqlAggFunction(),
      isDistinct,
      false,
      null,
      name,
      args.map(_.toRexNode): _*)

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
    AggSqlFunction(
      aggregateFunction.functionIdentifier,
      aggregateFunction.toString,
      aggregateFunction,
      externalResultType,
      externalAccType,
      typeFactory,
      aggregateFunction.requiresOver)
  }

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
    relBuilder.call(this.getSqlAggFunction(), args.map(_.toRexNode): _*)

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy