All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.runtime.functions.ScalarFunctions.scala Maven / Gradle / Ivy

Go to download

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.runtime.functions

import java.lang.{StringBuilder, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.util.regex.{Matcher, Pattern}

import org.apache.flink.table.utils.EncodingUtils

import org.apache.calcite.avatica.util.DateTimeUtils.{EPOCH_JULIAN, ymdToUnixDate}
import org.apache.calcite.avatica.util.TimeUnitRange
import org.apache.calcite.avatica.util.TimeUnitRange.{YEAR, MONTH}

import scala.annotation.varargs

/**
  * Built-in scalar runtime functions.
  *
  * NOTE: Before you add functions here, check if Calcite provides it in
  * [[org.apache.calcite.runtime.SqlFunctions]]. Furthermore, make sure to implement the function
  * efficiently. Sometimes it makes sense to create a
  * [[org.apache.flink.table.codegen.calls.CallGenerator]] instead to avoid massive object
  * creation and reuse instances.
  */
class ScalarFunctions {}

object ScalarFunctions {

  def power(a: Double, b: JBigDecimal): Double = {
    StrictMath.pow(a, b.doubleValue())
  }

  /**
    * Returns the hyperbolic cosine of a big decimal value.
    */
  def cosh(x: JBigDecimal): Double = {
    Math.cosh(x.doubleValue())
  }

  /**
    * Returns the string that results from concatenating the arguments.
    * Returns NULL if any argument is NULL.
    */
  @varargs
  def concat(args: String*): String = {
    val sb = new StringBuilder
    var i = 0
    while (i < args.length) {
      if (args(i) == null) {
        return null
      }
      sb.append(args(i))
      i += 1
    }
    sb.toString
  }

  /**
    * Returns the string that results from concatenating the arguments and separator.
    * Returns NULL If the separator is NULL.
    *
    * Note: CONCAT_WS() does not skip empty strings. However, it does skip any NULL values after
    * the separator argument.
    *
    **/
  @varargs
  def concat_ws(separator: String, args: String*): String = {
    if (null == separator) {
      return null
    }

    val sb = new StringBuilder

    var i = 0

    var hasValueAppended = false

    while (i < args.length) {
      if (null != args(i)) {
        if (hasValueAppended) {
          sb.append(separator)
        }
        sb.append(args(i))
        hasValueAppended = true
      }
      i = i + 1
    }
    sb.toString
  }

  /**
    * Returns the natural logarithm of "x".
    */
  def log(x: Double): Double = {
    if (x <= 0.0) {
      throw new IllegalArgumentException(s"x of 'log(x)' must be > 0, but x = $x")
    } else {
      StrictMath.log(x)
    }
  }

  /**
   * Returns exp(x).
   */
  def exp(x: Double): Double = {
    if (x <= 0.0) {
      throw new IllegalArgumentException(s"x of 'exp(x)' must be > 0, but x = $x")
    } else {
      StrictMath.exp(x)
    }
  }

  /**
    * Calculates the hyperbolic tangent of a big decimal number.
    */
  def tanh(x: JBigDecimal): Double = {
    Math.tanh(x.doubleValue())
  }

  /**
    * Returns the hyperbolic sine of a big decimal value.
    */
  def sinh(x: JBigDecimal): Double = {
    Math.sinh(x.doubleValue())
  }

  /**
    * Returns the logarithm of "x" with base "base".
    */
  def log(base: Double, x: Double): Double = {
    if (x <= 0.0) {
      throw new IllegalArgumentException(s"x of 'log(base, x)' must be > 0, but x = $x")
    }
    if (base <= 1.0) {
      throw new IllegalArgumentException(s"base of 'log(base, x)' must be > 1, but base = $base")
    } else {
      StrictMath.log(x) / StrictMath.log(base)
    }
  }

  /**
    * Returns the logarithm of "x" with base 2.
    */
  def log2(x: Double): Double = {
    if (x <= 0.0) {
      throw new IllegalArgumentException(s"x of 'log2(x)' must be > 0, but x = $x")
    } else {
      StrictMath.log(x) / StrictMath.log(2)
    }
  }

  /**
    * Returns the string str left-padded with the string pad to a length of len characters.
    * If str is longer than len, the return value is shortened to len characters.
    */
  def lpad(base: String, len: Integer, pad: String): String = {
    if (len < 0) {
      return null
    } else if (len == 0) {
      return ""
    }

    val data = new Array[Char](len)
    val baseChars = base.toCharArray
    val padChars = pad.toCharArray

    // The length of the padding needed
    val pos = Math.max(len - base.length, 0)

    // Copy the padding
    var i = 0
    while (i < pos) {
      var j = 0
      while (j < pad.length && j < pos - i) {
        data(i + j) = padChars(j)
        j += 1
      }
      i += pad.length
    }

    // Copy the base
    i = 0
    while (pos + i < len && i < base.length) {
      data(pos + i) = baseChars(i)
      i += 1
    }

    new String(data)
  }

  /**
    * Returns the string str right-padded with the string pad to a length of len characters.
    * If str is longer than len, the return value is shortened to len characters.
    */
  def rpad(base: String, len: Integer, pad: String): String = {
    if (len < 0) {
      return null
    } else if (len == 0) {
      return ""
    }

    val data = new Array[Char](len)
    val baseChars = base.toCharArray
    val padChars = pad.toCharArray

    var pos = 0

    // Copy the base
    while (pos < base.length && pos < len) {
      data(pos) = baseChars(pos)
      pos += 1
    }

    // Copy the padding
    while (pos < len) {
      var i = 0
      while (i < pad.length && i < len - pos) {
        data(pos + i) = padChars(i)
        i += 1
      }
      pos += pad.length
    }

    new String(data)
  }


  /**
    * Returns a string resulting from replacing all substrings
    * that match the regular expression with replacement.
    */
  def regexp_replace(str: String, regex: String, replacement: String): String = {
    if (str == null || regex == null || replacement == null) {
      return null
    }

    str.replaceAll(regex, Matcher.quoteReplacement(replacement))
  }

  /**
    * Returns a string extracted with a specified regular expression and a regex match group index.
    */
  def regexp_extract(str: String, regex: String, extractIndex: Integer): String = {
    if (str == null || regex == null) {
      return null
    }

    val m = Pattern.compile(regex).matcher(str)
    if (m.find) {
      val mr = m.toMatchResult
      return mr.group(extractIndex)
    }

    null
  }

  /**
    * Returns a string extracted with a specified regular expression and
    * a optional regex match group index.
    */
  def regexp_extract(str: String, regex: String): String = {
    regexp_extract(str, regex, 0)
  }

  /**
    * Returns the base string decoded with base64.
    */
  def fromBase64(base64: String): String =
    EncodingUtils.decodeBase64ToString(base64)

  /**
    * Returns the base64-encoded result of the input string.
    */
  def toBase64(string: String): String =
    EncodingUtils.encodeStringToBase64(string)

  /**
    * Returns the hex string of a long argument.
    */
  def hex(string: Long): String = JLong.toHexString(string).toUpperCase()

  /**
    * Returns the hex string of a string argument.
    */
  def hex(string: String): String =
    EncodingUtils.hex(string).toUpperCase()

  /**
    * Returns an UUID string using Java utilities.
    */
  def uuid(): String = java.util.UUID.randomUUID().toString

  /**
    * Returns a string that repeats the base string n times.
    */
  def repeat(base: String, n: Int): String = EncodingUtils.repeat(base, n)

  // TODO: remove if CALCITE-3199 fixed
  //  https://issues.apache.org/jira/browse/CALCITE-3199
  def unixDateCeil(range: TimeUnitRange, date: Int): Long =
    julianDateFloor(range, date + EPOCH_JULIAN, false)

  private def julianDateFloor(range: TimeUnitRange, julian: Int, floor: Boolean): Int = {
    // this shifts the epoch back to astronomical year -4800 instead of the
    // start of the Christian era in year AD 1 of the proleptic Gregorian
    // calendar.
    val j: Int = julian + 32044
    val g: Int = j / 146097
    val dg: Int = j % 146097
    val c: Int = (dg / 36524 + 1) * 3 / 4
    val dc: Int = dg - c * 36524
    val b: Int = dc / 1461
    val db: Int = dc % 1461
    val a: Int = (db / 365 + 1) * 3 / 4
    val da: Int = db - a * 365
    // integer number of full years elapsed since March 1, 4801 BC
    val y: Int = g * 400 + c * 100 + b * 4 + a
    // integer number of full months elapsed since the last March 1
    val m: Int = (da * 5 + 308) / 153 - 2
    // number of days elapsed since day 1 of the month
    val d: Int = da - (m + 4) * 153 / 5 + 122
    var year: Int = y - 4800 + (m + 2) / 12
    var month: Int = (m + 2) % 12 + 1
    val day: Int = d + 1
    range match {
      case YEAR =>
        if (!(floor) && (month > 1 || day > 1)) {
          year += 1
        }
        return ymdToUnixDate(year, 1, 1)
      case MONTH =>
        if (!(floor) && day > 1) {
          month += 1
        }
        return ymdToUnixDate(year, month, 1)
      case _ =>
        throw new AssertionError(range)
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy