All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.expressions.ExpressionUtils.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.expressions

import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short => JShort}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Time, Timestamp}
import java.time.Duration

import org.apache.flink.streaming.api.windowing.time.{Time => FlinkTime}
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.api.types.{DataTypes, InternalType}

object ExpressionUtils {
  /**
    * Retrieve result type of given Expression.
    *
    * @param expr The expression which caller is interested about result type
    * @return     The result type of Expression
    */
  def getResultType(expr: Expression): InternalType = {
    expr.resultType
  }

  private[flink] def isTimeIntervalLiteral(expr: Expression): Boolean = expr match {
    case Literal(_, DataTypes.INTERVAL_MILLIS) => true
    case _ => false
  }

  private[flink] def isRowCountLiteral(expr: Expression): Boolean = expr match {
    case Literal(_, DataTypes.INTERVAL_ROWS) => true
    case _ => false
  }

  private[flink] def isTimeAttribute(expr: Expression): Boolean =
    isRowtimeAttribute(expr) || isProctimeAttribute(expr)

  private[flink] def isRowtimeAttribute(expr: Expression): Boolean = expr match {
    case r: ResolvedFieldReference if r.resultType == DataTypes.ROWTIME_INDICATOR => true
    case _ => false
  }

  private[flink] def isProctimeAttribute(expr: Expression): Boolean = expr match {
    case r: ResolvedFieldReference if r.resultType == DataTypes.PROCTIME_INDICATOR => true
    case _: Proctime => true
    case _ => false
  }

  private[flink] def toTime(expr: Expression): FlinkTime = expr match {
    case Literal(value: Long, DataTypes.INTERVAL_MILLIS) =>
      FlinkTime.milliseconds(value)
    case _ => throw new IllegalArgumentException()
  }

  private[flink] def toDuration(expr: Expression): Duration = expr match {
    case Literal(value: Long, DataTypes.INTERVAL_MILLIS) =>
      Duration.ofMillis(value)
    case _ => throw new IllegalArgumentException()
  }

  private[flink] def toLong(expr: Expression): Long = expr match {
    case Literal(value: Long, DataTypes.INTERVAL_ROWS) => value
    case _ => throw new IllegalArgumentException()
  }

  private[flink] def toMonthInterval(expr: Expression, multiplier: Int): Expression = expr match {
    case Literal(value: Int, DataTypes.INT) =>
      Literal(value * multiplier, DataTypes.INTERVAL_MONTHS)
    case _ =>
      Cast(Mul(expr, Literal(multiplier)), DataTypes.INTERVAL_MONTHS)
  }

  private[flink] def toMilliInterval(expr: Expression, multiplier: Long): Expression = expr match {
    case Literal(value: Int, DataTypes.INT) =>
      Literal(value * multiplier, DataTypes.INTERVAL_MILLIS)
    case Literal(value: Long, DataTypes.LONG) =>
      Literal(value * multiplier, DataTypes.INTERVAL_MILLIS)
    case _ =>
      Cast(Mul(expr, Literal(multiplier)), DataTypes.INTERVAL_MILLIS)
  }

  private[flink] def toRowInterval(expr: Expression): Expression = expr match {
    case Literal(value: Int, DataTypes.INT) =>
      Literal(value.toLong, DataTypes.INTERVAL_ROWS)
    case Literal(value: Long, DataTypes.LONG) =>
      Literal(value, DataTypes.INTERVAL_ROWS)
    case _ =>
      throw new IllegalArgumentException("Invalid value for row interval literal.")
  }

  private[flink] def toRangeInterval(expr: Expression): Expression = expr match {
    case Literal(value: Int, DataTypes.INT) =>
      Literal(value.toLong, DataTypes.INTERVAL_RANGE)
    case Literal(value: Long, DataTypes.LONG) =>
      Literal(value, DataTypes.INTERVAL_RANGE)
    case Literal(value: Float, DataTypes.FLOAT) =>
      Literal(value, DataTypes.INTERVAL_RANGE)
    case Literal(value: Double, DataTypes.DOUBLE) =>
      Literal(value, DataTypes.INTERVAL_RANGE)
    case _ =>
      throw new IllegalArgumentException("Invalid value for range interval literal.")
  }

  private[flink] def convertArray(array: Array[_]): Expression = {
    def createArray(): Expression = {
      ArrayConstructor(array.map(Literal(_)))
    }

    array match {
      // primitives
      case _: Array[Boolean] => createArray()
      case _: Array[Byte] => createArray()
      case _: Array[Short] => createArray()
      case _: Array[Int] => createArray()
      case _: Array[Long] => createArray()
      case _: Array[Float] => createArray()
      case _: Array[Double] => createArray()

      // boxed types
      case _: Array[JBoolean] => createArray()
      case _: Array[JByte] => createArray()
      case _: Array[JShort] => createArray()
      case _: Array[JInteger] => createArray()
      case _: Array[JLong] => createArray()
      case _: Array[JFloat] => createArray()
      case _: Array[JDouble] => createArray()

      // others
      case _: Array[String] => createArray()
      case _: Array[JBigDecimal] => createArray()
      case _: Array[Date] => createArray()
      case _: Array[Time] => createArray()
      case _: Array[Timestamp] => createArray()
      case bda: Array[BigDecimal] => ArrayConstructor(bda.map { bd => Literal(bd.bigDecimal) })

      case _ =>
        // nested
        if (array.length > 0 && array.head.isInstanceOf[Array[_]]) {
          ArrayConstructor(array.map { na => convertArray(na.asInstanceOf[Array[_]]) })
        } else {
          throw new ValidationException("Unsupported array type.")
        }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy