All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.expressions.aggregations.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.expressions

import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.sql.fun._
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.tools.RelBuilder.AggCall
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.functions.utils.AggSqlFunction
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.java.typeutils.MultisetTypeInfo
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}

abstract sealed class Aggregation extends Expression {

  override def toString = s"Aggregate"

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
    throw new UnsupportedOperationException("Aggregate cannot be transformed to RexNode")

  /**
    * Convert Aggregate to its counterpart in Calcite, i.e. AggCall
    */
  private[flink] def toAggCall(
      name: String,
      isDistinct: Boolean = false
  )(implicit relBuilder: RelBuilder): AggCall

  /**
    * Returns the SqlAggFunction for this Aggregation.
    */
  private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder): SqlAggFunction

}

case class DistinctAgg(child: Expression) extends Aggregation {

  def distinct: Expression = DistinctAgg(child)

  override private[flink] def resultType: TypeInformation[_] = child.resultType

  override private[flink] def validateInput(): ValidationResult = {
    super.validateInput()
    child match {
      case agg: Aggregation =>
        child.validateInput()
      case _ =>
        ValidationFailure(s"Distinct modifier cannot be applied to $child! " +
          s"It can only be applied to an aggregation expression, for example, " +
          s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
    }
  }

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = true)(implicit relBuilder: RelBuilder) = {
    child.asInstanceOf[Aggregation].toAggCall(name, isDistinct = true)
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    child.asInstanceOf[Aggregation].getSqlAggFunction()
  }

  override private[flink] def children = Seq(child)
}

case class Sum(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"sum($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.SUM,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "sum")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    val returnType = relBuilder
      .getTypeFactory.asInstanceOf[FlinkTypeFactory]
      .createTypeFromTypeInfo(resultType, isNullable = true)
    new SqlSumAggFunction(returnType)
  }
}

case class Sum0(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"sum0($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.SUM0,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "sum0")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.SUM0
}

case class Min(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"min($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.MIN,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertOrderableExpr(child.resultType, "min")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.MIN
  }
}

case class Max(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"max($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.MAX,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertOrderableExpr(child.resultType, "max")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.MAX
  }
}

case class Count(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"count($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.COUNT,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = BasicTypeInfo.LONG_TYPE_INFO

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.COUNT
  }
}

case class Avg(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"avg($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.AVG,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "avg")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.AVG
  }
}

/**
  * Returns a multiset aggregates.
  */
case class Collect(child: Expression) extends Aggregation  {

  override private[flink] def children: Seq[Expression] = Seq(child)

  override private[flink] def resultType: TypeInformation[_] =
    MultisetTypeInfo.getInfoFor(child.resultType)

  override def toString: String = s"collect($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.COLLECT,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    SqlStdOperatorTable.COLLECT
  }
}

case class StddevPop(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"stddev_pop($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.STDDEV_POP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_pop")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.STDDEV_POP
}

case class StddevSamp(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"stddev_samp($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.STDDEV_SAMP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_samp")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.STDDEV_SAMP
}

case class VarPop(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"var_pop($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.VAR_POP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "var_pop")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.VAR_POP
}

case class VarSamp(child: Expression) extends Aggregation {
  override private[flink] def children: Seq[Expression] = Seq(child)
  override def toString = s"var_samp($child)"

  override private[flink] def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      SqlStdOperatorTable.VAR_SAMP,
      isDistinct,
      false,
      null,
      name,
      child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput() =
    TypeCheckUtils.assertNumericExpr(child.resultType, "var_samp")

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =
    SqlStdOperatorTable.VAR_SAMP
}

case class AggFunctionCall(
    aggregateFunction: AggregateFunction[_, _],
    resultTypeInfo: TypeInformation[_],
    accTypeInfo: TypeInformation[_],
    args: Seq[Expression])
  extends Aggregation {

  override private[flink] def children: Seq[Expression] = args

  override def resultType: TypeInformation[_] = resultTypeInfo

  override def validateInput(): ValidationResult = {
    val signature = children.map(_.resultType)
    // look for a signature that matches the input types
    val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
    if (foundSignature.isEmpty) {
      ValidationFailure(s"Given parameters do not match any signature. \n" +
                          s"Actual: ${signatureToString(signature)} \n" +
                          s"Expected: ${
                            getMethodSignatures(aggregateFunction, "accumulate")
                              .map(_.drop(1))
                              .map(signatureToString)
                              .mkString(", ")}")
    } else {
      ValidationSuccess
    }
  }

  override def toString: String = s"${aggregateFunction.getClass.getSimpleName}($args)"

  override def toAggCall(
      name: String, isDistinct: Boolean = false)(implicit relBuilder: RelBuilder): AggCall = {
    relBuilder.aggregateCall(
      this.getSqlAggFunction(),
      isDistinct,
      false,
      null,
      name,
      args.map(_.toRexNode): _*)
  }

  override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = {
    val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
    AggSqlFunction(
      aggregateFunction.functionIdentifier,
      aggregateFunction.toString,
      aggregateFunction,
      resultType,
      accTypeInfo,
      typeFactory,
      aggregateFunction.requiresOver)
  }

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    relBuilder.call(this.getSqlAggFunction(), args.map(_.toRexNode): _*)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy