All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.plan.ProjectionTranslator.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan

import org.apache.flink.table.api.types.{RowType}
import org.apache.flink.table.api.{OverWindow, TableEnvironment}
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.logical.{LogicalNode, Project}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object ProjectionTranslator {

  /**
    * Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
    * from the given expressions.
    *
    * @param exprs    a list of expressions to extract
    * @param tableEnv the TableEnvironment
    * @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
    *         and the second field contains the extracted and deduplicated window properties.
    */
  def extractAggregationsAndProperties(
      exprs: Seq[Expression],
      tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
    exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
      (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
    }
  }

  /** Identifies and deduplicates aggregation functions and window properties. */
  private def identifyAggregationsAndProperties(
      exp: Expression,
      tableEnv: TableEnvironment,
      aggNames: Map[Expression, String],
      propNames: Map[Expression, String]) : (Map[Expression, String], Map[Expression, String]) = {

    exp match {
      case agg: Aggregation =>
        if (aggNames contains agg) {
          (aggNames, propNames)
        } else {
          (aggNames + (agg -> tableEnv.createUniqueAttributeName()), propNames)
        }
      case prop: WindowProperty =>
        if (propNames contains prop) {
          (aggNames, propNames)
        } else {
          (aggNames, propNames + (prop -> tableEnv.createUniqueAttributeName()))
        }
      case l: LeafExpression =>
        (aggNames, propNames)
      case u: UnaryExpression =>
        identifyAggregationsAndProperties(u.child, tableEnv, aggNames, propNames)
      case b: BinaryExpression =>
        val l = identifyAggregationsAndProperties(b.left, tableEnv, aggNames, propNames)
        identifyAggregationsAndProperties(b.right, tableEnv, l._1, l._2)

      // Functions calls
      case c @ Call(name, args) =>
        args.foldLeft((aggNames, propNames)){
          (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
        }

      case sfc @ ScalarFunctionCall(clazz, args) =>
        args.foldLeft((aggNames, propNames)){
          (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
        }

      // General expression
      case e: Expression =>
        e.productIterator.foldLeft((aggNames, propNames)){
          (x, y) => y match {
            case e: Expression => identifyAggregationsAndProperties(e, tableEnv, x._1, x._2)
            case _ => (x._1, x._2)
          }
        }
    }
  }

  /**
    * Replaces expressions with deduplicated aggregations and properties.
    *
    * @param exprs     a list of expressions to replace
    * @param tableEnv  the TableEnvironment
    * @param aggNames  the deduplicated aggregations
    * @param propNames the deduplicated properties
    * @return a list of replaced expressions
    */
  def replaceAggregationsAndProperties(
      exprs: Seq[Expression],
      tableEnv: TableEnvironment,
      aggNames: Map[Expression, String],
      propNames: Map[Expression, String]): Seq[NamedExpression] = {
    val projectedNames = new mutable.HashSet[String]
    exprs.map((exp: Expression) => replaceAggregationsAndProperties(exp, tableEnv,
      aggNames, propNames, projectedNames))
        .map(UnresolvedAlias)
  }

  private def replaceAggregationsAndProperties(
      exp: Expression,
      tableEnv: TableEnvironment,
      aggNames: Map[Expression, String],
      propNames: Map[Expression, String],
      projectedNames: mutable.HashSet[String]) : Expression = {

    exp match {
      case agg: Aggregation =>
        val name = aggNames(agg)
        if (projectedNames.add(name)) {
          UnresolvedFieldReference(name)
        } else {
          Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
        }
      case prop: WindowProperty =>
        val name = propNames(prop)
        if (projectedNames.add(name)) {
          UnresolvedFieldReference(name)
        } else {
          Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName())
        }
      case n @ Alias(agg: Aggregation, name, _) =>
        val aName = aggNames(agg)
        Alias(UnresolvedFieldReference(aName), name)
      case n @ Alias(prop: WindowProperty, name, _) =>
        val pName = propNames(prop)
        Alias(UnresolvedFieldReference(pName), name)
      case l: LeafExpression => l
      case u: UnaryExpression =>
        val c = replaceAggregationsAndProperties(u.child, tableEnv,
          aggNames, propNames, projectedNames)
        u.makeCopy(Array(c))
      case b: BinaryExpression =>
        val l = replaceAggregationsAndProperties(b.left, tableEnv,
          aggNames, propNames, projectedNames)
        val r = replaceAggregationsAndProperties(b.right, tableEnv,
          aggNames, propNames, projectedNames)
        b.makeCopy(Array(l, r))

      // Functions calls
      case c @ Call(name, args) =>
        val newArgs = args.map((exp: Expression) =>
          replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
        c.makeCopy(Array(name, newArgs))

      case sfc @ ScalarFunctionCall(clazz, args) =>
        val newArgs: Seq[Expression] = args
          .map((exp: Expression) =>
            replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
        sfc.makeCopy(Array(clazz, newArgs))

      // row constructor
      case c @ RowConstructor(args) =>
        val newArgs = c.elements
          .map((exp: Expression) =>
            replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
        c.makeCopy(Array(newArgs))

      // array constructor
      case c @ ArrayConstructor(args) =>
        val newArgs = c.elements
          .map((exp: Expression) =>
            replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
        c.makeCopy(Array(newArgs))

      // map constructor
      case c @ MapConstructor(args) =>
        val newArgs = c.elements
          .map((exp: Expression) =>
            replaceAggregationsAndProperties(exp, tableEnv, aggNames, propNames, projectedNames))
        c.makeCopy(Array(newArgs))

      // General expression
      case e: Expression =>
        val newArgs = e.productIterator.map {
          case arg: Expression =>
            replaceAggregationsAndProperties(arg, tableEnv, aggNames, propNames, projectedNames)
        }
        e.makeCopy(newArgs.toArray)
    }
  }

  /**
    * Expands an UnresolvedFieldReference("*") to parent's full project list.
    */
  def expandProjectList(
      exprs: Seq[Expression],
      parent: LogicalNode,
      tableEnv: TableEnvironment)
    : Seq[Expression] = {

    val projectList = new ListBuffer[Expression]

    exprs.foreach {
      case n: UnresolvedFieldReference if n.name == "*" =>
        projectList ++= parent.output.map(a => UnresolvedFieldReference(a.name))

      case Flattening(unresolved) =>
        // simulate a simple project to resolve fields using current parent
        val project = Project(Seq(UnresolvedAlias(unresolved)), parent).validate(tableEnv)
        val resolvedExpr = project
          .output
          .headOption
          .getOrElse(throw new RuntimeException("Could not find resolved composite."))
        resolvedExpr.validateInput()
        val newProjects = resolvedExpr.resultType match {
          case ct: RowType =>
            (0 until ct.getArity).map { idx =>
              projectList += GetCompositeField(unresolved, ct.getFieldNames()(idx))
            }
          case _ =>
            projectList += unresolved
        }

      case e: Expression => projectList += e
    }
    projectList
  }

  def resolveOverWindows(
      exprs: Seq[Expression],
      overWindows: Array[OverWindow],
      tEnv: TableEnvironment): Seq[Expression] = {

    exprs.map(e => replaceOverCall(e, overWindows, tEnv))
  }

  /**
    * Find and replace UnresolvedOverCall with OverCall
    *
    * @param expr    the expression to check
    * @return an expression with correct resolved OverCall
    */
  private def replaceOverCall(
    expr: Expression,
    overWindows: Array[OverWindow],
    tableEnv: TableEnvironment): Expression = {

    expr match {
      case u: UnresolvedOverCall =>
        val overWindow = overWindows.find(_.alias.equals(u.alias))
        if (overWindow.isDefined) {
          OverCall(
            u.agg,
            overWindow.get.partitionBy,
            overWindow.get.orderBy,
            overWindow.get.preceding,
            overWindow.get.following,
            tableEnv)
        } else {
          u
        }

      case u: UnaryExpression =>
        val c = replaceOverCall(u.child, overWindows, tableEnv)
        u.makeCopy(Array(c))

      case b: BinaryExpression =>
        val l = replaceOverCall(b.left, overWindows, tableEnv)
        val r = replaceOverCall(b.right, overWindows, tableEnv)
        b.makeCopy(Array(l, r))

      // Functions calls
      case c @ Call(name, args: Seq[Expression]) =>
        val newArgs =
          args.map(
            (exp: Expression) =>
              replaceOverCall(exp, overWindows, tableEnv))
        c.makeCopy(Array(name, newArgs))

      // Scala functions
      case sfc @ ScalarFunctionCall(clazz, args: Seq[Expression]) =>
        val newArgs: Seq[Expression] =
          args.map(
            (exp: Expression) =>
              replaceOverCall(exp, overWindows, tableEnv))
        sfc.makeCopy(Array(clazz, newArgs))

      // Array constructor
      case c @ ArrayConstructor(args) =>
        val newArgs =
          c.elements
            .map((exp: Expression) => replaceOverCall(exp, overWindows, tableEnv))
        c.makeCopy(Array(newArgs))

      // Other expressions
      case e: Expression => e
    }
  }


  /**
    * Extract all field references from the given expressions.
    *
    * @param exprs a list of expressions to extract
    * @return a list of field references extracted from the given expressions
    */
  def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
    exprs.foldLeft(Set[NamedExpression]()) {
      (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
    }.toSeq
  }

  private def identifyFieldReferences(
      expr: Expression,
      fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {

    case f: UnresolvedFieldReference =>
      fieldReferences + UnresolvedAlias(f)

    case b: BinaryExpression =>
      val l = identifyFieldReferences(b.left, fieldReferences)
      identifyFieldReferences(b.right, l)

    // Functions calls
    case Call(_, args: Seq[Expression]) =>
      args.foldLeft(fieldReferences) {
        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
      }
    case ScalarFunctionCall(_, args: Seq[Expression]) =>
      args.foldLeft(fieldReferences) {
        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
      }

    case AggFunctionCall(_, _, _, args) =>
      args.foldLeft(fieldReferences) {
        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
      }

    // array constructor
    case ArrayConstructor(args) =>
      args.foldLeft(fieldReferences) {
        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
      }

    // ignore fields from window property
    case _: WindowProperty =>
      fieldReferences

    // keep this case after all unwanted unary expressions
    case u: UnaryExpression =>
      identifyFieldReferences(u.child, fieldReferences)

    // General expression
    case e: Expression =>
      e.productIterator.foldLeft(fieldReferences) {
        (fieldReferences, expr) => expr match {
          case e: Expression => identifyFieldReferences(e, fieldReferences)
          case _ => fieldReferences
        }
      }
  }

  /**
    * Find and replace UDAGG function Call to AggFunctionCall
    *
    * @param field    the expression to check
    * @param tableEnv the TableEnvironment
    * @return an expression with correct AggFunctionCall type for UDAGG functions
    */
  def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
    field match {
      case l: LeafExpression => l

      case u: UnaryExpression =>
        val c = replaceAggFunctionCall(u.child, tableEnv)
        u.makeCopy(Array(c))

      case b: BinaryExpression =>
        val l = replaceAggFunctionCall(b.left, tableEnv)
        val r = replaceAggFunctionCall(b.right, tableEnv)
        b.makeCopy(Array(l, r))
      // Functions calls
      case c @ Call(name, args) =>
        val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
        function match {
          case a: AggFunctionCall => a
          case a: Aggregation => a
          case p: AbstractWindowProperty => p
          case _ =>
            val newArgs =
              args.map(
                (exp: Expression) =>
                  replaceAggFunctionCall(exp, tableEnv))
            c.makeCopy(Array(name, newArgs))
        }
      // Scala functions
      case sfc @ ScalarFunctionCall(clazz, args) =>
        val newArgs: Seq[Expression] =
          args.map(
            (exp: Expression) =>
              replaceAggFunctionCall(exp, tableEnv))
        sfc.makeCopy(Array(clazz, newArgs))

      // Array constructor
      case c @ ArrayConstructor(args) =>
        val newArgs =
          c.elements
            .map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv))
        c.makeCopy(Array(newArgs))

      // Other expressions
      case e: Expression => e
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy