All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prophecy.libs.SparkFunctions.scala Maven / Gradle / Ivy

There is a newer version: 6.3.0-3.3.0
Show newest version
 * ====================================================================
 * Prophecy Inc
 * All Rights Reserved.
 * NOTICE:  All information contained herein is, and remains
 * the property of Prophecy Inc, the intellectual and technical concepts contained
 * herein are proprietary to Prophecy Inc and may be covered by U.S. and Foreign Patents,
 * patents in process, and are protected by trade secret or copyright law.
 * Dissemination of this information or reproduction of this material
 * is strictly forbidden unless prior written permission is obtained
 * from Prophecy Inc.
 * ====================================================================
package io.prophecy.libs

import com.fasterxml.jackson.databind.ObjectMapper
import de.greenrobot.common.hash.Murmur3F
import io.prophecy.abinitio.ScalaFunctions._
import org.apache.commons.lang.StringEscapeUtils
import org.json.XML
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.crosssupport.SparkFunctionsCrossSupport
import org.apache.spark.sql.functions.{lit, _}
import org.apache.spark.sql.types._
import org.joda.time.format.DateTimeFormat

import java.math.BigInteger
import java.nio.{ByteBuffer, ByteOrder}
import java.sql.Date
import java.text.{DecimalFormat,            DecimalFormatSymbols, SimpleDateFormat}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
import java.time.temporal.ChronoField
import java.time.{LocalDate, LocalDateTime, ZonedDateTime}
import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

object SparkFunctions {}

  * Library of all spark functions which implements different abinitio functions used in abinitio workflows.
trait SparkFunctions {
  def registerAllUDFs(spark: SparkSession): Unit = {
    spark.udf.register("string_pad",                string_pad)
    spark.udf.register("string_index_with_offset",  string_index_with_offset)
    spark.udf.register("string_filter",             string_filter)
    spark.udf.register("generate_sequence",         generate_sequence)
    spark.udf.register("re_get_match_with_index",   re_get_match_with_index)
    spark.udf.register("string_filter_out",         string_filter_out)
    spark.udf.register("number_grouping",           number_grouping)
    spark.udf.register("re_get_match",              re_get_match)
    spark.udf.register("re_index_with_offset",      re_index_with_offset)
    spark.udf.register("file_information",          file_information)
    spark.udf.register("re_index",                  re_index)
    spark.udf.register("string_split_no_empty",     string_split_no_empty)
    spark.udf.register("string_split",              string_split)
    spark.udf.register("type_info",                 type_info)
    spark.udf.register("record_info",               record_info)
    spark.udf.register("type_info_with_includes",   type_info_with_includes)
    spark.udf.register("record_info_with_includes", record_info_with_includes)
    spark.udf.register("read_file",                 read_file)
    spark.udf.register("unique_identifier",         unique_identifier)
    spark.udf.register("hash_MD5",                  hash_MD5)
    spark.udf.register("string_like",               string_like)
    spark.udf.register("re_split_no_empty",         re_split_no_empty)
    spark.udf.register("string_index",              string_index)
    spark.udf.register("string_rindex",             string_rindex)
    spark.udf.register("string_rindex_with_offset", string_rindex_with_offset)
    spark.udf.register("make_byte_flags",           make_byte_flags)
    spark.udf.register("translate_bytes",           translate_bytes)
    spark.udf.register("test_characters_all",       test_characters_all)
    spark.udf.register("string_convert_explicit",   string_convert_explicit)
    spark.udf.register("cross_join_index_range",    cross_join_index_range)
    spark.udf.register("getByteFromByteArray",      getByteFromByteArray)
    spark.udf.register("getShortFromByteArray",     getShortFromByteArray)
    spark.udf.register("getIntFromByteArray",       getIntFromByteArray)
    spark.udf.register("url_encode_escapes",        url_encode_escapes)
    spark.udf.register("force_error",               force_error)
    spark.udf.register("getLongFromByteArray",      getLongFromByteArray)
    spark.udf.register("getLongArrayFromByteArray", getLongArrayFromByteArray)
    spark.udf.register("murmur",                    murmur)
    spark.udf.register("eval",                      eval)
    spark.udf.register("truncateMicroSeconds",      truncateMicroSeconds)
    spark.udf.register("string_cleanse",            string_cleanse)
    spark.udf.register("canonical_representation",  canonical_representation)
    spark.udf.register("string_representation",     string_representation)
    spark.udf.register("datetime_difference",       datetime_difference)
    spark.udf.register("encode_date",               encode_date)
    spark.udf.register("date_month_end",            date_month_end)
    spark.udf.register("is_bzero",                  is_bzero)
    spark.udf.register("readBytesIntoLong",         readBytesIntoLong)
    spark.udf.register("readBytesIntoInteger",      readBytesIntoInteger)
    spark.udf.register("writeLongToBytes",          writeLongToBytes)
    spark.udf.register("decodeString",              decodeString)
    spark.udf.register("encodeString",              encodeString)
    spark.udf.register("writeIntegerToBytes",       writeIntegerToBytes)

      udf((input: String, prefix: String) ⇒ if (input == null || prefix == null) null else input.startsWith(prefix),
      udf((input: String, suffix: String) ⇒ if (input == null || suffix == null) null else input.endsWith(suffix),
    spark.udf.register("is_blank", udf((input: String) ⇒ if (input == null) null else input.trim.isEmpty, BooleanType))
                       udf { (input: Any, len: Int) ⇒
                         _decimal_lpad(input, len)
                       udf { (input: Double, number_of_places: Int) ⇒
                         (input * Math.pow(10, number_of_places)).asInstanceOf[Long] / Math.pow(10, number_of_places)

    //    spark.udf.register("re_match_replace")

    //    spark.udf.register("vector_sort_dedup_first", )

    //    spark.udf.register("string_substring", (input: Column, start_position: Column, length: Column) ⇒ input.substr(start_position, length))
    //    import org.apache.hadoop.hive.ql.exec.UDF;
    //    class SimpleUdf extends UDF {
    //      def string_substring(input: Column, start_position: Column, length: Column): Column =
    //        input.substr(start_position, length)
    //    }
    //    spark.sql("CREATE FUNCTION simple_udf AS 'SimpleUdf'")


    * Method to find substring of input string.
    * @param input          string on which to find substring.
    * @param start_position 1 based starting position to find substring from.
    * @param length         total length of substring to be found.
    * @return substring of input string
  def string_substring(input: Column, start_position: Column, length: Column): Column =
    input.substr(start_position, length)

  def string_prefix(input: Column, length: Column): Column =
    when(length < lit(0), lit("")).otherwise(input.substr(lit(1), length))

    * function trims the string and then pad the string with given character upto given length.
    * if the length of trimmed string is equal to or greater than given length than it return input
    * string
    * @param input            input string
    * @param len              length in number of characters.
    * @param char_to_pad_with A character used to pad input string to length len.
    * @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
  def string_lrepad(input: Column, len: Int, char_to_pad_with: String = " "): Column = {
    when(isnull(input),                lit(null)).otherwise(
      when(length(trim(input)) >= len, input)
        .otherwise(lpad(trim(input),   len, char_to_pad_with))

    * Returns a number rounded up to a specified number of places to the right of the decimal point.
    * @param input
    * @param places
    * @return
  def decimal_round_up(input: Column, places: Int): Column =
    ceil(input * pow(lit(10), places)) / pow(lit(10), places)

    * Function returns a value which is rounded down to right_digits number of digits to the right of decimal point.
    * @param input
    * @param right_digits
    * @return
  def decimal_round_down(input: Column, right_digits: Int): Column =
    floor(input * pow(lit(10), right_digits)) / pow(lit(10), right_digits)

  def decimal_round(input: Column, places: Int): Column =
    when(decimal_round_up(input, places) - input > input - decimal_round_down(input, places),
         decimal_round_down(input, places)
      when(decimal_round_up(input, places) - input < input - decimal_round_down(input, places),
           decimal_round_up(input, places)
        when(abs(decimal_round_up(input, places)) > abs(decimal_round_down(input, places)),
             decimal_round_up(input, places)
        ).otherwise(decimal_round_down(input, places))

  def decimal_truncate(input: Column, number_of_places: Column): Column =
    (input * pow(lit(10), number_of_places)).cast(LongType) / pow(lit(10), number_of_places)

    * Returns the internal representation of a date resulting from adding (or subtracting) a number of months to the specified date.
    * @param inputDate in yyyy-MM-dd format
    * @param months
    * @return
  def date_add_months(inputDate: Column, months: Int) = add_months(inputDate, months)

    * integer values specifying days relative to January 1, 1900.
    * This function returns the internal representation of a date
    * given the year, month, and date.
    * encode_date returns the internal representation of the date specified by
    * the year 1998, the month 5, and the day 18:encode_date(1998, 5, 18) = 35931
  val encode_date = udf { (year: Int, month: Int, day: Int) ⇒
    import java.time.LocalDate
    import java.time.format.DateTimeFormatter
    val formatter   = DateTimeFormatter.ofPattern("yyyy-MM-dd")
    val baseDate    = LocalDate.parse("1900-01-01", formatter)
    var monthString = s"$month"
    var dayString   = s"$day"
    var yearString  = s"$year"
    while (dayString.length < 2) dayString = "0" + dayString
    while (monthString.length < 2) monthString = "0" + monthString
    while (yearString.length < 4) yearString = "0" + yearString
    val endDate = LocalDate.parse(s"$yearString-$monthString-$dayString", formatter)
    endDate.toEpochDay - baseDate.toEpochDay

  val date_month_end = udf({ (month: Int, year: Int) ⇒
                             if (month == null || year == null) null
                             else {
                               month match {
                                 case 1 | 3 | 5 | 7 | 8 | 10 | 12 ⇒ 31
                                 case 4 | 6 | 9 | 11              ⇒ 30
                                 case 2                           ⇒ if (year % 4 != 0) 28 else 29

    * Returns the internal representation of a timestamp resulting from adding (or subtracting) a number of months to the specified timestamp.
    * @param input timestamp in  yyyy-MM-dd HH:mm:ss.SSSS format
    * @param months
    * @return
  def datetime_add_months(input: Column, months: Int): Column = add_months(input, months)

    * function trims the string and then pad the string on right side with given character upto given length.
    * if the length of trimmed string is equal to or greater than given length than it return input
    * string
    * @param input            input string
    * @param len              length in number of characters.
    * @param char_to_pad_with A character used to pad input string to length len.
    * @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
  def string_repad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
    when(isnull(input).or(length(trim(input)) >= len), input)
      .otherwise(rpad(trim(input),                     len, char_to_pad_with))

    * function pads input on the right with the character char_to_pad_with to make the string length len. If str is
    * already len or more characters long, the function returns input unmodified.
    * @param input
    * @param len
    * @param char_to_pad_with
    * @return
  def string_pad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
    when(isnull(input).or(length(input) >= len), input)
      .otherwise(rpad(input,                     len, char_to_pad_with))

  val string_pad = udf(
    { (input: String, len: Int) ⇒
      if (input == null || len == null) null
      else {
        if (input.length < len) input + (0 until (len - input.length)).map(" ").mkString("")
        else input

  val string_pad_with_char = udf(
    { (input: String, len: Int, char_to_pad_with: String) ⇒
      if (input == null || len == null || char_to_pad_with == null) null
      else {
        if (input.length < len) input + (0 until (len - input.length)).map(char_to_pad_with).mkString("")
        else input

    * Replaces all substrings in a target string that match a specified regular expression.
    * @param target      A string that the function searches for a substring that matches pattern_expr.
    * @param pattern     regular expression
    * @param replacement replacement string
    * @param offset      Number of characters, from the beginning of str, to skip before searching.
    * @return a replaced string in which all substrings, which matches a specified regular expression, are replaced.
  def re_replace(target: Column, pattern: String, replacement: String, offset: Int = 0): Column = {
    when(target.isNull, target).otherwise(
        target.substr(1, offset),
          target.substr(lit(offset + 1), length(target) - lit(offset)),

    * Replaces only the first regex matching occurrence in the target string.
    * @param target      A string that the function searches for a substring that matches pattern_expr.
    * @param pattern     regular expression
    * @param replacement replacement string
    * @return a replaced string in which first substring, which matches a specified regular expression, is replaced.
  def re_replace_first(target: Column, pattern: String, replacement: String, offset: Column = lit(0)): Column = {
    when(target.isNull.or(!target.rlike(pattern)), target).otherwise(
        regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 1),
        regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 3)

    * Left-pad the input string column with pad_char to a length of len. If length of input column is more than len
    * then returns input column unmodified.
    * @param input
    * @param len
    * @param pad_char
    * @return
  def string_lpad(input: Column, len: Int, pad_char: String = " "): Column =
    when(input.isNull || length(input) >= lit(len), input)
      .otherwise(lpad(input,                        len, pad_char))

    * Function uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9 ]*[0-9] and [1-9]+ regex
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [1-9][0-9]*(\\$$decimal_point_char)[0-9 ]+ and (0\\$$decimal_point_char)[0-9 ]*[0-9] regex
    * After extracting decimal number this code looks for minus sign before extracted number in input and appends it
    * with decimal number if found minus sign.
    * In the end it replaces all whitespaces with empty string in the final resultant decimal number.
    * @param input              input string
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal from a string that has been trimmed of leading zeros and non-numeric characters.
  def decimal_strip(input: Column, decimal_point_char: String = "."): Column = {
    val pattern =
      s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
    val sqlPattern = lit(pattern).expr.sql
    val columnSql  = input.expr.sql

    when(input.isNull, lit("0")).otherwise(
      when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
              instr_udf(input, regexp_extract(input, pattern, 1))
          ) > 0,
            regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
          regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")

  //  def decimal_strip(input: Column, columnName: String, decimal_point_char: String = "."): Column = {
  //    val pattern =
  //      s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
  //    val sqlPattern = lit(pattern).expr.sql
  //    when(input.isNull, lit("0")).otherwise(
  //      when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
  //        when(
  //          instr(
  //            input.substr(
  //              lit(0),
  //              expr(
  //                s"instr($columnName, regexp_extract($columnName, $sqlPattern, 1))"
  //              )
  //            ),
  //            "-"
  //          ) > 0,
  //          concat(
  //            lit("-"),
  //            regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
  //          )
  //        ).otherwise(
  //          regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
  //        )
  //      )
  //    )
  //  }

  val instr_udf = udf({ (inputStr: String, substring: String) ⇒
                        if (inputStr == null || substring == null) null
                          inputStr.indexOf(substring) + 1

    * UDF to find index of seekStr in inputStr from offset index onwards.
    * Returned string position is 1 based position.
  val string_index_with_offset = udf(
    (inputStr: String, seekStr: String, offset: Int) ⇒ {
      if (inputStr == null || seekStr == null || offset == null) null
      else {
        if (offset <= inputStr.length) {
          val index = inputStr.substring(offset).indexOf(seekStr)
          if (index < 0) {
          } else {
            offset + index + 1
        } else {

    * Method which returns string of characters present in both of the strings in the same order as appearing in first
    * string
  val string_filter = udf(
    (inputStr1: String, inputStr2: String) ⇒ {
      if (inputStr1 == null || inputStr2 == null) null
      else {
        val str2Set = inputStr2.toCharArray.toSet

    * Function to replace occurrence of seekStr with newStr string in input string after offset characters from first character.
    * @param input   input string on which to perform replace operation.
    * @param seekStr string to be replaced in input string.
    * @param newStr  string to be used instead of seekStr in input string.
    * @param offset  number of characters to skip from begining in input string before performing string_replace operation.
    * @return modified string where seekStr is replaced with newStr in input string.
  def string_replace(input: Column, seekStr: Column, newStr: Column, offset: Column = lit(0)): Column = {
    when(isnull(input),             lit(null)).otherwise(
      when(length(input) <= offset, input).otherwise(
          string_substring(input, lit(0), offset),
            string_substring(input, offset + 1, length(input)),

  val string_replace_first = udf({ (input: String, seekStr: String, newStr: String) ⇒
                                   if (input == null || seekStr == null || newStr == null) null
                                     input.replaceFirst(seekStr, newStr)


  val string_replace_in_loop = udf(
    { (input: String, seekStrs: Seq[String], newStrs: Seq[String]) ⇒
      if (input == null || seekStrs == null || newStrs == null) null
      else {
        var currentInput = input
        (0 until Math.min(seekStrs.length, newStrs.length)).foreach { itr ⇒
          currentInput = currentInput.replaceAll(seekStrs(itr), newStrs(itr))

    * Method which returns true if input string contains all alphabetic characters, or false otherwise.
    * @param input
    * @return
  def string_is_alphabetic(input: Column): Column = input.rlike("^[a-zA-Z]+$")

    * Method which returns true if input string contains all numeric characters, or false otherwise.
    * @param input
    * @return
  def string_is_numeric(input: Column): Column =
    when(input.rlike("^[0-9]+$").or(length(input) === lit(0)), lit(true))

    * Returns true if string columns starts with given prefix
  def starts_with(input: Column, prefix: String): Column =

    * Returns true if string columns ends with given suffix
  def ends_with(input: Column, suffix: String): Column =

    * UDF to generate column with sequence of integers between two passed start and end columns.
  val generate_sequence = udf { (start: Int, end: Int) ⇒
    (start to end).toArray

    * Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9]*[0-9] and [1-9]+ regex
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [1-9][0-9]*(\\\\$$decimal_point_char)[0-9]+ and (0\\\\$$decimal_point_char)[0-9]*[0-9] regex
    * After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
    * If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
    * decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
    * part of decimal number.
    * @param input              input string.
    * @param len                length of characters.
    * @param char_to_pad_with   character to left pad with. default value is "0"
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
    *         trimmed of leading zeros.
  def decimal_lrepad(
    input:              Column,
    len:                Int,
    char_to_pad_with:   String = "0",
    decimal_point_char: String = "."
  ): Column = {
    val pattern =

    when(input.isNull, lit(null)).otherwise(
        length(regexp_extract(input, pattern, 0)) >= len,
        regexp_extract(input, pattern, 0)
            lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),

    * Method to return character code of character at index position in inputStr string.
    * @param inputStr input string
    * @param index    location of character to get code.
    * @return integer column.
  def string_char(inputStr: Column, index: Int): Column = {
      (lit(index) <= lit(0)).or(lit(index) > length(inputStr)).or(inputStr.isNull),
      ascii(inputStr.substr(lit(index), length(inputStr) - lit(index) + 1))

    * Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by [[0-9]+ regex.
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [0-9]+(\\$$decimal_point_char)[0-9]+ and (0\\$$decimal_point_char)[0-9]+ regex
    * After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
    * If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
    * decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
    * part of decimal number.
    * @param input              input string.
    * @param len                length of characters.
    * @param char_to_pad_with   character to left pad with. default value is "0"
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
    *         trimmed of leading zeros.
  def decimal_lpad(
    input:              Column,
    len:                Int,
    char_to_pad_with:   String = "0",
    decimal_point_char: String = "."
  ): Column = {
    val pattern =

    when(input.isNull, lit(null)).otherwise(
        length(regexp_extract(input, pattern, 0)) >= len,
        regexp_extract(input, pattern, 0)
            lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),

  val re_get_match_with_index = udf(
    { (input: String, pattern: String, offset: Int) ⇒
      if (input == null || pattern == null || offset == null) null
      else {
        val result =
            .map(x ⇒ input.substring(x.start + offset, x.end + offset))
        if (result.hasNext) {

        } else null

    * Compares two input strings, then returns characters that appear in one string but not in the other.
  val string_filter_out = udf(
    (input1: String, input2: String) ⇒ {
      if (input1 == null || input2 == null) null
      else {
        val str2Set = input2.toCharArray.toSet

    * udf to group input decimal into multiple groups separated by separator
  val number_grouping = udf(
    { (input: String, groups: Int, seperator: String) ⇒
      if (input == null || groups == null || seperator == null) null
      else {
        val result = input.split("\\.").map(group_number(_, groups, ',')).mkString(".")
        if (seperator != ",") result.replaceAll(",", seperator) else result

  private def group_number(input: String, groups: Int, seperator: Char): String = {
    val symbols = new DecimalFormatSymbols()
    val dfDecimal = new DecimalFormat("###########0.0###")
    val result = dfDecimal.format(BigDecimal(input))
    result.substring(0, result.length - 2)

    * Returns the first string in a target string that matches a regular expression.
  val re_get_match = udf { (input: String, pattern: String) ⇒
    _re_get_match(input, pattern)
  val decodeString =
    udf { (input: String, charSet: String) ⇒
      new String(input.getBytes("windows-1252"), charSet)

  val decodeBytes =
    udf { (input: Array[Byte], charSet: String) ⇒
      new String(input, charSet)

  val encodeString =
    udf { (input: String, charSet: String) ⇒
      val result = new String(input.getBytes(charSet), "windows-1252")

  val encodeBytes =
    udf { (input: Array[Byte], charSet: String) ⇒
      new String(input).getBytes(charSet)

  val writeLongToBytes =
    udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _writeLongToBytes(input, length, endian, isUnsigned)

  val writeIntegerToBytes =
    udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _writeIntegerToBytes(input, length, endian, isUnsigned)

  val readBytesStringIntoLong =
    udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesStringIntoLong(input, length, endian, isUnsigned)

  val readBytesIntoLong =
    udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesIntoLong(input, length, endian, isUnsigned)

  val readBytesIntoInteger =
    udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesIntoInteger(input, length, endian, isUnsigned)

  val readBytesStringIntoInteger =
    udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesStringIntoInteger(input, length, endian, isUnsigned)

  val packedBytesToDecimal = udf { (input: Array[Byte], scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
    _packedBytesToDecimal(input, scale, isUnsigned, isStripped)

  val packedBytesStringToDecimal = udf { (input: String, scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
    _packedBytesStringToDecimal(input, scale, isUnsigned, isStripped)

  val bigDecimalToPackedBytes = udf {
    (input: java.math.BigDecimal, precision: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
      _bigDecimalToPackedBytes(input, precision, isUnsigned, isStripped)

    * Returns the first string in a target string that matches a regular expression.
  val re_index_with_offset = udf { (input: String, pattern: String, offset: Int) ⇒
    val result =
      pattern.r.findAllMatchIn(input.substring(offset - 1)).map(_.start)
    if (result.hasNext) { + offset
    } else {

  def datetime_from_unixtime(seconds: Column): Column =
    from_unixtime(seconds, "yyyyMMddHHmmssSSSSSS")

    * UDF to get file information for passed input file path.
  val file_information = udf { (inputPath: String) ⇒

  def getMTimeDataframe(filepath: String, format: String, spark: SparkSession): DataFrame = {
    import spark.implicits._
         .print(SparkFunctionsCrossSupport.file_information_def(filepath).modified * 1000)
    ).toDF("file", "mtime")

    * UDF wrapper over re_index function.
  val re_index = udf { (input: String, pattern: String) ⇒
    val result = pattern.r.findAllMatchIn(input.substring(0)).map(_.start)
    if (result.hasNext) { + 1
    } else {

  def string_suffix(input: Column, len: Int): Column =
    when(lit(len) >= length(input),       input)
      .otherwise(when(lit(len) <= lit(0), lit("")).otherwise(input.substr(length(input) - lit(len) + lit(1), lit(len))))

  def scanf_long(format: Column, value: Column): Column =
    when(format === lit("%Ld"), value.cast(LongType)).otherwise(lit(null))

  def scanf_double(format: Column, value: Column): Column =
    when(format === lit("%lf"), value.cast(DoubleType)).otherwise(lit(null))

    * Concatenates the elements of column using the delimiter.
  def string_join(column: Column, delimiter: String): Column =
    SparkFunctionsCrossSupport.string_join(column, delimiter)

    * Method to zip two arrays with first one having event_type and second one having event_text
    * @param column1
    * @param column2
    * @return
  def zip_eventInfo_arrays(column1: Column, column2: Column): Column =
    SparkFunctionsCrossSupport.zip_eventInfo_arrays(column1, column2)

    * Method to get field at specific position from struct column
    * @param column
    * @return
  def getFieldFromStructByPosition(column: Column, position: Int): Column =
    SparkFunctionsCrossSupport.getFieldFromStructByPosition(column, position)

    * UDF to split input string via delimiter string and remove all empty subtrings.
  val string_split_no_empty = udf { (input: String, delimiter: String) ⇒

  def replace_null_with_blank(input: Column): Column = when(input.isNull, lit("")).otherwise(input)

    * UDF to split input string via delimiter string.
  val string_split = udf { (input: String, delimiter: String) ⇒

  val type_info = udf { (dml_type: String) ⇒

  val record_info = udf { (dml_type: String) ⇒

  val type_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
    _type_info(dml_type, includes)

  val record_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
    _record_info(dml_type, includes)

  val read_file = udf { (filePath: String) ⇒

  val unique_identifier = udf { () ⇒

  val hash_MD5 = udf { (input: Any) ⇒
    val inputBytes    = if (input.isInstanceOf[String]) input.asInstanceOf[String].getBytes else getByteArray(input)
    val messageDigest = MessageDigest.getInstance("MD5")

  val string_representation = udf { (input: Any) ⇒

  val canonical_representation =
      (input: Any) ⇒ {
        val (encoding, value): (Short, String) = if (input == null) {
          (0, "NULL")
        } else if (input.isInstanceOf[String]) {
          try {
            val rep = new String(input.toString.getBytes, "utf-8")
            (1, rep)
          } catch {
            case _: Throwable ⇒
              (2, input.toString)
        } else if (input.isInstanceOf[Row] || input.isInstanceOf[Seq[_]]) {
          (3, "")
        } else {
          (1, input.toString)
          StructField("encoding", ShortType,  false),
          StructField("value",    StringType, false)

    * Method to test whether a string matches a specified pattern. This function returns 1 if the input string matches
    * a specified pattern, and 0 if the string does not match the pattern.
    * In abinitio version % character in pattern means to match zero or more characters and _ character means matches a
    * single character.
  val string_like = udf { (input: String, pattern: String) ⇒
    input.matches(pattern.replaceAll("%", "(.*)").replaceAll("_", "."))

    * UDF to split input string via pattern string and remove all empty subtrings.
  val re_split_no_empty = udf { (input: String, pattern: String) ⇒

    * UDF to find index of seekStr in inputStr. Returned index will be
    * 1 based index.
  val string_index = udf { (inputStr: String, seekStr: String) ⇒
    inputStr.indexOf(seekStr) + 1

    * Returns the index of the first character of the last occurrence of a seek string within another input string.
    * Returned index is 1 based.
  val string_rindex = udf { (inputStr: String, seekStr: String) ⇒
    inputStr.lastIndexOf(seekStr) + 1

    * UDF to find index of seekStr in inputStr from end of inputStr skipping offset number of characters from end. Offset index is number
    * of characters, from the end of str, to skip before searching. Returned string position is 1 based position.
  val string_rindex_with_offset = udf((inputStr: String, seekStr: String, offset: Int) ⇒ {
    if (offset <= inputStr.length) {
      val index =
        inputStr.substring(0, inputStr.length - offset).lastIndexOf(seekStr)
      index + 1
    } else {

    * UDF to return a flag for each character if it is present or not in input String.
  val make_byte_flags = udf { (inputStr: String) ⇒
    val charFlag = new Array[Int](256)
    inputStr.toCharArray.foreach(c ⇒ charFlag(c.toInt) = 1)

    * UDF to return a string in the native character set made up of bytes from the given map. Each byte of the result is
    * the value of map indexed by the character code of the corresponding byte of the input string str. The function
    * returns NULL if any argument is NULL.
  val translate_bytes = udf((inputStr: String, byteMap: Seq[Byte]) ⇒ {
    if (inputStr == null || byteMap == null)
    else {
      val translatedChars = { c ⇒
        if (c < byteMap.length && c >= 0)
        else c
      new String(translatedChars)

    * UDF to identify the number of characters in inputStr which are present in charFlag
  val test_characters_all = udf { (inputStr: String, charFlag: Seq[Int]) ⇒
    inputStr.toCharArray.filter(c ⇒ charFlag(c.toInt) == 1).size

    * Converts a string from one character set to another, replacing inconvertible characters with a specified string.
  val string_convert_explicit = udf { (input: String, charSet: String, replaceStr: String) ⇒
    new String(input.getBytes(), charSet).replaceAll("�", replaceStr)

    * This implementation is incorrect.
  val string_cleanse = udf { (input: String, replacement: String, charSet: String) ⇒
    new String(input.getBytes(), charSet).replaceAll("�", replacement)

  val decode_datetime_as_local =
      (time: String) ⇒ {

        val datetimeInfo = _decode_datetime(time)
 / 60
          StructField("year",            IntegerType, false),
          StructField("month",           IntegerType, false),
          StructField("day",             IntegerType, false),
          StructField("hour",            IntegerType, false),
          StructField("minute",          IntegerType, false),
          StructField("second",          IntegerType, false),
          StructField("microsecond",     IntegerType, false),
          StructField("timezone_offset", ShortType,   false)

  val xmlToJSON = udf { (input: String) ⇒
    val jsonObj    = XML.toJSONObject(input)
    val newJsonObj = if (jsonObj.length() == 1) jsonObj.getJSONObject(jsonObj.keys().next()) else jsonObj

  def from_xml(content: Column, schema: StructType): Column = {
    val json = xmlToJSON(content)
    from_json(json, schema)

  def numberOfPartitions(in: DataFrame): Column =

  def findLastElement(input: Column, default: Column = lit(null)): Column = {
    val lastElement = SparkFunctionsCrossSupport.element_at(reverse(input), lit(1))
    when(lastElement.isNull, default).otherwise(lastElement)

  def string_length(input: Column) = length(input)

  def findFirstElement(input: Column, default: Column = lit(null)): Column = {
    val firstElement = SparkFunctionsCrossSupport.element_at(input, lit(1))
    when(firstElement.isNull, default).otherwise(firstElement)

  def findFirstNonBlankElement(input: Column, default: Column): Column = {
    val firstElement = SparkFunctionsCrossSupport.element_at(
      SparkFunctionsCrossSupport.array_except(input, array(lit(""), lit(" "), lit("  "), lit("   "))),
    when(firstElement.isNull, default).otherwise(firstElement)

  def getColumnInSecondArrayByFirstNonBlankPositionInFirstArray(
    nonBlankEntryExpr: Column,
    firstArray:        Column,
    secondArray:       Column
  ): Column = {
          SparkFunctionsCrossSupport.element_at(nonBlankEntryExpr, lit(1))

  def schemaRowCompareResult(row1: StructType, row2: StructType): Column = {
    val value1 = concat(flattenStructSchema(row1): _*)
    val value2 = concat(flattenStructSchema(row2): _*)
    value1 === value2

  def flattenStructSchema(schema: StructType, prefix: String = null): Array[Column] = {
    schema.fields.flatMap { field ⇒
      val columnName = if (prefix == null) else (prefix + "." +

      field.dataType match {
        case st: StructType ⇒ flattenStructSchema(st, columnName)
        case _ ⇒ Array(col(columnName).as(columnName.replace(".", "_")))

    * UDF to get record of type decode_datetime_type. This record will have all its fields populated with
    * corresponding entries in input date/timestamp.
    * Returned record will have following schema.
    * integer(8) year;
    * integer(8) month;
    * integer(8) day;
    * integer(8) hour;
    * integer(8) minute;
    * integer(8) second;
    * integer(8) microsecond;
    * Note: Supported Input time is in yyyy-MM-dd HH:mm:ss.SSSSSS or yyyy-MM-dd HH:mm:ss or yyyy-MM-dd formats only.
    * Additional handling is done to support timestamp retrieved from now() function call.
  val decode_datetime =
      (time: String) ⇒ {

        val datetimeInfo = _decode_datetime(time)
          StructField("year",        IntegerType, false),
          StructField("month",       IntegerType, false),
          StructField("day",         IntegerType, false),
          StructField("hour",        IntegerType, false),
          StructField("minute",      IntegerType, false),
          StructField("second",      IntegerType, false),
          StructField("microsecond", IntegerType, false)

  val cross_join_index_range = udf(
    (input1: Seq[Int], input2: Seq[Int]) ⇒ input1.flatMap(i ⇒ ⇒ Row.fromSeq(Seq(i, j)))),
          StructField("i", IntegerType, false),
          StructField("j", IntegerType, false)

    * Method to identify if input string is a blank string or not.
    * @param input input string.
    * @return return 1 if given string contains all blank character or is a zero length string,
    *         otherwise it returns 0
  def is_blank(input: Column): Column =
    when(isnull(input),                               lit(null))
      .otherwise(when(length(trim(input)) === lit(0), lit(true)).otherwise(lit(false)))

    * Method to list all files present in a passed path where filename starts with passed
    * prefix.
    * @param path
    * @param filePrefix
    * @return
  private def listFiles(path: String, filePrefix: String) = {
    new File(path)

  def directory_listing(path: String, filePrefix: String) =
    typedLit(listFiles(path, filePrefix))

  private def getDataTypeFromIntLength(lenValue: Int, dataType: DataType) = {
    lenValue match {
      case 1 ⇒ ByteType
      case 2 ⇒ ShortType
      case 4 ⇒ IntegerType
      case 8 ⇒ LongType
      case _ ⇒ dataType

  def is_valid(input: Column): Column =
    is_valid(input, false)

  def is_valid(input: Column, isNullable: Boolean): Column =
    is_valid(input, isNullable, None, None)

  def is_valid(input: Column, formatInfo: Option[Any]): Column =
    is_valid(input, false, formatInfo, None)

  def is_valid(input: Column, formatInfo: Option[Any], len: Option[Seq[Int]]): Column =
    is_valid(input, false, formatInfo, len)

  def is_valid(input: Column, isNullable: Boolean, formatInfo: Option[Any]): Column =
    is_valid(input, isNullable, formatInfo, None)

    * Method to identify if passed input column is a valid expression after typecasting to passed dataType.
    * Also while typecasting if len is present then this function also makes sure the max length of input column
    * after typecasting operation is not greater than len.
    * @param input      input column expression to be identified if is valid.
    * @param formatInfo datatype to which input column expression must be typecasted.
    *                   If datatype is a string then it is treated as timestamp format. If it is
    *                   a list of string then it is treated as having current timestamp format and
    *                   and new timestamp format to which input column needs to be typecasted.
    * @param len        max length of input column after typecasting it to dataType.
    * @return 0 if input column is not valid after typecasting or 1 if it is valid.
  def is_valid(
    input:      Column,
    isNullable: Boolean,
    formatInfo: Option[Any],
    len:        Option[Seq[Int]]
  ): Column = {
    if (isNullable) lit(true)
    else {
      val result: Column = len match {
        case None | Some(Nil) ⇒
          formatInfo match {
            case Some(typecast: DataType) ⇒ input.cast(typecast)
            case Some(format: String) ⇒ to_timestamp(input, format)
            case Some(List(currentFormat: String, newFormat: String)) ⇒
              date_format(to_timestamp(input, currentFormat), newFormat)
            case _ ⇒ input // TODO: support other types
        case Some(lenValue) ⇒
          formatInfo match {
            case Some(x: DecimalType) ⇒
              if (lenValue.length == 2) {
                input.cast(DecimalType(lenValue(0), lenValue(1)))
              } else if (lenValue.length == 1) {
                input.cast(DecimalType(lenValue(0), 0))
              } else {
            case Some(x: IntegerType) ⇒ input.cast(getDataTypeFromIntLength(lenValue.head, x))
            case Some(y: StringType) ⇒
              when(length(input.cast(y)) > lit(lenValue.head), lit(null))
            case Some(z: DoubleType) ⇒
              when(input.cast(DecimalType(lenValue.head, 0)).isNull, lit(null))
                  input.cast(DecimalType(lenValue.head, 0)).cast(StringType).cast(z)
            case _ ⇒ input // TODO: support other types
      when(result.isNull, lit(false)).otherwise(lit(true))

    * Method to get current timestamp.
    * @return current timestamp in YYYYMMddHHmmssSSSSSS format.
  def now(): Column = date_format(current_timestamp(), "yyyyMMddHHmmssSSSSSS")

    * Returns the number of hours between two specified dates in standard format yyyy-MM-dd HH:mm:ss.SSSS.
    * @param end
    * @param start
    * @return
  def datetime_difference_hours(end: Column, start: Column): Column =
    (unix_timestamp(end) - unix_timestamp(start)) / lit(3600.0)

  val datetime_difference =
      { (timestamp1: String, timestamp2: String) ⇒
        val endTime        = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp1).getTime
        val startTime      = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp2).getTime()
        val differenceTime = endTime - startTime
        val days           = differenceTime / (24 * 60 * 60 * 1000)
        var remainingTime  = differenceTime % (24 * 60 * 60 * 1000)
        val hours          = remainingTime / (60 * 60 * 1000)
        remainingTime = remainingTime % (60 * 60 * 1000)
        val minutes = remainingTime / (60 * 1000)
        remainingTime = remainingTime % (60 * 1000)
        val seconds = remainingTime / 1000
        remainingTime = remainingTime % 1000
        val microSeconds = remainingTime * 1000
        Row(days, hours, minutes, seconds, microSeconds)
          StructField("days",         LongType, true),
          StructField("hours",        LongType, true),
          StructField("minutes",      LongType, true),
          StructField("seconds",      LongType, true),
          StructField("microseconds", LongType, true)

  def string_compare(input1: Column, input2: Column): Column =
    when(input1.isNull.or(input2.isNull), lit(null))
      .otherwise(when(input1 < input2,    lit(-1)).otherwise(when(input1 > input2, lit(1)).otherwise(lit(0))))

    * Method to convert
    * @param timezone
    * @param time
    * @return
  def timezone_to_utc(timezone: String, time: Column): Column =
    if (timezone == "local") to_utc_timestamp(time, TimeZone.getDefault.toZoneId.getId)
    else to_utc_timestamp(time,                     timezone)

    * Method to return integer value representing number of days to today from “1-1-1990”.
    * @return integer value
  def today(): Column =
    datediff(current_date(), to_date(lit("1990-01-01"), "yyyy-MM-dd"))

    * Method to check if current column is null or has empty value.
    * @param input
    * @return
  def isNullOrEmpty(input: Column): Column =
    when(input.isNull || length(trim(input)) === lit(0), lit(true))

    * Function to create sequence of array between two passed numbers
    * @param start starting point of generated sequence
    * @param end   terminating point of generated sequence.
    * @return column containing sequence of integers.
  def generate_sequence(start: Int, end: Int, step: Int = 1): Column =
    array((start to end).map(lit): _*)

    * Method to create array of size "size" containing seedVal as each entry
    * @param size
    * @param seedVal
    * @return
  def make_constant_vector(size: Int, seedVal: Column): Column =
    array((1 to size).map(x ⇒ seedVal): _*)

    * Method to create array of size "size" containing seedVal as each entry
    * @param size
    * @param seedVal
    * @return
  def make_constant_vector(size: Int, seedVal: Int) =
    (1 to size).map(x ⇒ seedVal).toArray

    * UDF to break input string into multiple string via delimiter. Number of strings after split are adjusted as per
    * passed width parameter. If number of strings are less then empty strings are added otherwise in case of more number
    * of strings, first width number of entries are picked and remaining are discarded.
  val splitIntoMultipleColumnsUdf =
    udf { (input: String, delimiter: String, width: Int) ⇒
      val entries = input.split(delimiter)
      (0 until width)
        .map(index ⇒
          if (entries.length > index) {
          } else {

    * Method to create dataframe with single column containing increasing sequence id from start to end.
    * @param start
    * @param end
    * @param columnName
    * @param sparkSession
    * @return
  def generateDataFrameWithSequenceColumn(
    start:        Int,
    end:          Int,
    columnName:   String,
    sparkSession: SparkSession
  ): DataFrame = {
    import sparkSession.implicits._

    val arrayDF = Seq(((start to end).toArray)).toDF(columnName)

  class StringAsStream(val content: String) extends Serializable {
    var currentOffset: Int = 0

    def read_string(len: Int): String = {
      if (currentOffset >= content.length || currentOffset + len > content.length)
      else {
        val result = content.substring(currentOffset, currentOffset + len)
        currentOffset = currentOffset + len

  def getContentAsStream(content: String): StringAsStream =
    new StringAsStream(content)

    * Method to read values from inputData and create dataframe with column name as columnName and column type as
    * columnType for the values in inputData delimiter by delimiter.
    * @param inputData
    * @param delimiter
    * @param columnName
    * @param columnType
    * @param sparkSession
    * @return
  def createDataFrameFromData(
    inputData:    String,
    delimiter:    String,
    columnName:   String,
    columnType:   String,
    sparkSession: SparkSession
  ): DataFrame = {
    val schema = StructType(
      List(StructField(columnName, getColumnDataType(columnType)))

    val rowData =
      inputData.split(delimiter).map(value ⇒ Row.fromSeq(List(value))).toList
    sparkSession.createDataFrame(rowData.asJava, schema)

    * Method to get spark data type.
    * @param columnType
    * @return
  private def getColumnDataType(columnType: String): DataType = {
    columnType match {
      case "StringType"  ⇒ StringType
      case "IntegerType" ⇒ IntegerType
      case "DoubleType"  ⇒ DoubleType
      case "BooleanType" ⇒ BooleanType
      case _             ⇒ StringType

    * Method to get ByteArray from any passed object.
    * @param input
    * @return
  private def getByteArray(input: Any): Array[Byte] = {
    lazy val default = new ObjectMapper().writeValueAsBytes(input)

    input match {
      case value:  String ⇒ value.getBytes("windows-1252")
      case value:  Long   ⇒ BigInt(value).toByteArray
      case value:  Int    ⇒ BigInt(value).toByteArray
      case value:  Short  ⇒ BigInt(value).toByteArray
      case value:  Byte ⇒ BigInt(value).toByteArray
      case schema: GenericRowWithSchema if schema.size > 0 ⇒
        schema.get(0) match {
          case LongWrappedArray(array) ⇒ array.filterNot(_ == 0).flatMap(x ⇒ BigInt(x).toByteArray).toArray
          case _                       ⇒ default
      case _ ⇒ default

    * Tests whether an object is composed of all binary zero bytes.
    * This function returns:
    * 1. 1 if obj contains only binary zero bytes or is a zero-length string
    * 2. 0 if obj contains any non-zero bytes
    * 3. NULL if obj is NULL
  val is_bzero = udf { (input: Any) ⇒
    if (input.isInstanceOf[String] && input.asInstanceOf[String].isEmpty) true
    else if (input.isInstanceOf[Int] && input.asInstanceOf[Int] == 0) true
    else false

    * UDF to get last Byte from ByteArray of input data.
  val getByteFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    byteArray(byteArray.length - 1)

    * UDF to get short comprising of last 2 Bytes from ByteArray of input data.
  val getShortFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 2, byteArray.length)
    new BigInteger(slicedByteArray).shortValue()

    * UDF to get integer comprising of last 4 Bytes from ByteArray of input data.
  val getIntFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 4, byteArray.length)
    new BigInteger(slicedByteArray).intValue()
  val url_encode_escapes = udf { (input: String) ⇒
    URLEncoder.encode(input, "UTF-8")

  // TODO support for threshold!
  val force_error = udf { (input: String) ⇒
    throw new Exception(s"$input")

    * UDF to get long comprising of last 8 Bytes from ByteArray of input data.
  val getLongFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 8, byteArray.length)
    new BigInteger(slicedByteArray).longValue()

    * Method to check if all characters of passed input are digit or not.
    * @param input
    * @return
  private def isAllDigits(input: String): Boolean =

  object LongWrappedArray {
    def unapply(input: Any): Option[mutable.WrappedArray[Long]] = input match {
      case array: mutable.WrappedArray[_] if array.forall(_.isInstanceOf[Long]) ⇒
      case _ ⇒ None

  object LongSequence {
    def unapply(input: Any): Option[Seq[Long]] = input match {
      case sequence: Seq[_] if sequence.forall(_.isInstanceOf[Long]) ⇒ Some(sequence.asInstanceOf[Seq[Long]])
      case _ ⇒ None

    * Method used for abinitio's reinterpret_as function to read necessary bytes from byteArray for input data and convert
    * into struct format as per provided in typeInfo sequence.
    * TypeInfo can have multiple entries, each could be either decimal or string type. Depending on the argument passed
    * within decimal or string bytes are read from input byte array.
    * If decimal or string argument has some integer then that many bytes are read from input byte array or if decimal or
    * string has some string delimiter as its argument then from the current position bytes are read until string delimiter
    * is found in input byte array.
    * @param input
    * @param typeInfo
    * @return
  def convertInputBytesToStructType(input: Any, typeInfo: Seq[String], startByte: Int = 0): Row = input match {
    case LongSequence(sequence) ⇒
    case _ ⇒
      var curPointer = 0
      var byteArray  = getByteArray(input)
      byteArray = byteArray.slice(startByte, byteArray.length)
      val stringVal = new String(byteArray, "windows-1252")
      val rowValues = ListBuffer[Any]() { curType ⇒
        if (curType.startsWith("decimal") || curType.startsWith("string")) {
          val stInd  = curType.indexOf("(")
          val endInd = curType.indexOf(")")
          val arg    = curType.substring(stInd + 1, endInd)
          if (isAllDigits(arg)) {
            val takeUntil = Math.min(Integer.parseInt(arg) + curPointer, stringVal.length)
            rowValues += stringVal.substring(curPointer, takeUntil)
            curPointer = takeUntil
          } else {
            val trimmedArg = arg.trim
            val pattern    = trimmedArg.substring(1, trimmedArg.length - 1)
            val index      = stringVal.substring(curPointer).indexOf(pattern)
            rowValues += stringVal.substring(curPointer, index + curPointer)
            curPointer = curPointer + index + pattern.length
        } else if (curType == "long") {
          if (curPointer >= byteArray.length) {
            rowValues += 0L
          } else {
            val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
            rowValues += new BigInteger(slicedByteArray).longValue()
            curPointer = curPointer + 8

    * UDF to get long comprising of last 8 Bytes from ByteArray of input data.
  val getLongArrayFromByteArray = udf(
    (input: Any, size: Int, offset: Int) ⇒ {
      var byteArray = getByteArray(input)
      byteArray = byteArray.slice(offset, byteArray.length)
      var curPointer = 0
      (0 until size).map { itr ⇒
        if (curPointer >= byteArray.length) {
        } else {
          val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
          curPointer = curPointer + 8
          new BigInteger(slicedByteArray).longValue()
    DataTypes.createArrayType(LongType, false)

    * UDF for murmur hash generation for any column type
  val murmur = udf { (x: Any) ⇒
    val murmurInstance = new Murmur3F()
    if (x.isInstanceOf[GenericRowWithSchema]) {
      val input          = x.asInstanceOf[GenericRowWithSchema]
      val inputByteArray = (0 until input.length).map(input.get(_)).toArray.flatMap(_.toString.getBytes())
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
    } else if (x.isInstanceOf[Seq[Any]]) {
      val inputByteArray = x.asInstanceOf[Seq[Any]].flatMap(_.toString.getBytes()).toArray
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
    } else {
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
  //    if (x.isInstanceOf[GenericRowWithSchema] && x
  //          .asInstanceOf[GenericRowWithSchema]
  //          .get(0)
  //          .isInstanceOf[mutable.WrappedArray[Any]]) {
  //      MurmurHash3.arrayHash(x.asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[mutable.WrappedArray[Any]].toArray)
  //    } else if (x.isInstanceOf[GenericRowWithSchema] && x
  //                 .asInstanceOf[GenericRowWithSchema]
  //                 .length > 0) {
  //      MurmurHash3.arrayHash(
  //        (0 until x.asInstanceOf[GenericRowWithSchema].length).map(x.asInstanceOf[GenericRowWithSchema].get(_)).toArray
  //      )
  //    } else if (x.isInstanceOf[Seq[Any]]) MurmurHash3.arrayHash(x.asInstanceOf[Array[Any]])
  //    else if (x.isInstanceOf[String]) MurmurHash3.stringHash(x.asInstanceOf[String])

    * Method to return the result of evaluating a string expression in the context of a specified input column. Here
    * input column could be struct type record, simple column, array type etc.
    * Here expr could be reference to nested column inside input column or any expression which requires values from input
    * column for its evaulation.
    * Note: Current implementation only supports scenerio where input column is of struct type and expr is simply dot
    * separated column reference to input struct.
    * @param input
    * @param expr
    * @return
  val eval = udf((input: Any, expr: String) ⇒ {
    if (input == null || expr == null || !input.isInstanceOf[Row]) null
    else {
      val tokens = expr.split("\\.")
      if (tokens.length == 2) {
        val result = input.asInstanceOf[Row].getAs[Row](tokens(0)).getAs[Any](tokens(1))
        if (result != null) {
        } else {
      } else if (tokens.length == 1) {
        val result = input.asInstanceOf[Row].getAs[Any](tokens(0))
        if (result != null) {
        } else {
      } else {

    * UDF to truncate microseconds part of timestamp. This is needed as abinitio and spark has some incompatibility in
    * microseconds part of timestamp format.
  val truncateMicroSeconds = udf((sourceFormat: String, timestamp: String) ⇒ {
    if (timestamp == null || sourceFormat == null) null
    else {
      val index = sourceFormat.indexOf("S")
      if (index < 0) {
      } else {
        timestamp.substring(0, index)

    * Method to replace String Columns with Empty value to Null.
    * @param input
    * @return
  def replaceBlankColumnWithNull(input: Column): Column =
    when(length(input.cast(StringType)) === lit(0), lit(null)).otherwise(input)

    * Method to identify and return first non null expression.
    * @param expr1
    * @param expr2
    * @return
  def first_defined(expr1: Column, expr2: Column): Column =
    when(expr1.isNull, expr2).otherwise(expr1)

    * Adds an explicit sign to the number. E.g.
    * 2 -> +2; -004 -> -004; 0 -> +0
  def sign_explicit(c: Column): Column =
    when(c.startsWith("+") || c.startsWith("-"), c)
      .otherwise(when(c >= 0,                    concat(lit("+"), c)).otherwise(concat(lit("-"), c)))

  def from_sv(input: Column, separator: String, schema: StructType): Column = {
    val splitResult = split(input, separator)
    val fields = {
      case (field, index) ⇒ splitResult.getItem(index).as(

    struct(fields: _*).as("value")

    * Method removes any non-digit characters from the specified string column.
    * @param input input String Column
    * @return Cleaned string column or null
  def remove_non_digit(input: Column): Column = {
    val pattern = "[^\\d]"
    when(input.isNull, lit(null)).otherwise(regexp_replace(input, pattern, ""))

    * Computes number of days between two specified dates in "yyyyMMdd" format
    * @param laterDate   input date
    * @param earlierDate input date
    * @return number of days between laterDate and earlierDate or null if either one is null
  def date_difference_days(laterDate: Column, earlierDate: Column): Column = {
    val dateFormat = "yyyyMMdd"
    when(laterDate.isNull.or(earlierDate.isNull), lit(null))
      .otherwise(datediff(to_date(laterDate, dateFormat), to_date(earlierDate, dateFormat)))

    * Checks if a string is ascii
    * @param input column to be checked
    * @return true if the input string is ascii otherwise false
  def is_ascii(input: Column): Column = {
    val asciiPattern = "[^\\x00-\\x7f]"
    when(input.isNull, lit(null)).otherwise(
      when(length(regexp_replace(input, asciiPattern, "")) < length(input), lit(false))

    * Checks if an input string contains only ascii code and numbers
    * @param input string to be checked
    * @return true if input string contains only ascii code and numbers or null if input is null

  def is_numeric_ascii(input: Column): Column = {
    when(input.isNull, lit(null))
        when((is_ascii(input) === lit(true))
               .and((string_is_numeric(input)) === lit(true))
               .and(length(input) > lit(0)),

    * Validates date against a input format
    * @param dateFormat A pattern such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` or `dd.MM.yyyy`
    * @param inDate     Input date to be validated
    * @return true if the input date is valid otherwise false
  def is_valid_date(dateFormat: String, inDate: Column): Column =
    //when(date_format(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))
    when(to_date(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))

    * Converts 1 digit julian year to 4 digits julian year.
    * @param in_date  date in Julian in "YJJJ" format
    * @param ref_date date in "yyyyMMdd" format
    * @return a date in "YYYYJJJ"

  def YJJJ_to_YYYYJJJ(in_date: Column, ref_date: Column): Column = {

    val y_in_date     = string_substring(in_date,  lit(1), lit(1))
    val y_ref_date    = string_substring(ref_date, lit(4), lit(1))
    val yyyy_ref_date = string_substring(ref_date, lit(1), lit(4))

    val yyyy_ref_date_tmp = when((y_in_date.notEqual(y_ref_date)).and(y_ref_date.notEqual(lit(0))),
                                 concat(string_substring(yyyy_ref_date, lit(1), lit(3)), y_in_date)

    val result = concat(yyyy_ref_date_tmp, string_substring(in_date, lit(2), lit(3)))
    when((to_date(ref_date, "yyyyMMdd").isNull), lit(null))
      .otherwise(when(to_date(result, "yyyyDDD").isNull, yyyyMMdd_to_YYYYJJJ(ref_date)).otherwise(result))


    * Converts yyyyyMMdd to YYYYJJJ
    * @param in_date date in yyyyMMdd format
    * @return a date converted to YYYYJJJ
  def yyyyMMdd_to_YYYYJJJ(in_date: Column) = {
    val day_of_year = string_lrepad(date_format(to_date(in_date, "yyyyMMdd"), "D"), 3,      "0")
    val year        = string_substring(in_date,                                     lit(1), lit(4))
    when(to_date(in_date, "yyyyMMdd").isNull, lit(null))
      .otherwise(concat(year, day_of_year))

    * Computes number of days in February month in a given year
    * @param year year whose number of days in February needs to be calculated
    * @return number of days
  def getFebruaryDay(year: Column): Column =
    when((year % 400 === lit(0)).or((year % 4 === lit(0)).and((year % 100).notEqual(lit(0)))), lit(29))


© 2015 - 2024 Weber Informatics LLC | Privacy Policy