All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.expressions.arithmetic.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.expressions

import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.table.functions.sql.ScalarSqlFunctions
import org.apache.flink.table.plan.logical.LogicalExprVisitor
import org.apache.flink.table.api.types.{DataTypes, DecimalType, InternalType}
import org.apache.flink.table.typeutils.TypeCheckUtils._
import org.apache.flink.table.typeutils.TypeCoercion
import org.apache.flink.table.validate._

import scala.collection.JavaConversions._

abstract class BinaryArithmetic extends BinaryExpression {
  private[flink] def sqlOperator: SqlOperator

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    relBuilder.call(sqlOperator, children.map(_.toRexNode))
  }

  override private[flink] def resultType: InternalType =
    TypeCoercion.widerTypeOf(left.resultType, right.resultType) match {
      case Some(t) => t
      case None =>
        throw new RuntimeException("This should never happen.")
    }

  override private[flink] def validateInput(): ValidationResult = {
    if (!isNumeric(left.resultType) ||
        !isNumeric(right.resultType)) {
      ValidationFailure(s"$this requires both operands to be numeric, but was " +
        s"$left : ${left.resultType} and $right : ${right.resultType}")
    } else {
      ValidationSuccess
    }
  }
}

case class Plus(left: Expression, right: Expression) extends BinaryArithmetic {
  override def toString = s"($left + $right)"

  private[flink] val sqlOperator = SqlStdOperatorTable.PLUS

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    if(isString(left.resultType)) {
      val castedRight = Cast(right, DataTypes.STRING)
      relBuilder.call(SqlStdOperatorTable.CONCAT, left.toRexNode, castedRight.toRexNode)
    } else if(isString(right.resultType)) {
      val castedLeft = Cast(left, DataTypes.STRING)
      relBuilder.call(SqlStdOperatorTable.CONCAT, castedLeft.toRexNode, right.toRexNode)
    } else if (isTimeInterval(left.resultType) &&
        left.resultType == right.resultType) {
      relBuilder.call(SqlStdOperatorTable.PLUS, left.toRexNode, right.toRexNode)
    } else if (isTimeInterval(left.resultType)
        && isTemporal(right.resultType)) {
      // Calcite has a bug that can't apply INTERVAL + DATETIME (INTERVAL at left)
      // we manually switch them here
      relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, right.toRexNode, left.toRexNode)
    } else if (isTemporal(left.resultType) &&
        isTemporal(right.resultType)) {
      relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, left.toRexNode, right.toRexNode)
    } else {
      super.toRexNode
    }
  }

  override private[flink] def validateInput(): ValidationResult = {
    if (isString(left.resultType) ||
        isString(right.resultType)) {
      ValidationSuccess
    } else if (isTimeInterval(left.resultType) &&
        left.resultType == right.resultType) {
      ValidationSuccess
    } else if (isTimePoint(left.resultType) &&
        isTimeInterval(right.resultType)) {
      ValidationSuccess
    } else if (isTimeInterval(left.resultType) &&
        isTimePoint(right.resultType)) {
      ValidationSuccess
    } else if (isNumeric(left.resultType) &&
        isNumeric(right.resultType)) {
      ValidationSuccess
    } else {
      ValidationFailure(
        s"$this requires Numeric, String, Intervals of same type, " +
        s"or Interval and a time point input, " +
        s"get $left : ${left.resultType} and $right : ${right.resultType}")
    }
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class UnaryMinus(child: Expression) extends UnaryExpression {
  override def toString = s"-($child)"

  override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
    relBuilder.call(SqlStdOperatorTable.UNARY_MINUS, child.toRexNode)
  }

  override private[flink] def resultType = child.resultType

  override private[flink] def validateInput(): ValidationResult = {
    if (isNumeric(child.resultType)) {
      ValidationSuccess
    } else if (isTimeInterval(child.resultType)) {
      ValidationSuccess
    } else {
      ValidationFailure(s"$this requires Numeric, or Interval input, get ${child.resultType}")
    }
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Minus(left: Expression, right: Expression) extends BinaryArithmetic {
  override def toString = s"($left - $right)"

  private[flink] val sqlOperator = SqlStdOperatorTable.MINUS

  override private[flink] def validateInput(): ValidationResult = {
    if (isTimeInterval(left.resultType) &&
        left.resultType == right.resultType) {
      ValidationSuccess
    } else if (isTimePoint(left.resultType) &&
        isTimeInterval(right.resultType)) {
      ValidationSuccess
    } else if (isTimeInterval(left.resultType) &&
        isTimePoint(right.resultType)) {
      ValidationSuccess
    } else {
      super.validateInput()
    }
  }

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Div(left: Expression, right: Expression) extends BinaryArithmetic {
  override def toString = s"($left / $right)"

  private[flink] val sqlOperator = ScalarSqlFunctions.DIVIDE

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)

  override private[flink] def resultType: InternalType =
    super.resultType match {
      case dt: DecimalType => dt
      case _ => DataTypes.DOUBLE
    }
}

case class Mul(left: Expression, right: Expression) extends BinaryArithmetic {
  override def toString = s"($left * $right)"

  private[flink] val sqlOperator = SqlStdOperatorTable.MULTIPLY

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}

case class Mod(left: Expression, right: Expression) extends BinaryArithmetic {
  override def toString = s"($left % $right)"

  private[flink] val sqlOperator = SqlStdOperatorTable.MOD

  override def accept[T](logicalExprVisitor: LogicalExprVisitor[T]): T =
    logicalExprVisitor.visit(this)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy