All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prophecy.libs.SparkFunctions.scala Maven / Gradle / Ivy

There is a newer version: 6.3.0-3.3.0
Show newest version
/*
 * ====================================================================
 *
 * PROPHECY CONFIDENTIAL
 *
 * Prophecy Inc
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains
 * the property of Prophecy Inc, the intellectual and technical concepts contained
 * herein are proprietary to Prophecy Inc and may be covered by U.S. and Foreign Patents,
 * patents in process, and are protected by trade secret or copyright law.
 * Dissemination of this information or reproduction of this material
 * is strictly forbidden unless prior written permission is obtained
 * from Prophecy Inc.
 *
 * ====================================================================
 */
package io.prophecy.libs

import com.fasterxml.jackson.databind.ObjectMapper
import de.greenrobot.common.hash.Murmur3F
import io.prophecy.abinitio.ScalaFunctions._
import org.apache.commons.lang.StringEscapeUtils
import org.json.XML
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.crosssupport.SparkFunctionsCrossSupport
import org.apache.spark.sql.functions.{lit, _}
import org.apache.spark.sql.types._
import org.joda.time.format.DateTimeFormat

import java.io.File
import java.math.BigInteger
import java.net.URLEncoder
import java.nio.{ByteBuffer, ByteOrder}
import java.security.MessageDigest
import java.sql.Date
import java.text.{DecimalFormat,            DecimalFormatSymbols, SimpleDateFormat}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
import java.time.temporal.ChronoField
import java.time.{LocalDate, LocalDateTime, ZonedDateTime}
import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

object SparkFunctions {}

/**
  * Library of all spark functions which implements different abinitio functions used in abinitio workflows.
  */
trait SparkFunctions {
  def registerAllUDFs(spark: SparkSession): Unit = {
    spark.udf.register("string_pad",                string_pad)
    spark.udf.register("string_index_with_offset",  string_index_with_offset)
    spark.udf.register("string_filter",             string_filter)
    spark.udf.register("generate_sequence",         generate_sequence)
    spark.udf.register("re_get_match_with_index",   re_get_match_with_index)
    spark.udf.register("string_filter_out",         string_filter_out)
    spark.udf.register("number_grouping",           number_grouping)
    spark.udf.register("re_get_match",              re_get_match)
    spark.udf.register("re_index_with_offset",      re_index_with_offset)
    spark.udf.register("file_information",          file_information)
    spark.udf.register("re_index",                  re_index)
    spark.udf.register("string_split_no_empty",     string_split_no_empty)
    spark.udf.register("string_split",              string_split)
    spark.udf.register("type_info",                 type_info)
    spark.udf.register("record_info",               record_info)
    spark.udf.register("type_info_with_includes",   type_info_with_includes)
    spark.udf.register("record_info_with_includes", record_info_with_includes)
    spark.udf.register("read_file",                 read_file)
    spark.udf.register("unique_identifier",         unique_identifier)
    spark.udf.register("hash_MD5",                  hash_MD5)
    spark.udf.register("string_like",               string_like)
    spark.udf.register("re_split_no_empty",         re_split_no_empty)
    spark.udf.register("string_index",              string_index)
    spark.udf.register("string_rindex",             string_rindex)
    spark.udf.register("string_rindex_with_offset", string_rindex_with_offset)
    spark.udf.register("make_byte_flags",           make_byte_flags)
    spark.udf.register("translate_bytes",           translate_bytes)
    spark.udf.register("test_characters_all",       test_characters_all)
    spark.udf.register("string_convert_explicit",   string_convert_explicit)
    spark.udf.register("cross_join_index_range",    cross_join_index_range)
    spark.udf.register("getByteFromByteArray",      getByteFromByteArray)
    spark.udf.register("getShortFromByteArray",     getShortFromByteArray)
    spark.udf.register("getIntFromByteArray",       getIntFromByteArray)
    spark.udf.register("url_encode_escapes",        url_encode_escapes)
    spark.udf.register("force_error",               force_error)
    spark.udf.register("getLongFromByteArray",      getLongFromByteArray)
    spark.udf.register("getLongArrayFromByteArray", getLongArrayFromByteArray)
    spark.udf.register("murmur",                    murmur)
    spark.udf.register("eval",                      eval)
    spark.udf.register("truncateMicroSeconds",      truncateMicroSeconds)
    spark.udf.register("string_cleanse",            string_cleanse)
    spark.udf.register("canonical_representation",  canonical_representation)
    spark.udf.register("string_representation",     string_representation)
    spark.udf.register("datetime_difference",       datetime_difference)
    spark.udf.register("encode_date",               encode_date)
    spark.udf.register("date_month_end",            date_month_end)
    spark.udf.register("is_bzero",                  is_bzero)
    spark.udf.register("readBytesIntoLong",         readBytesIntoLong)
    spark.udf.register("readBytesIntoInteger",      readBytesIntoInteger)
    spark.udf.register("writeLongToBytes",          writeLongToBytes)
    spark.udf.register("decodeString",              decodeString)
    spark.udf.register("encodeString",              encodeString)
    spark.udf.register("writeIntegerToBytes",       writeIntegerToBytes)

    spark.udf.register(
      "starts_with",
      udf((input: String, prefix: String) ⇒ if (input == null || prefix == null) null else input.startsWith(prefix),
          BooleanType
      )
    )
    spark.udf.register(
      "ends_with",
      udf((input: String, suffix: String) ⇒ if (input == null || suffix == null) null else input.endsWith(suffix),
          BooleanType
      )
    )
    spark.udf.register("is_blank", udf((input: String) ⇒ if (input == null) null else input.trim.isEmpty, BooleanType))
    spark.udf.register("decimal_lpad",
                       udf { (input: Any, len: Int) ⇒
                         _decimal_lpad(input, len)
                       }
    )
    spark.udf.register("decimal_truncate",
                       udf { (input: Double, number_of_places: Int) ⇒
                         (input * Math.pow(10, number_of_places)).asInstanceOf[Long] / Math.pow(10, number_of_places)
                       }
    )

    //    spark.udf.register("re_match_replace")

    //    spark.udf.register("vector_sort_dedup_first", )

    //    spark.udf.register("string_substring", (input: Column, start_position: Column, length: Column) ⇒ input.substr(start_position, length))
    //    import org.apache.hadoop.hive.ql.exec.UDF;
    //    class SimpleUdf extends UDF {
    //      def string_substring(input: Column, start_position: Column, length: Column): Column =
    //        input.substr(start_position, length)
    //    }
    //    spark.sql("CREATE FUNCTION simple_udf AS 'SimpleUdf'")

  }

  /**
    * Method to find substring of input string.
    *
    * @param input          string on which to find substring.
    * @param start_position 1 based starting position to find substring from.
    * @param length         total length of substring to be found.
    * @return substring of input string
    */
  def string_substring(input: Column, start_position: Column, length: Column): Column =
    input.substr(start_position, length)

  def string_prefix(input: Column, length: Column): Column =
    when(length < lit(0), lit("")).otherwise(input.substr(lit(1), length))

  /**
    * function trims the string and then pad the string with given character upto given length.
    * if the length of trimmed string is equal to or greater than given length than it return input
    * string
    *
    * @param input            input string
    * @param len              length in number of characters.
    * @param char_to_pad_with A character used to pad input string to length len.
    * @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
    */
  def string_lrepad(input: Column, len: Int, char_to_pad_with: String = " "): Column = {
    when(isnull(input),                lit(null)).otherwise(
      when(length(trim(input)) >= len, input)
        .otherwise(lpad(trim(input),   len, char_to_pad_with))
    )
  }

  /**
    * Returns a number rounded up to a specified number of places to the right of the decimal point.
    *
    * @param input
    * @param places
    * @return
    */
  def decimal_round_up(input: Column, places: Int): Column =
    ceil(input * pow(lit(10), places)) / pow(lit(10), places)

  /**
    * Function returns a value which is rounded down to right_digits number of digits to the right of decimal point.
    *
    * @param input
    * @param right_digits
    * @return
    */
  def decimal_round_down(input: Column, right_digits: Int): Column =
    floor(input * pow(lit(10), right_digits)) / pow(lit(10), right_digits)

  def decimal_round(input: Column, places: Int): Column =
    when(decimal_round_up(input, places) - input > input - decimal_round_down(input, places),
         decimal_round_down(input, places)
    ).otherwise(
      when(decimal_round_up(input, places) - input < input - decimal_round_down(input, places),
           decimal_round_up(input, places)
      ).otherwise(
        when(abs(decimal_round_up(input, places)) > abs(decimal_round_down(input, places)),
             decimal_round_up(input, places)
        ).otherwise(decimal_round_down(input, places))
      )
    )

  def decimal_truncate(input: Column, number_of_places: Column): Column =
    (input * pow(lit(10), number_of_places)).cast(LongType) / pow(lit(10), number_of_places)

  /**
    * Returns the internal representation of a date resulting from adding (or subtracting) a number of months to the specified date.
    *
    * @param inputDate in yyyy-MM-dd format
    * @param months
    * @return
    */
  def date_add_months(inputDate: Column, months: Int) = add_months(inputDate, months)

  /**
    * integer values specifying days relative to January 1, 1900.
    * This function returns the internal representation of a date
    * given the year, month, and date.
    * encode_date returns the internal representation of the date specified by
    * the year 1998, the month 5, and the day 18:encode_date(1998, 5, 18) = 35931
    */
  val encode_date = udf { (year: Int, month: Int, day: Int) ⇒
    import java.time.LocalDate
    import java.time.format.DateTimeFormatter
    val formatter   = DateTimeFormatter.ofPattern("yyyy-MM-dd")
    val baseDate    = LocalDate.parse("1900-01-01", formatter)
    var monthString = s"$month"
    var dayString   = s"$day"
    var yearString  = s"$year"
    while (dayString.length < 2) dayString = "0" + dayString
    while (monthString.length < 2) monthString = "0" + monthString
    while (yearString.length < 4) yearString = "0" + yearString
    val endDate = LocalDate.parse(s"$yearString-$monthString-$dayString", formatter)
    endDate.toEpochDay - baseDate.toEpochDay
  }

  val date_month_end = udf({ (month: Int, year: Int) ⇒
                             if (month == null || year == null) null
                             else {
                               month match {
                                 case 1 | 3 | 5 | 7 | 8 | 10 | 12 ⇒ 31
                                 case 4 | 6 | 9 | 11              ⇒ 30
                                 case 2                           ⇒ if (year % 4 != 0) 28 else 29
                               }
                             }
                           },
                           IntegerType
  )

  /**
    * Returns the internal representation of a timestamp resulting from adding (or subtracting) a number of months to the specified timestamp.
    *
    * @param input timestamp in  yyyy-MM-dd HH:mm:ss.SSSS format
    * @param months
    * @return
    */
  def datetime_add_months(input: Column, months: Int): Column = add_months(input, months)

  /**
    * function trims the string and then pad the string on right side with given character upto given length.
    * if the length of trimmed string is equal to or greater than given length than it return input
    * string
    *
    * @param input            input string
    * @param len              length in number of characters.
    * @param char_to_pad_with A character used to pad input string to length len.
    * @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
    */
  def string_repad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
    when(isnull(input).or(length(trim(input)) >= len), input)
      .otherwise(rpad(trim(input),                     len, char_to_pad_with))

  /**
    * function pads input on the right with the character char_to_pad_with to make the string length len. If str is
    * already len or more characters long, the function returns input unmodified.
    *
    * @param input
    * @param len
    * @param char_to_pad_with
    * @return
    */
  def string_pad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
    when(isnull(input).or(length(input) >= len), input)
      .otherwise(rpad(input,                     len, char_to_pad_with))

  val string_pad = udf(
    { (input: String, len: Int) ⇒
      if (input == null || len == null) null
      else {
        if (input.length < len) input + (0 until (len - input.length)).map(" ").mkString("")
        else input
      }
    },
    StringType
  )

  val string_pad_with_char = udf(
    { (input: String, len: Int, char_to_pad_with: String) ⇒
      if (input == null || len == null || char_to_pad_with == null) null
      else {
        if (input.length < len) input + (0 until (len - input.length)).map(char_to_pad_with).mkString("")
        else input
      }
    },
    StringType
  )

  /**
    * Replaces all substrings in a target string that match a specified regular expression.
    *
    * @param target      A string that the function searches for a substring that matches pattern_expr.
    * @param pattern     regular expression
    * @param replacement replacement string
    * @param offset      Number of characters, from the beginning of str, to skip before searching.
    * @return a replaced string in which all substrings, which matches a specified regular expression, are replaced.
    */
  def re_replace(target: Column, pattern: String, replacement: String, offset: Int = 0): Column = {
    when(target.isNull, target).otherwise(
      concat(
        target.substr(1, offset),
        regexp_replace(
          target.substr(lit(offset + 1), length(target) - lit(offset)),
          pattern,
          replacement
        )
      )
    )
  }

  /**
    * Replaces only the first regex matching occurrence in the target string.
    *
    * @param target      A string that the function searches for a substring that matches pattern_expr.
    * @param pattern     regular expression
    * @param replacement replacement string
    * @return a replaced string in which first substring, which matches a specified regular expression, is replaced.
    */
  def re_replace_first(target: Column, pattern: String, replacement: String, offset: Column = lit(0)): Column = {
    when(target.isNull.or(!target.rlike(pattern)), target).otherwise(
      concat(
        regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 1),
        lit(replacement),
        regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 3)
      )
    )
  }

  /**
    * Left-pad the input string column with pad_char to a length of len. If length of input column is more than len
    * then returns input column unmodified.
    *
    * @param input
    * @param len
    * @param pad_char
    * @return
    */
  def string_lpad(input: Column, len: Int, pad_char: String = " "): Column =
    when(input.isNull || length(input) >= lit(len), input)
      .otherwise(lpad(input,                        len, pad_char))

  /**
    * Function uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9 ]*[0-9] and [1-9]+ regex
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [1-9][0-9]*(\\$$decimal_point_char)[0-9 ]+ and (0\\$$decimal_point_char)[0-9 ]*[0-9] regex
    *
    * After extracting decimal number this code looks for minus sign before extracted number in input and appends it
    * with decimal number if found minus sign.
    *
    * In the end it replaces all whitespaces with empty string in the final resultant decimal number.
    *
    * @param input              input string
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal from a string that has been trimmed of leading zeros and non-numeric characters.
    */
  def decimal_strip(input: Column, decimal_point_char: String = "."): Column = {
    val pattern =
      s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
    val sqlPattern = lit(pattern).expr.sql
    val columnSql  = input.expr.sql

    when(input.isNull, lit("0")).otherwise(
      when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
        when(
          instr(
            input.substr(
              lit(0),
              instr_udf(input, regexp_extract(input, pattern, 1))
            ),
            "-"
          ) > 0,
          concat(
            lit("-"),
            regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
          )
        ).otherwise(
          regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
        )
      )
    )
  }

  //  def decimal_strip(input: Column, columnName: String, decimal_point_char: String = "."): Column = {
  //    val pattern =
  //      s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
  //    val sqlPattern = lit(pattern).expr.sql
  //
  //    when(input.isNull, lit("0")).otherwise(
  //      when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
  //        when(
  //          instr(
  //            input.substr(
  //              lit(0),
  //              expr(
  //                s"instr($columnName, regexp_extract($columnName, $sqlPattern, 1))"
  //              )
  //            ),
  //            "-"
  //          ) > 0,
  //          concat(
  //            lit("-"),
  //            regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
  //          )
  //        ).otherwise(
  //          regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
  //        )
  //      )
  //    )
  //  }

  val instr_udf = udf({ (inputStr: String, substring: String) ⇒
                        if (inputStr == null || substring == null) null
                        else
                          inputStr.indexOf(substring) + 1
                      },
                      IntegerType
  )

  /**
    * UDF to find index of seekStr in inputStr from offset index onwards.
    * Returned string position is 1 based position.
    */
  val string_index_with_offset = udf(
    (inputStr: String, seekStr: String, offset: Int) ⇒ {
      if (inputStr == null || seekStr == null || offset == null) null
      else {
        if (offset <= inputStr.length) {
          val index = inputStr.substring(offset).indexOf(seekStr)
          if (index < 0) {
            0
          } else {
            offset + index + 1
          }
        } else {
          0
        }
      }
    },
    IntegerType
  )

  /**
    * Method which returns string of characters present in both of the strings in the same order as appearing in first
    * string
    */
  val string_filter = udf(
    (inputStr1: String, inputStr2: String) ⇒ {
      if (inputStr1 == null || inputStr2 == null) null
      else {
        val str2Set = inputStr2.toCharArray.toSet
        inputStr1.toCharArray.filter(str2Set.contains(_)).mkString("")
      }
    },
    StringType
  )

  /**
    * Function to replace occurrence of seekStr with newStr string in input string after offset characters from first character.
    *
    * @param input   input string on which to perform replace operation.
    * @param seekStr string to be replaced in input string.
    * @param newStr  string to be used instead of seekStr in input string.
    * @param offset  number of characters to skip from begining in input string before performing string_replace operation.
    * @return modified string where seekStr is replaced with newStr in input string.
    */
  def string_replace(input: Column, seekStr: Column, newStr: Column, offset: Column = lit(0)): Column = {
    when(isnull(input),             lit(null)).otherwise(
      when(length(input) <= offset, input).otherwise(
        concat(
          string_substring(input, lit(0), offset),
          regexp_replace(
            string_substring(input, offset + 1, length(input)),
            seekStr,
            newStr
          )
        )
      )
    )
  }

  val string_replace_first = udf({ (input: String, seekStr: String, newStr: String) ⇒
                                   if (input == null || seekStr == null || newStr == null) null
                                   else
                                     input.replaceFirst(seekStr, newStr)

                                 },
                                 StringType
  )

  val string_replace_in_loop = udf(
    { (input: String, seekStrs: Seq[String], newStrs: Seq[String]) ⇒
      if (input == null || seekStrs == null || newStrs == null) null
      else {
        var currentInput = input
        (0 until Math.min(seekStrs.length, newStrs.length)).foreach { itr ⇒
          currentInput = currentInput.replaceAll(seekStrs(itr), newStrs(itr))
        }
        currentInput
      }
    },
    StringType
  )

  /**
    * Method which returns true if input string contains all alphabetic characters, or false otherwise.
    *
    * @param input
    * @return
    */
  def string_is_alphabetic(input: Column): Column = input.rlike("^[a-zA-Z]+$")

  /**
    * Method which returns true if input string contains all numeric characters, or false otherwise.
    *
    * @param input
    * @return
    */
  def string_is_numeric(input: Column): Column =
    when(input.rlike("^[0-9]+$").or(length(input) === lit(0)), lit(true))
      .otherwise(lit(false))

  /**
    * Returns true if string columns starts with given prefix
    */
  def starts_with(input: Column, prefix: String): Column =
    input.startsWith(prefix)

  /**
    * Returns true if string columns ends with given suffix
    */
  def ends_with(input: Column, suffix: String): Column =
    input.endsWith(suffix)

  /**
    * UDF to generate column with sequence of integers between two passed start and end columns.
    */
  val generate_sequence = udf { (start: Int, end: Int) ⇒
    (start to end).toArray
  }

  /**
    * Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9]*[0-9] and [1-9]+ regex
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [1-9][0-9]*(\\\\$$decimal_point_char)[0-9]+ and (0\\\\$$decimal_point_char)[0-9]*[0-9] regex
    *
    * After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
    * If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
    * decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
    * part of decimal number.
    *
    * @param input              input string.
    * @param len                length of characters.
    * @param char_to_pad_with   character to left pad with. default value is "0"
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
    *         trimmed of leading zeros.
    */
  def decimal_lrepad(
    input:              Column,
    len:                Int,
    char_to_pad_with:   String = "0",
    decimal_point_char: String = "."
  ): Column = {
    val pattern =
      s"""(-?)([1-9][0-9]*(\\$decimal_point_char)[0-9]+|[1-9][0-9]+|[1-9]+|(0\\$decimal_point_char)[0-9]*[0-9])"""

    when(input.isNull, lit(null)).otherwise(
      when(
        length(regexp_extract(input, pattern, 0)) >= len,
        regexp_extract(input, pattern, 0)
      ).otherwise(
        lpad(
          regexp_replace(
            lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),
            "\\-",
            ""
          ),
          len,
          "-"
        )
      )
    )
  }

  /**
    * Method to return character code of character at index position in inputStr string.
    *
    * @param inputStr input string
    * @param index    location of character to get code.
    * @return integer column.
    */
  def string_char(inputStr: Column, index: Int): Column = {
    when(
      (lit(index) <= lit(0)).or(lit(index) > length(inputStr)).or(inputStr.isNull),
      lit(null)
    ).otherwise(
      ascii(inputStr.substr(lit(index), length(inputStr) - lit(index) + 1))
    )
  }

  /**
    * Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
    * 1. Simple integral number. e.g. 013334848. This part is identified by [[0-9]+ regex.
    * 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
    * [0-9]+(\\$$decimal_point_char)[0-9]+ and (0\\$$decimal_point_char)[0-9]+ regex
    *
    * After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
    * If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
    * decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
    * part of decimal number.
    *
    * @param input              input string.
    * @param len                length of characters.
    * @param char_to_pad_with   character to left pad with. default value is "0"
    * @param decimal_point_char A string that specifies the character that represents the decimal point.
    * @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
    *         trimmed of leading zeros.
    */
  def decimal_lpad(
    input:              Column,
    len:                Int,
    char_to_pad_with:   String = "0",
    decimal_point_char: String = "."
  ): Column = {
    val pattern =
      s"""(-?)([0-9]+(\\$decimal_point_char)[0-9]+|[0-9]+|(0\\$decimal_point_char)[0-9]+)"""

    when(input.isNull, lit(null)).otherwise(
      when(
        length(regexp_extract(input, pattern, 0)) >= len,
        regexp_extract(input, pattern, 0)
      ).otherwise(
        lpad(
          regexp_replace(
            lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),
            "\\-",
            ""
          ),
          len,
          "-"
        )
      )
    )
  }

  val re_get_match_with_index = udf(
    { (input: String, pattern: String, offset: Int) ⇒
      if (input == null || pattern == null || offset == null) null
      else {
        val result =
          pattern.r
            .findAllMatchIn(input.substring(offset))
            .map(x ⇒ input.substring(x.start + offset, x.end + offset))
        if (result.hasNext) {
          result.next()
        } else null
      }
    },
    StringType
  )

  /**
    * Compares two input strings, then returns characters that appear in one string but not in the other.
    */
  val string_filter_out = udf(
    (input1: String, input2: String) ⇒ {
      if (input1 == null || input2 == null) null
      else {
        val str2Set = input2.toCharArray.toSet
        input1.toCharArray.filter(!str2Set.contains(_)).mkString("")
      }
    },
    StringType
  )

  /**
    * udf to group input decimal into multiple groups separated by separator
    */
  val number_grouping = udf(
    { (input: String, groups: Int, seperator: String) ⇒
      if (input == null || groups == null || seperator == null) null
      else {
        val result = input.split("\\.").map(group_number(_, groups, ',')).mkString(".")
        if (seperator != ",") result.replaceAll(",", seperator) else result
      }
    },
    StringType
  )

  private def group_number(input: String, groups: Int, seperator: Char): String = {
    val symbols = new DecimalFormatSymbols()
    symbols.setGroupingSeparator(seperator)
    val dfDecimal = new DecimalFormat("###########0.0###")
    dfDecimal.setDecimalFormatSymbols(symbols)
    dfDecimal.setGroupingSize(groups)
    dfDecimal.setGroupingUsed(true)
    val result = dfDecimal.format(BigDecimal(input))
    result.substring(0, result.length - 2)
  }

  /**
    * Returns the first string in a target string that matches a regular expression.
    */
  val re_get_match = udf { (input: String, pattern: String) ⇒
    _re_get_match(input, pattern)
  }
  val decodeString =
    udf { (input: String, charSet: String) ⇒
      new String(input.getBytes("windows-1252"), charSet)
    }

  val decodeBytes =
    udf { (input: Array[Byte], charSet: String) ⇒
      new String(input, charSet)
    }

  val encodeString =
    udf { (input: String, charSet: String) ⇒
      val result = new String(input.getBytes(charSet), "windows-1252")
      result
    }

  val encodeBytes =
    udf { (input: Array[Byte], charSet: String) ⇒
      new String(input).getBytes(charSet)
    }

  val writeLongToBytes =
    udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _writeLongToBytes(input, length, endian, isUnsigned)
    }

  val writeIntegerToBytes =
    udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _writeIntegerToBytes(input, length, endian, isUnsigned)
    }

  val readBytesStringIntoLong =
    udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesStringIntoLong(input, length, endian, isUnsigned)
    }

  val readBytesIntoLong =
    udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesIntoLong(input, length, endian, isUnsigned)
    }

  val readBytesIntoInteger =
    udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesIntoInteger(input, length, endian, isUnsigned)
    }

  val readBytesStringIntoInteger =
    udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
      _readBytesStringIntoInteger(input, length, endian, isUnsigned)
    }

  val packedBytesToDecimal = udf { (input: Array[Byte], scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
    _packedBytesToDecimal(input, scale, isUnsigned, isStripped)
  }

  val packedBytesStringToDecimal = udf { (input: String, scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
    _packedBytesStringToDecimal(input, scale, isUnsigned, isStripped)
  }

  val bigDecimalToPackedBytes = udf {
    (input: java.math.BigDecimal, precision: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
      _bigDecimalToPackedBytes(input, precision, isUnsigned, isStripped)
  }

  /**
    * Returns the first string in a target string that matches a regular expression.
    */
  val re_index_with_offset = udf { (input: String, pattern: String, offset: Int) ⇒
    val result =
      pattern.r.findAllMatchIn(input.substring(offset - 1)).map(_.start)
    if (result.hasNext) {
      result.next() + offset
    } else {
      0
    }
  }

  def datetime_from_unixtime(seconds: Column): Column =
    from_unixtime(seconds, "yyyyMMddHHmmssSSSSSS")

  /**
    * UDF to get file information for passed input file path.
    */
  val file_information = udf { (inputPath: String) ⇒
    SparkFunctionsCrossSupport.file_information_def(inputPath)
  }

  def getMTimeDataframe(filepath: String, format: String, spark: SparkSession): DataFrame = {
    import spark.implicits._
    Seq(
      (filepath,
       DateTimeFormat
         .forPattern(format)
         .print(SparkFunctionsCrossSupport.file_information_def(filepath).modified * 1000)
      )
    ).toDF("file", "mtime")
  }

  /**
    * UDF wrapper over re_index function.
    */
  val re_index = udf { (input: String, pattern: String) ⇒
    val result = pattern.r.findAllMatchIn(input.substring(0)).map(_.start)
    if (result.hasNext) {
      result.next() + 1
    } else {
      0
    }
  }

  def string_suffix(input: Column, len: Int): Column =
    when(lit(len) >= length(input),       input)
      .otherwise(when(lit(len) <= lit(0), lit("")).otherwise(input.substr(length(input) - lit(len) + lit(1), lit(len))))

  def scanf_long(format: Column, value: Column): Column =
    when(format === lit("%Ld"), value.cast(LongType)).otherwise(lit(null))

  def scanf_double(format: Column, value: Column): Column =
    when(format === lit("%lf"), value.cast(DoubleType)).otherwise(lit(null))

  /**
    * Concatenates the elements of column using the delimiter.
    */
  def string_join(column: Column, delimiter: String): Column =
    SparkFunctionsCrossSupport.string_join(column, delimiter)

  /**
    * Method to zip two arrays with first one having event_type and second one having event_text
    *
    * @param column1
    * @param column2
    * @return
    */
  def zip_eventInfo_arrays(column1: Column, column2: Column): Column =
    SparkFunctionsCrossSupport.zip_eventInfo_arrays(column1, column2)

  /**
    * Method to get field at specific position from struct column
    *
    * @param column
    * @return
    */
  def getFieldFromStructByPosition(column: Column, position: Int): Column =
    SparkFunctionsCrossSupport.getFieldFromStructByPosition(column, position)

  /**
    * UDF to split input string via delimiter string and remove all empty subtrings.
    */
  val string_split_no_empty = udf { (input: String, delimiter: String) ⇒
    input.split(delimiter).filterNot(_.isEmpty)
  }

  def replace_null_with_blank(input: Column): Column = when(input.isNull, lit("")).otherwise(input)

  /**
    * UDF to split input string via delimiter string.
    */
  val string_split = udf { (input: String, delimiter: String) ⇒
    input.split(delimiter)
  }

  val type_info = udf { (dml_type: String) ⇒
    _type_info(dml_type)
  }

  val record_info = udf { (dml_type: String) ⇒
    _record_info(dml_type)
  }

  val type_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
    _type_info(dml_type, includes)
  }

  val record_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
    _record_info(dml_type, includes)
  }

  val read_file = udf { (filePath: String) ⇒
    _read_file(filePath)
  }

  val unique_identifier = udf { () ⇒
    java.util.UUID.randomUUID.toString
  }

  val hash_MD5 = udf { (input: Any) ⇒
    val inputBytes    = if (input.isInstanceOf[String]) input.asInstanceOf[String].getBytes else getByteArray(input)
    val messageDigest = MessageDigest.getInstance("MD5")
    messageDigest.reset()
    messageDigest.update(inputBytes)
    messageDigest.digest()
  }

  val string_representation = udf { (input: Any) ⇒
    _string_representation(input)
  }

  val canonical_representation =
    udf(
      (input: Any) ⇒ {
        val (encoding, value): (Short, String) = if (input == null) {
          (0, "NULL")
        } else if (input.isInstanceOf[String]) {
          try {
            val rep = new String(input.toString.getBytes, "utf-8")
            (1, rep)
          } catch {
            case _: Throwable ⇒
              (2, input.toString)
          }
        } else if (input.isInstanceOf[Row] || input.isInstanceOf[Seq[_]]) {
          (3, "")
        } else {
          (1, input.toString)
        }
        Row(
          encoding,
          value
        )
      },
      StructType(
        List(
          StructField("encoding", ShortType,  false),
          StructField("value",    StringType, false)
        )
      )
    )

  /**
    * Method to test whether a string matches a specified pattern. This function returns 1 if the input string matches
    * a specified pattern, and 0 if the string does not match the pattern.
    *
    * In abinitio version % character in pattern means to match zero or more characters and _ character means matches a
    * single character.
    */
  val string_like = udf { (input: String, pattern: String) ⇒
    input.matches(pattern.replaceAll("%", "(.*)").replaceAll("_", "."))
  }

  /**
    * UDF to split input string via pattern string and remove all empty subtrings.
    */
  val re_split_no_empty = udf { (input: String, pattern: String) ⇒
    input.split(pattern).filterNot(_.isEmpty)
  }

  /**
    * UDF to find index of seekStr in inputStr. Returned index will be
    * 1 based index.
    */
  val string_index = udf { (inputStr: String, seekStr: String) ⇒
    inputStr.indexOf(seekStr) + 1
  }

  /**
    * Returns the index of the first character of the last occurrence of a seek string within another input string.
    * Returned index is 1 based.
    */
  val string_rindex = udf { (inputStr: String, seekStr: String) ⇒
    inputStr.lastIndexOf(seekStr) + 1
  }

  /**
    * UDF to find index of seekStr in inputStr from end of inputStr skipping offset number of characters from end. Offset index is number
    * of characters, from the end of str, to skip before searching. Returned string position is 1 based position.
    */
  val string_rindex_with_offset = udf((inputStr: String, seekStr: String, offset: Int) ⇒ {
    if (offset <= inputStr.length) {
      val index =
        inputStr.substring(0, inputStr.length - offset).lastIndexOf(seekStr)
      index + 1
    } else {
      0
    }
  })

  /**
    * UDF to return a flag for each character if it is present or not in input String.
    */
  val make_byte_flags = udf { (inputStr: String) ⇒
    val charFlag = new Array[Int](256)
    inputStr.toCharArray.foreach(c ⇒ charFlag(c.toInt) = 1)
    charFlag
  }

  /**
    * UDF to return a string in the native character set made up of bytes from the given map. Each byte of the result is
    * the value of map indexed by the character code of the corresponding byte of the input string str. The function
    * returns NULL if any argument is NULL.
    */
  val translate_bytes = udf((inputStr: String, byteMap: Seq[Byte]) ⇒ {
    if (inputStr == null || byteMap == null)
      null
    else {
      val translatedChars = inputStr.getBytes.map { c ⇒
        if (c < byteMap.length && c >= 0)
          byteMap(c)
        else c
      }
      new String(translatedChars)
    }
  })

  /**
    * UDF to identify the number of characters in inputStr which are present in charFlag
    */
  val test_characters_all = udf { (inputStr: String, charFlag: Seq[Int]) ⇒
    inputStr.toCharArray.filter(c ⇒ charFlag(c.toInt) == 1).size
  }

  /**
    * Converts a string from one character set to another, replacing inconvertible characters with a specified string.
    */
  val string_convert_explicit = udf { (input: String, charSet: String, replaceStr: String) ⇒
    new String(input.getBytes(), charSet).replaceAll("�", replaceStr)
  }

  /**
    * This implementation is incorrect.
    */
  val string_cleanse = udf { (input: String, replacement: String, charSet: String) ⇒
    new String(input.getBytes(), charSet).replaceAll("�", replacement)
  }

  val decode_datetime_as_local =
    udf(
      (time: String) ⇒ {

        val datetimeInfo = _decode_datetime(time)
        Row(
          datetimeInfo.year,
          datetimeInfo.month,
          datetimeInfo.day,
          datetimeInfo.hour,
          datetimeInfo.minute,
          datetimeInfo.second,
          ZonedDateTime.now.getOffset.getTotalSeconds / 60
        )
      },
      StructType(
        List(
          StructField("year",            IntegerType, false),
          StructField("month",           IntegerType, false),
          StructField("day",             IntegerType, false),
          StructField("hour",            IntegerType, false),
          StructField("minute",          IntegerType, false),
          StructField("second",          IntegerType, false),
          StructField("microsecond",     IntegerType, false),
          StructField("timezone_offset", ShortType,   false)
        )
      )
    )

  val xmlToJSON = udf { (input: String) ⇒
    val jsonObj    = XML.toJSONObject(input)
    val newJsonObj = if (jsonObj.length() == 1) jsonObj.getJSONObject(jsonObj.keys().next()) else jsonObj
    newJsonObj.toString
  }

  def from_xml(content: Column, schema: StructType): Column = {
    val json = xmlToJSON(content)
    from_json(json, schema)
  }

  def numberOfPartitions(in: DataFrame): Column =
    lit(in.rdd.getNumPartitions)

  def findLastElement(input: Column, default: Column = lit(null)): Column = {
    val lastElement = SparkFunctionsCrossSupport.element_at(reverse(input), lit(1))
    when(lastElement.isNull, default).otherwise(lastElement)
  }

  def string_length(input: Column) = length(input)

  def findFirstElement(input: Column, default: Column = lit(null)): Column = {
    val firstElement = SparkFunctionsCrossSupport.element_at(input, lit(1))
    when(firstElement.isNull, default).otherwise(firstElement)
  }

  def findFirstNonBlankElement(input: Column, default: Column): Column = {
    val firstElement = SparkFunctionsCrossSupport.element_at(
      SparkFunctionsCrossSupport.array_except(input, array(lit(""), lit(" "), lit("  "), lit("   "))),
      lit(1)
    )
    when(firstElement.isNull, default).otherwise(firstElement)
  }

  def getColumnInSecondArrayByFirstNonBlankPositionInFirstArray(
    nonBlankEntryExpr: Column,
    firstArray:        Column,
    secondArray:       Column
  ): Column = {
    SparkFunctionsCrossSupport.element_at(
      secondArray,
      SparkFunctionsCrossSupport
        .array_position(
          firstArray,
          SparkFunctionsCrossSupport.element_at(nonBlankEntryExpr, lit(1))
        )
        .cast(IntegerType)
    )
  }

  def schemaRowCompareResult(row1: StructType, row2: StructType): Column = {
    val value1 = concat(flattenStructSchema(row1): _*)
    val value2 = concat(flattenStructSchema(row2): _*)
    value1 === value2
  }

  def flattenStructSchema(schema: StructType, prefix: String = null): Array[Column] = {
    schema.fields.flatMap { field ⇒
      val columnName = if (prefix == null) field.name else (prefix + "." + field.name)

      field.dataType match {
        case st: StructType ⇒ flattenStructSchema(st, columnName)
        case _ ⇒ Array(col(columnName).as(columnName.replace(".", "_")))
      }
    }
  }

  /**
    * UDF to get record of type decode_datetime_type. This record will have all its fields populated with
    * corresponding entries in input date/timestamp.
    *
    * Returned record will have following schema.
    *
    * integer(8) year;
    * integer(8) month;
    * integer(8) day;
    * integer(8) hour;
    * integer(8) minute;
    * integer(8) second;
    * integer(8) microsecond;
    *
    * Note: Supported Input time is in yyyy-MM-dd HH:mm:ss.SSSSSS or yyyy-MM-dd HH:mm:ss or yyyy-MM-dd formats only.
    * Additional handling is done to support timestamp retrieved from now() function call.
    */
  val decode_datetime =
    udf(
      (time: String) ⇒ {

        val datetimeInfo = _decode_datetime(time)
        Row(
          datetimeInfo.year,
          datetimeInfo.month,
          datetimeInfo.day,
          datetimeInfo.hour,
          datetimeInfo.minute,
          datetimeInfo.second,
          datetimeInfo.microsecond
        )
      },
      StructType(
        List(
          StructField("year",        IntegerType, false),
          StructField("month",       IntegerType, false),
          StructField("day",         IntegerType, false),
          StructField("hour",        IntegerType, false),
          StructField("minute",      IntegerType, false),
          StructField("second",      IntegerType, false),
          StructField("microsecond", IntegerType, false)
        )
      )
    )

  val cross_join_index_range = udf(
    (input1: Seq[Int], input2: Seq[Int]) ⇒ input1.flatMap(i ⇒ input2.map(j ⇒ Row.fromSeq(Seq(i, j)))),
    ArrayType(
      StructType(
        List(
          StructField("i", IntegerType, false),
          StructField("j", IntegerType, false)
        )
      )
    )
  )

  /**
    * Method to identify if input string is a blank string or not.
    *
    * @param input input string.
    * @return return 1 if given string contains all blank character or is a zero length string,
    *         otherwise it returns 0
    */
  def is_blank(input: Column): Column =
    when(isnull(input),                               lit(null))
      .otherwise(when(length(trim(input)) === lit(0), lit(true)).otherwise(lit(false)))

  /**
    * Method to list all files present in a passed path where filename starts with passed
    * prefix.
    *
    * @param path
    * @param filePrefix
    * @return
    */
  private def listFiles(path: String, filePrefix: String) = {
    new File(path)
      .listFiles()
      .filter(_.isFile)
      .map(_.getName)
      .filter(_.startsWith(filePrefix))
  }

  def directory_listing(path: String, filePrefix: String) =
    typedLit(listFiles(path, filePrefix))

  private def getDataTypeFromIntLength(lenValue: Int, dataType: DataType) = {
    lenValue match {
      case 1 ⇒ ByteType
      case 2 ⇒ ShortType
      case 4 ⇒ IntegerType
      case 8 ⇒ LongType
      case _ ⇒ dataType
    }
  }

  def is_valid(input: Column): Column =
    is_valid(input, false)

  def is_valid(input: Column, isNullable: Boolean): Column =
    is_valid(input, isNullable, None, None)

  def is_valid(input: Column, formatInfo: Option[Any]): Column =
    is_valid(input, false, formatInfo, None)

  def is_valid(input: Column, formatInfo: Option[Any], len: Option[Seq[Int]]): Column =
    is_valid(input, false, formatInfo, len)

  def is_valid(input: Column, isNullable: Boolean, formatInfo: Option[Any]): Column =
    is_valid(input, isNullable, formatInfo, None)

  /**
    * Method to identify if passed input column is a valid expression after typecasting to passed dataType.
    * Also while typecasting if len is present then this function also makes sure the max length of input column
    * after typecasting operation is not greater than len.
    *
    * @param input      input column expression to be identified if is valid.
    * @param formatInfo datatype to which input column expression must be typecasted.
    *                   If datatype is a string then it is treated as timestamp format. If it is
    *                   a list of string then it is treated as having current timestamp format and
    *                   and new timestamp format to which input column needs to be typecasted.
    * @param len        max length of input column after typecasting it to dataType.
    * @return 0 if input column is not valid after typecasting or 1 if it is valid.
    */
  def is_valid(
    input:      Column,
    isNullable: Boolean,
    formatInfo: Option[Any],
    len:        Option[Seq[Int]]
  ): Column = {
    if (isNullable) lit(true)
    else {
      val result: Column = len match {
        case None | Some(Nil) ⇒
          formatInfo match {
            case Some(typecast: DataType) ⇒ input.cast(typecast)
            case Some(format: String) ⇒ to_timestamp(input, format)
            case Some(List(currentFormat: String, newFormat: String)) ⇒
              date_format(to_timestamp(input, currentFormat), newFormat)
            case _ ⇒ input // TODO: support other types
          }
        case Some(lenValue) ⇒
          formatInfo match {
            case Some(x: DecimalType) ⇒
              if (lenValue.length == 2) {
                input.cast(DecimalType(lenValue(0), lenValue(1)))
              } else if (lenValue.length == 1) {
                input.cast(DecimalType(lenValue(0), 0))
              } else {
                input.cast(DoubleType)
              }
            case Some(x: IntegerType) ⇒ input.cast(getDataTypeFromIntLength(lenValue.head, x))
            case Some(y: StringType) ⇒
              when(length(input.cast(y)) > lit(lenValue.head), lit(null))
                .otherwise(input.cast(y))
            case Some(z: DoubleType) ⇒
              when(input.cast(DecimalType(lenValue.head, 0)).isNull, lit(null))
                .otherwise(
                  input.cast(DecimalType(lenValue.head, 0)).cast(StringType).cast(z)
                )
            case _ ⇒ input // TODO: support other types
          }
      }
      when(result.isNull, lit(false)).otherwise(lit(true))
    }
  }

  /**
    * Method to get current timestamp.
    *
    * @return current timestamp in YYYYMMddHHmmssSSSSSS format.
    */
  def now(): Column = date_format(current_timestamp(), "yyyyMMddHHmmssSSSSSS")

  /**
    * Returns the number of hours between two specified dates in standard format yyyy-MM-dd HH:mm:ss.SSSS.
    *
    * @param end
    * @param start
    * @return
    */
  def datetime_difference_hours(end: Column, start: Column): Column =
    (unix_timestamp(end) - unix_timestamp(start)) / lit(3600.0)

  val datetime_difference =
    udf(
      { (timestamp1: String, timestamp2: String) ⇒
        val endTime        = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp1).getTime
        val startTime      = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp2).getTime()
        val differenceTime = endTime - startTime
        val days           = differenceTime / (24 * 60 * 60 * 1000)
        var remainingTime  = differenceTime % (24 * 60 * 60 * 1000)
        val hours          = remainingTime / (60 * 60 * 1000)
        remainingTime = remainingTime % (60 * 60 * 1000)
        val minutes = remainingTime / (60 * 1000)
        remainingTime = remainingTime % (60 * 1000)
        val seconds = remainingTime / 1000
        remainingTime = remainingTime % 1000
        val microSeconds = remainingTime * 1000
        Row(days, hours, minutes, seconds, microSeconds)
      },
      StructType(
        List(
          StructField("days",         LongType, true),
          StructField("hours",        LongType, true),
          StructField("minutes",      LongType, true),
          StructField("seconds",      LongType, true),
          StructField("microseconds", LongType, true)
        )
      )
    )

  def string_compare(input1: Column, input2: Column): Column =
    when(input1.isNull.or(input2.isNull), lit(null))
      .otherwise(when(input1 < input2,    lit(-1)).otherwise(when(input1 > input2, lit(1)).otherwise(lit(0))))

  /**
    * Method to convert
    *
    * @param timezone
    * @param time
    * @return
    */
  def timezone_to_utc(timezone: String, time: Column): Column =
    if (timezone == "local") to_utc_timestamp(time, TimeZone.getDefault.toZoneId.getId)
    else to_utc_timestamp(time,                     timezone)

  /**
    * Method to return integer value representing number of days to today from “1-1-1990”.
    *
    * @return integer value
    */
  def today(): Column =
    datediff(current_date(), to_date(lit("1990-01-01"), "yyyy-MM-dd"))

  /**
    * Method to check if current column is null or has empty value.
    *
    * @param input
    * @return
    */
  def isNullOrEmpty(input: Column): Column =
    when(input.isNull || length(trim(input)) === lit(0), lit(true))
      .otherwise(lit(false))

  /**
    * Function to create sequence of array between two passed numbers
    *
    * @param start starting point of generated sequence
    * @param end   terminating point of generated sequence.
    * @return column containing sequence of integers.
    */
  def generate_sequence(start: Int, end: Int, step: Int = 1): Column =
    array((start to end).map(lit): _*)

  /**
    * Method to create array of size "size" containing seedVal as each entry
    *
    * @param size
    * @param seedVal
    * @return
    */
  def make_constant_vector(size: Int, seedVal: Column): Column =
    array((1 to size).map(x ⇒ seedVal): _*)

  /**
    * Method to create array of size "size" containing seedVal as each entry
    *
    * @param size
    * @param seedVal
    * @return
    */
  def make_constant_vector(size: Int, seedVal: Int) =
    (1 to size).map(x ⇒ seedVal).toArray

  /**
    * UDF to break input string into multiple string via delimiter. Number of strings after split are adjusted as per
    * passed width parameter. If number of strings are less then empty strings are added otherwise in case of more number
    * of strings, first width number of entries are picked and remaining are discarded.
    */
  val splitIntoMultipleColumnsUdf =
    udf { (input: String, delimiter: String, width: Int) ⇒
      val entries = input.split(delimiter)
      (0 until width)
        .map(index ⇒
          if (entries.length > index) {
            entries(index)
          } else {
            ""
          }
        )
        .toArray
    }

  /**
    * Method to create dataframe with single column containing increasing sequence id from start to end.
    *
    * @param start
    * @param end
    * @param columnName
    * @param sparkSession
    * @return
    */
  def generateDataFrameWithSequenceColumn(
    start:        Int,
    end:          Int,
    columnName:   String,
    sparkSession: SparkSession
  ): DataFrame = {
    import sparkSession.implicits._

    val arrayDF = Seq(((start to end).toArray)).toDF(columnName)
    arrayDF.select(explode(col(columnName)).as(columnName))
  }

  class StringAsStream(val content: String) extends Serializable {
    var currentOffset: Int = 0

    def read_string(len: Int): String = {
      if (currentOffset >= content.length || currentOffset + len > content.length)
        null
      else {
        val result = content.substring(currentOffset, currentOffset + len)
        currentOffset = currentOffset + len
        result
      }
    }
  }

  def getContentAsStream(content: String): StringAsStream =
    new StringAsStream(content)

  /**
    * Method to read values from inputData and create dataframe with column name as columnName and column type as
    * columnType for the values in inputData delimiter by delimiter.
    *
    * @param inputData
    * @param delimiter
    * @param columnName
    * @param columnType
    * @param sparkSession
    * @return
    */
  def createDataFrameFromData(
    inputData:    String,
    delimiter:    String,
    columnName:   String,
    columnType:   String,
    sparkSession: SparkSession
  ): DataFrame = {
    val schema = StructType(
      List(StructField(columnName, getColumnDataType(columnType)))
    )

    val rowData =
      inputData.split(delimiter).map(value ⇒ Row.fromSeq(List(value))).toList
    sparkSession.createDataFrame(rowData.asJava, schema)
  }

  /**
    * Method to get spark data type.
    *
    * @param columnType
    * @return
    */
  private def getColumnDataType(columnType: String): DataType = {
    columnType match {
      case "StringType"  ⇒ StringType
      case "IntegerType" ⇒ IntegerType
      case "DoubleType"  ⇒ DoubleType
      case "BooleanType" ⇒ BooleanType
      case _             ⇒ StringType
    }
  }

  /**
    * Method to get ByteArray from any passed object.
    *
    * @param input
    * @return
    */
  private def getByteArray(input: Any): Array[Byte] = {
    lazy val default = new ObjectMapper().writeValueAsBytes(input)

    input match {
      case value:  String ⇒ value.getBytes("windows-1252")
      case value:  Long   ⇒ BigInt(value).toByteArray
      case value:  Int    ⇒ BigInt(value).toByteArray
      case value:  Short  ⇒ BigInt(value).toByteArray
      case value:  Byte ⇒ BigInt(value).toByteArray
      case schema: GenericRowWithSchema if schema.size > 0 ⇒
        schema.get(0) match {
          case LongWrappedArray(array) ⇒ array.filterNot(_ == 0).flatMap(x ⇒ BigInt(x).toByteArray).toArray
          case _                       ⇒ default
        }
      case _ ⇒ default
    }
  }

  /**
    * Tests whether an object is composed of all binary zero bytes.
    * This function returns:
    * 1. 1 if obj contains only binary zero bytes or is a zero-length string
    * 2. 0 if obj contains any non-zero bytes
    * 3. NULL if obj is NULL
    */
  val is_bzero = udf { (input: Any) ⇒
    if (input.isInstanceOf[String] && input.asInstanceOf[String].isEmpty) true
    else if (input.isInstanceOf[Int] && input.asInstanceOf[Int] == 0) true
    else false
  }

  /**
    * UDF to get last Byte from ByteArray of input data.
    */
  val getByteFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    byteArray(byteArray.length - 1)
  }

  /**
    * UDF to get short comprising of last 2 Bytes from ByteArray of input data.
    */
  val getShortFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 2, byteArray.length)
    new BigInteger(slicedByteArray).shortValue()
  }

  /**
    * UDF to get integer comprising of last 4 Bytes from ByteArray of input data.
    */
  val getIntFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 4, byteArray.length)
    new BigInteger(slicedByteArray).intValue()
  }
  val url_encode_escapes = udf { (input: String) ⇒
    URLEncoder.encode(input, "UTF-8")
  }

  // TODO support for threshold!
  val force_error = udf { (input: String) ⇒
    throw new Exception(s"$input")
  }

  /**
    * UDF to get long comprising of last 8 Bytes from ByteArray of input data.
    */
  val getLongFromByteArray = udf { (input: Any, offset: Int) ⇒
    var byteArray = getByteArray(input)
    byteArray = byteArray.slice(offset, byteArray.length)
    val slicedByteArray =
      byteArray.slice(byteArray.length - 8, byteArray.length)
    new BigInteger(slicedByteArray).longValue()
  }

  /**
    * Method to check if all characters of passed input are digit or not.
    *
    * @param input
    * @return
    */
  private def isAllDigits(input: String): Boolean =
    input.forall(Character.isDigit)

  object LongWrappedArray {
    def unapply(input: Any): Option[mutable.WrappedArray[Long]] = input match {
      case array: mutable.WrappedArray[_] if array.forall(_.isInstanceOf[Long]) ⇒
        Some(array.asInstanceOf[mutable.WrappedArray[Long]])
      case _ ⇒ None
    }
  }

  object LongSequence {
    def unapply(input: Any): Option[Seq[Long]] = input match {
      case sequence: Seq[_] if sequence.forall(_.isInstanceOf[Long]) ⇒ Some(sequence.asInstanceOf[Seq[Long]])
      case _ ⇒ None
    }
  }

  /**
    * Method used for abinitio's reinterpret_as function to read necessary bytes from byteArray for input data and convert
    * into struct format as per provided in typeInfo sequence.
    *
    * TypeInfo can have multiple entries, each could be either decimal or string type. Depending on the argument passed
    * within decimal or string bytes are read from input byte array.
    *
    * If decimal or string argument has some integer then that many bytes are read from input byte array or if decimal or
    * string has some string delimiter as its argument then from the current position bytes are read until string delimiter
    * is found in input byte array.
    *
    * @param input
    * @param typeInfo
    * @return
    */
  def convertInputBytesToStructType(input: Any, typeInfo: Seq[String], startByte: Int = 0): Row = input match {
    case LongSequence(sequence) ⇒
      Row.fromSeq(sequence)
    case _ ⇒
      var curPointer = 0
      var byteArray  = getByteArray(input)
      byteArray = byteArray.slice(startByte, byteArray.length)
      val stringVal = new String(byteArray, "windows-1252")
      val rowValues = ListBuffer[Any]()
      typeInfo.map { curType ⇒
        if (curType.startsWith("decimal") || curType.startsWith("string")) {
          val stInd  = curType.indexOf("(")
          val endInd = curType.indexOf(")")
          val arg    = curType.substring(stInd + 1, endInd)
          if (isAllDigits(arg)) {
            val takeUntil = Math.min(Integer.parseInt(arg) + curPointer, stringVal.length)
            rowValues += stringVal.substring(curPointer, takeUntil)
            curPointer = takeUntil
          } else {
            val trimmedArg = arg.trim
            val pattern    = trimmedArg.substring(1, trimmedArg.length - 1)
            val index      = stringVal.substring(curPointer).indexOf(pattern)
            rowValues += stringVal.substring(curPointer, index + curPointer)
            curPointer = curPointer + index + pattern.length
          }
        } else if (curType == "long") {
          if (curPointer >= byteArray.length) {
            rowValues += 0L
          } else {
            val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
            rowValues += new BigInteger(slicedByteArray).longValue()
            curPointer = curPointer + 8
          }
        }
      }
      Row.fromSeq(rowValues)
  }

  /**
    * UDF to get long comprising of last 8 Bytes from ByteArray of input data.
    */
  val getLongArrayFromByteArray = udf(
    (input: Any, size: Int, offset: Int) ⇒ {
      var byteArray = getByteArray(input)
      byteArray = byteArray.slice(offset, byteArray.length)
      var curPointer = 0
      (0 until size).map { itr ⇒
        if (curPointer >= byteArray.length) {
          0L
        } else {
          val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
          curPointer = curPointer + 8
          new BigInteger(slicedByteArray).longValue()
        }
      }
    },
    DataTypes.createArrayType(LongType, false)
  )

  /**
    * UDF for murmur hash generation for any column type
    */
  val murmur = udf { (x: Any) ⇒
    val murmurInstance = new Murmur3F()
    if (x.isInstanceOf[GenericRowWithSchema]) {
      val input          = x.asInstanceOf[GenericRowWithSchema]
      val inputByteArray = (0 until input.length).map(input.get(_)).toArray.flatMap(_.toString.getBytes())
      murmurInstance.update(inputByteArray)
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
    } else if (x.isInstanceOf[Seq[Any]]) {
      val inputByteArray = x.asInstanceOf[Seq[Any]].flatMap(_.toString.getBytes()).toArray
      murmurInstance.update(inputByteArray)
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
    } else {
      murmurInstance.update(x.toString.getBytes())
      Array(murmurInstance.getValue, murmurInstance.getValueHigh)
    }
  //    if (x.isInstanceOf[GenericRowWithSchema] && x
  //          .asInstanceOf[GenericRowWithSchema]
  //          .get(0)
  //          .isInstanceOf[mutable.WrappedArray[Any]]) {
  //      MurmurHash3.arrayHash(x.asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[mutable.WrappedArray[Any]].toArray)
  //    } else if (x.isInstanceOf[GenericRowWithSchema] && x
  //                 .asInstanceOf[GenericRowWithSchema]
  //                 .length > 0) {
  //      MurmurHash3.arrayHash(
  //        (0 until x.asInstanceOf[GenericRowWithSchema].length).map(x.asInstanceOf[GenericRowWithSchema].get(_)).toArray
  //      )
  //    } else if (x.isInstanceOf[Seq[Any]]) MurmurHash3.arrayHash(x.asInstanceOf[Array[Any]])
  //    else if (x.isInstanceOf[String]) MurmurHash3.stringHash(x.asInstanceOf[String])
  }

  /**
    * Method to return the result of evaluating a string expression in the context of a specified input column. Here
    * input column could be struct type record, simple column, array type etc.
    * Here expr could be reference to nested column inside input column or any expression which requires values from input
    * column for its evaulation.
    *
    * Note: Current implementation only supports scenerio where input column is of struct type and expr is simply dot
    * separated column reference to input struct.
    *
    * @param input
    * @param expr
    * @return
    */
  val eval = udf((input: Any, expr: String) ⇒ {
    if (input == null || expr == null || !input.isInstanceOf[Row]) null
    else {
      val tokens = expr.split("\\.")
      if (tokens.length == 2) {
        val result = input.asInstanceOf[Row].getAs[Row](tokens(0)).getAs[Any](tokens(1))
        if (result != null) {
          result.toString
        } else {
          null
        }
      } else if (tokens.length == 1) {
        val result = input.asInstanceOf[Row].getAs[Any](tokens(0))
        if (result != null) {
          result.toString
        } else {
          null
        }
      } else {
        null
      }
    }
  })

  /**
    * UDF to truncate microseconds part of timestamp. This is needed as abinitio and spark has some incompatibility in
    * microseconds part of timestamp format.
    */
  val truncateMicroSeconds = udf((sourceFormat: String, timestamp: String) ⇒ {
    if (timestamp == null || sourceFormat == null) null
    else {
      val index = sourceFormat.indexOf("S")
      if (index < 0) {
        timestamp
      } else {
        timestamp.substring(0, index)
      }
    }
  })

  /**
    * Method to replace String Columns with Empty value to Null.
    *
    * @param input
    * @return
    */
  def replaceBlankColumnWithNull(input: Column): Column =
    when(length(input.cast(StringType)) === lit(0), lit(null)).otherwise(input)

  /**
    * Method to identify and return first non null expression.
    *
    * @param expr1
    * @param expr2
    * @return
    */
  def first_defined(expr1: Column, expr2: Column): Column =
    when(expr1.isNull, expr2).otherwise(expr1)

  /**
    * Adds an explicit sign to the number. E.g.
    * 2 -> +2; -004 -> -004; 0 -> +0
    */
  def sign_explicit(c: Column): Column =
    when(c.startsWith("+") || c.startsWith("-"), c)
      .otherwise(when(c >= 0,                    concat(lit("+"), c)).otherwise(concat(lit("-"), c)))

  def from_sv(input: Column, separator: String, schema: StructType): Column = {
    val splitResult = split(input, separator)
    val fields = schema.fields.zipWithIndex.map {
      case (field, index) ⇒ splitResult.getItem(index).as(field.name)
    }

    struct(fields: _*).as("value")
  }

  /**
    * Method removes any non-digit characters from the specified string column.
    *
    * @param input input String Column
    * @return Cleaned string column or null
    */
  def remove_non_digit(input: Column): Column = {
    val pattern = "[^\\d]"
    when(input.isNull, lit(null)).otherwise(regexp_replace(input, pattern, ""))
  }

  /**
    * Computes number of days between two specified dates in "yyyyMMdd" format
    *
    * @param laterDate   input date
    * @param earlierDate input date
    * @return number of days between laterDate and earlierDate or null if either one is null
    */
  def date_difference_days(laterDate: Column, earlierDate: Column): Column = {
    val dateFormat = "yyyyMMdd"
    when(laterDate.isNull.or(earlierDate.isNull), lit(null))
      .otherwise(datediff(to_date(laterDate, dateFormat), to_date(earlierDate, dateFormat)))
  }

  /**
    * Checks if a string is ascii
    *
    * @param input column to be checked
    * @return true if the input string is ascii otherwise false
    */
  def is_ascii(input: Column): Column = {
    val asciiPattern = "[^\\x00-\\x7f]"
    when(input.isNull, lit(null)).otherwise(
      when(length(regexp_replace(input, asciiPattern, "")) < length(input), lit(false))
        .otherwise(lit(true))
    )
  }

  /**
    * Checks if an input string contains only ascii code and numbers
    *
    * @param input string to be checked
    * @return true if input string contains only ascii code and numbers or null if input is null
    */

  def is_numeric_ascii(input: Column): Column = {
    when(input.isNull, lit(null))
      .otherwise(
        when((is_ascii(input) === lit(true))
               .and((string_is_numeric(input)) === lit(true))
               .and(length(input) > lit(0)),
             lit(true)
        ).otherwise(lit(false))
      )
  }

  /**
    * Validates date against a input format
    *
    * @param dateFormat A pattern such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` or `dd.MM.yyyy`
    * @param inDate     Input date to be validated
    * @return true if the input date is valid otherwise false
    */
  def is_valid_date(dateFormat: String, inDate: Column): Column =
    //when(date_format(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))
    when(to_date(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))

  /**
    * Converts 1 digit julian year to 4 digits julian year.
    *
    * @param in_date  date in Julian in "YJJJ" format
    * @param ref_date date in "yyyyMMdd" format
    * @return a date in "YYYYJJJ"
    */

  def YJJJ_to_YYYYJJJ(in_date: Column, ref_date: Column): Column = {

    val y_in_date     = string_substring(in_date,  lit(1), lit(1))
    val y_ref_date    = string_substring(ref_date, lit(4), lit(1))
    val yyyy_ref_date = string_substring(ref_date, lit(1), lit(4))

    val yyyy_ref_date_tmp = when((y_in_date.notEqual(y_ref_date)).and(y_ref_date.notEqual(lit(0))),
                                 concat(string_substring(yyyy_ref_date, lit(1), lit(3)), y_in_date)
    ).otherwise(yyyy_ref_date)

    val result = concat(yyyy_ref_date_tmp, string_substring(in_date, lit(2), lit(3)))
    when((to_date(ref_date, "yyyyMMdd").isNull), lit(null))
      .otherwise(when(to_date(result, "yyyyDDD").isNull, yyyyMMdd_to_YYYYJJJ(ref_date)).otherwise(result))

  }

  /**
    * Converts yyyyyMMdd to YYYYJJJ
    *
    * @param in_date date in yyyyMMdd format
    * @return a date converted to YYYYJJJ
    */
  def yyyyMMdd_to_YYYYJJJ(in_date: Column) = {
    val day_of_year = string_lrepad(date_format(to_date(in_date, "yyyyMMdd"), "D"), 3,      "0")
    val year        = string_substring(in_date,                                     lit(1), lit(4))
    when(to_date(in_date, "yyyyMMdd").isNull, lit(null))
      .otherwise(concat(year, day_of_year))
  }

  /**
    * Computes number of days in February month in a given year
    *
    * @param year year whose number of days in February needs to be calculated
    * @return number of days
    */
  def getFebruaryDay(year: Column): Column =
    when((year % 400 === lit(0)).or((year % 4 === lit(0)).and((year % 100).notEqual(lit(0)))), lit(29))
      .otherwise(lit(28))

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy