Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ====================================================================
*
* PROPHECY CONFIDENTIAL
*
* Prophecy Inc
* All Rights Reserved.
*
* NOTICE: All information contained herein is, and remains
* the property of Prophecy Inc, the intellectual and technical concepts contained
* herein are proprietary to Prophecy Inc and may be covered by U.S. and Foreign Patents,
* patents in process, and are protected by trade secret or copyright law.
* Dissemination of this information or reproduction of this material
* is strictly forbidden unless prior written permission is obtained
* from Prophecy Inc.
*
* ====================================================================
*/
package io.prophecy.libs
import com.fasterxml.jackson.databind.ObjectMapper
import de.greenrobot.common.hash.Murmur3F
import io.prophecy.abinitio.ScalaFunctions._
import org.apache.commons.lang.StringEscapeUtils
import org.json.XML
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.crosssupport.SparkFunctionsCrossSupport
import org.apache.spark.sql.functions.{lit, _}
import org.apache.spark.sql.types._
import org.joda.time.format.DateTimeFormat
import java.io.File
import java.math.BigInteger
import java.net.URLEncoder
import java.nio.{ByteBuffer, ByteOrder}
import java.security.MessageDigest
import java.sql.Date
import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
import java.time.temporal.ChronoField
import java.time.{LocalDate, LocalDateTime, ZonedDateTime}
import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
object SparkFunctions {}
/**
* Library of all spark functions which implements different abinitio functions used in abinitio workflows.
*/
trait SparkFunctions {
def registerAllUDFs(spark: SparkSession): Unit = {
spark.udf.register("string_pad", string_pad)
spark.udf.register("string_index_with_offset", string_index_with_offset)
spark.udf.register("string_filter", string_filter)
spark.udf.register("generate_sequence", generate_sequence)
spark.udf.register("re_get_match_with_index", re_get_match_with_index)
spark.udf.register("string_filter_out", string_filter_out)
spark.udf.register("number_grouping", number_grouping)
spark.udf.register("re_get_match", re_get_match)
spark.udf.register("re_index_with_offset", re_index_with_offset)
spark.udf.register("file_information", file_information)
spark.udf.register("re_index", re_index)
spark.udf.register("string_split_no_empty", string_split_no_empty)
spark.udf.register("string_split", string_split)
spark.udf.register("type_info", type_info)
spark.udf.register("record_info", record_info)
spark.udf.register("type_info_with_includes", type_info_with_includes)
spark.udf.register("record_info_with_includes", record_info_with_includes)
spark.udf.register("read_file", read_file)
spark.udf.register("unique_identifier", unique_identifier)
spark.udf.register("hash_MD5", hash_MD5)
spark.udf.register("string_like", string_like)
spark.udf.register("re_split_no_empty", re_split_no_empty)
spark.udf.register("string_index", string_index)
spark.udf.register("string_rindex", string_rindex)
spark.udf.register("string_rindex_with_offset", string_rindex_with_offset)
spark.udf.register("make_byte_flags", make_byte_flags)
spark.udf.register("translate_bytes", translate_bytes)
spark.udf.register("test_characters_all", test_characters_all)
spark.udf.register("string_convert_explicit", string_convert_explicit)
spark.udf.register("cross_join_index_range", cross_join_index_range)
spark.udf.register("getByteFromByteArray", getByteFromByteArray)
spark.udf.register("getShortFromByteArray", getShortFromByteArray)
spark.udf.register("getIntFromByteArray", getIntFromByteArray)
spark.udf.register("url_encode_escapes", url_encode_escapes)
spark.udf.register("force_error", force_error)
spark.udf.register("getLongFromByteArray", getLongFromByteArray)
spark.udf.register("getLongArrayFromByteArray", getLongArrayFromByteArray)
spark.udf.register("murmur", murmur)
spark.udf.register("eval", eval)
spark.udf.register("truncateMicroSeconds", truncateMicroSeconds)
spark.udf.register("string_cleanse", string_cleanse)
spark.udf.register("canonical_representation", canonical_representation)
spark.udf.register("string_representation", string_representation)
spark.udf.register("datetime_difference", datetime_difference)
spark.udf.register("encode_date", encode_date)
spark.udf.register("date_month_end", date_month_end)
spark.udf.register("is_bzero", is_bzero)
spark.udf.register("readBytesIntoLong", readBytesIntoLong)
spark.udf.register("readBytesIntoInteger", readBytesIntoInteger)
spark.udf.register("writeLongToBytes", writeLongToBytes)
spark.udf.register("decodeString", decodeString)
spark.udf.register("encodeString", encodeString)
spark.udf.register("writeIntegerToBytes", writeIntegerToBytes)
spark.udf.register(
"starts_with",
udf((input: String, prefix: String) ⇒ if (input == null || prefix == null) null else input.startsWith(prefix),
BooleanType
)
)
spark.udf.register(
"ends_with",
udf((input: String, suffix: String) ⇒ if (input == null || suffix == null) null else input.endsWith(suffix),
BooleanType
)
)
spark.udf.register("is_blank", udf((input: String) ⇒ if (input == null) null else input.trim.isEmpty, BooleanType))
spark.udf.register("decimal_lpad",
udf { (input: Any, len: Int) ⇒
_decimal_lpad(input, len)
}
)
spark.udf.register("decimal_truncate",
udf { (input: Double, number_of_places: Int) ⇒
(input * Math.pow(10, number_of_places)).asInstanceOf[Long] / Math.pow(10, number_of_places)
}
)
// spark.udf.register("re_match_replace")
// spark.udf.register("vector_sort_dedup_first", )
// spark.udf.register("string_substring", (input: Column, start_position: Column, length: Column) ⇒ input.substr(start_position, length))
// import org.apache.hadoop.hive.ql.exec.UDF;
// class SimpleUdf extends UDF {
// def string_substring(input: Column, start_position: Column, length: Column): Column =
// input.substr(start_position, length)
// }
// spark.sql("CREATE FUNCTION simple_udf AS 'SimpleUdf'")
}
/**
* Method to find substring of input string.
*
* @param input string on which to find substring.
* @param start_position 1 based starting position to find substring from.
* @param length total length of substring to be found.
* @return substring of input string
*/
def string_substring(input: Column, start_position: Column, length: Column): Column =
input.substr(start_position, length)
def string_prefix(input: Column, length: Column): Column =
when(length < lit(0), lit("")).otherwise(input.substr(lit(1), length))
/**
* function trims the string and then pad the string with given character upto given length.
* if the length of trimmed string is equal to or greater than given length than it return input
* string
*
* @param input input string
* @param len length in number of characters.
* @param char_to_pad_with A character used to pad input string to length len.
* @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
*/
def string_lrepad(input: Column, len: Int, char_to_pad_with: String = " "): Column = {
when(isnull(input), lit(null)).otherwise(
when(length(trim(input)) >= len, input)
.otherwise(lpad(trim(input), len, char_to_pad_with))
)
}
/**
* Returns a number rounded up to a specified number of places to the right of the decimal point.
*
* @param input
* @param places
* @return
*/
def decimal_round_up(input: Column, places: Int): Column =
ceil(input * pow(lit(10), places)) / pow(lit(10), places)
/**
* Function returns a value which is rounded down to right_digits number of digits to the right of decimal point.
*
* @param input
* @param right_digits
* @return
*/
def decimal_round_down(input: Column, right_digits: Int): Column =
floor(input * pow(lit(10), right_digits)) / pow(lit(10), right_digits)
def decimal_round(input: Column, places: Int): Column =
when(decimal_round_up(input, places) - input > input - decimal_round_down(input, places),
decimal_round_down(input, places)
).otherwise(
when(decimal_round_up(input, places) - input < input - decimal_round_down(input, places),
decimal_round_up(input, places)
).otherwise(
when(abs(decimal_round_up(input, places)) > abs(decimal_round_down(input, places)),
decimal_round_up(input, places)
).otherwise(decimal_round_down(input, places))
)
)
def decimal_truncate(input: Column, number_of_places: Column): Column =
(input * pow(lit(10), number_of_places)).cast(LongType) / pow(lit(10), number_of_places)
/**
* Returns the internal representation of a date resulting from adding (or subtracting) a number of months to the specified date.
*
* @param inputDate in yyyy-MM-dd format
* @param months
* @return
*/
def date_add_months(inputDate: Column, months: Int) = add_months(inputDate, months)
/**
* integer values specifying days relative to January 1, 1900.
* This function returns the internal representation of a date
* given the year, month, and date.
* encode_date returns the internal representation of the date specified by
* the year 1998, the month 5, and the day 18:encode_date(1998, 5, 18) = 35931
*/
val encode_date = udf { (year: Int, month: Int, day: Int) ⇒
import java.time.LocalDate
import java.time.format.DateTimeFormatter
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
val baseDate = LocalDate.parse("1900-01-01", formatter)
var monthString = s"$month"
var dayString = s"$day"
var yearString = s"$year"
while (dayString.length < 2) dayString = "0" + dayString
while (monthString.length < 2) monthString = "0" + monthString
while (yearString.length < 4) yearString = "0" + yearString
val endDate = LocalDate.parse(s"$yearString-$monthString-$dayString", formatter)
endDate.toEpochDay - baseDate.toEpochDay
}
val date_month_end = udf({ (month: Int, year: Int) ⇒
if (month == null || year == null) null
else {
month match {
case 1 | 3 | 5 | 7 | 8 | 10 | 12 ⇒ 31
case 4 | 6 | 9 | 11 ⇒ 30
case 2 ⇒ if (year % 4 != 0) 28 else 29
}
}
},
IntegerType
)
/**
* Returns the internal representation of a timestamp resulting from adding (or subtracting) a number of months to the specified timestamp.
*
* @param input timestamp in yyyy-MM-dd HH:mm:ss.SSSS format
* @param months
* @return
*/
def datetime_add_months(input: Column, months: Int): Column = add_months(input, months)
/**
* function trims the string and then pad the string on right side with given character upto given length.
* if the length of trimmed string is equal to or greater than given length than it return input
* string
*
* @param input input string
* @param len length in number of characters.
* @param char_to_pad_with A character used to pad input string to length len.
* @return string of a specified length, trimmed of leading and trailing blanks and left-padded with a given character.
*/
def string_repad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
when(isnull(input).or(length(trim(input)) >= len), input)
.otherwise(rpad(trim(input), len, char_to_pad_with))
/**
* function pads input on the right with the character char_to_pad_with to make the string length len. If str is
* already len or more characters long, the function returns input unmodified.
*
* @param input
* @param len
* @param char_to_pad_with
* @return
*/
def string_pad(input: Column, len: Int, char_to_pad_with: String = " "): Column =
when(isnull(input).or(length(input) >= len), input)
.otherwise(rpad(input, len, char_to_pad_with))
val string_pad = udf(
{ (input: String, len: Int) ⇒
if (input == null || len == null) null
else {
if (input.length < len) input + (0 until (len - input.length)).map(" ").mkString("")
else input
}
},
StringType
)
val string_pad_with_char = udf(
{ (input: String, len: Int, char_to_pad_with: String) ⇒
if (input == null || len == null || char_to_pad_with == null) null
else {
if (input.length < len) input + (0 until (len - input.length)).map(char_to_pad_with).mkString("")
else input
}
},
StringType
)
/**
* Replaces all substrings in a target string that match a specified regular expression.
*
* @param target A string that the function searches for a substring that matches pattern_expr.
* @param pattern regular expression
* @param replacement replacement string
* @param offset Number of characters, from the beginning of str, to skip before searching.
* @return a replaced string in which all substrings, which matches a specified regular expression, are replaced.
*/
def re_replace(target: Column, pattern: String, replacement: String, offset: Int = 0): Column = {
when(target.isNull, target).otherwise(
concat(
target.substr(1, offset),
regexp_replace(
target.substr(lit(offset + 1), length(target) - lit(offset)),
pattern,
replacement
)
)
)
}
/**
* Replaces only the first regex matching occurrence in the target string.
*
* @param target A string that the function searches for a substring that matches pattern_expr.
* @param pattern regular expression
* @param replacement replacement string
* @return a replaced string in which first substring, which matches a specified regular expression, is replaced.
*/
def re_replace_first(target: Column, pattern: String, replacement: String, offset: Column = lit(0)): Column = {
when(target.isNull.or(!target.rlike(pattern)), target).otherwise(
concat(
regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 1),
lit(replacement),
regexp_extract(target.substr(offset, length(target) - offset), s"([^$pattern]*)($pattern)(.*)", 3)
)
)
}
/**
* Left-pad the input string column with pad_char to a length of len. If length of input column is more than len
* then returns input column unmodified.
*
* @param input
* @param len
* @param pad_char
* @return
*/
def string_lpad(input: Column, len: Int, pad_char: String = " "): Column =
when(input.isNull || length(input) >= lit(len), input)
.otherwise(lpad(input, len, pad_char))
/**
* Function uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
* 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9 ]*[0-9] and [1-9]+ regex
* 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
* [1-9][0-9]*(\\$$decimal_point_char)[0-9 ]+ and (0\\$$decimal_point_char)[0-9 ]*[0-9] regex
*
* After extracting decimal number this code looks for minus sign before extracted number in input and appends it
* with decimal number if found minus sign.
*
* In the end it replaces all whitespaces with empty string in the final resultant decimal number.
*
* @param input input string
* @param decimal_point_char A string that specifies the character that represents the decimal point.
* @return a decimal from a string that has been trimmed of leading zeros and non-numeric characters.
*/
def decimal_strip(input: Column, decimal_point_char: String = "."): Column = {
val pattern =
s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
val sqlPattern = lit(pattern).expr.sql
val columnSql = input.expr.sql
when(input.isNull, lit("0")).otherwise(
when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
when(
instr(
input.substr(
lit(0),
instr_udf(input, regexp_extract(input, pattern, 1))
),
"-"
) > 0,
concat(
lit("-"),
regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
)
).otherwise(
regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
)
)
)
}
// def decimal_strip(input: Column, columnName: String, decimal_point_char: String = "."): Column = {
// val pattern =
// s"""([1-9][0-9]*(\\$decimal_point_char)[0-9 ]+|[1-9][0-9 ]+|[1-9]+|(0\\$decimal_point_char)[0-9 ]*[0-9])"""
// val sqlPattern = lit(pattern).expr.sql
//
// when(input.isNull, lit("0")).otherwise(
// when(regexp_extract(input, pattern, 1) === "", "0").otherwise(
// when(
// instr(
// input.substr(
// lit(0),
// expr(
// s"instr($columnName, regexp_extract($columnName, $sqlPattern, 1))"
// )
// ),
// "-"
// ) > 0,
// concat(
// lit("-"),
// regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
// )
// ).otherwise(
// regexp_replace(regexp_extract(input, pattern, 1), "\\s+", "")
// )
// )
// )
// }
val instr_udf = udf({ (inputStr: String, substring: String) ⇒
if (inputStr == null || substring == null) null
else
inputStr.indexOf(substring) + 1
},
IntegerType
)
/**
* UDF to find index of seekStr in inputStr from offset index onwards.
* Returned string position is 1 based position.
*/
val string_index_with_offset = udf(
(inputStr: String, seekStr: String, offset: Int) ⇒ {
if (inputStr == null || seekStr == null || offset == null) null
else {
if (offset <= inputStr.length) {
val index = inputStr.substring(offset).indexOf(seekStr)
if (index < 0) {
0
} else {
offset + index + 1
}
} else {
0
}
}
},
IntegerType
)
/**
* Method which returns string of characters present in both of the strings in the same order as appearing in first
* string
*/
val string_filter = udf(
(inputStr1: String, inputStr2: String) ⇒ {
if (inputStr1 == null || inputStr2 == null) null
else {
val str2Set = inputStr2.toCharArray.toSet
inputStr1.toCharArray.filter(str2Set.contains(_)).mkString("")
}
},
StringType
)
/**
* Function to replace occurrence of seekStr with newStr string in input string after offset characters from first character.
*
* @param input input string on which to perform replace operation.
* @param seekStr string to be replaced in input string.
* @param newStr string to be used instead of seekStr in input string.
* @param offset number of characters to skip from begining in input string before performing string_replace operation.
* @return modified string where seekStr is replaced with newStr in input string.
*/
def string_replace(input: Column, seekStr: Column, newStr: Column, offset: Column = lit(0)): Column = {
when(isnull(input), lit(null)).otherwise(
when(length(input) <= offset, input).otherwise(
concat(
string_substring(input, lit(0), offset),
regexp_replace(
string_substring(input, offset + 1, length(input)),
seekStr,
newStr
)
)
)
)
}
val string_replace_first = udf({ (input: String, seekStr: String, newStr: String) ⇒
if (input == null || seekStr == null || newStr == null) null
else
input.replaceFirst(seekStr, newStr)
},
StringType
)
val string_replace_in_loop = udf(
{ (input: String, seekStrs: Seq[String], newStrs: Seq[String]) ⇒
if (input == null || seekStrs == null || newStrs == null) null
else {
var currentInput = input
(0 until Math.min(seekStrs.length, newStrs.length)).foreach { itr ⇒
currentInput = currentInput.replaceAll(seekStrs(itr), newStrs(itr))
}
currentInput
}
},
StringType
)
/**
* Method which returns true if input string contains all alphabetic characters, or false otherwise.
*
* @param input
* @return
*/
def string_is_alphabetic(input: Column): Column = input.rlike("^[a-zA-Z]+$")
/**
* Method which returns true if input string contains all numeric characters, or false otherwise.
*
* @param input
* @return
*/
def string_is_numeric(input: Column): Column =
when(input.rlike("^[0-9]+$").or(length(input) === lit(0)), lit(true))
.otherwise(lit(false))
/**
* Returns true if string columns starts with given prefix
*/
def starts_with(input: Column, prefix: String): Column =
input.startsWith(prefix)
/**
* Returns true if string columns ends with given suffix
*/
def ends_with(input: Column, suffix: String): Column =
input.endsWith(suffix)
/**
* UDF to generate column with sequence of integers between two passed start and end columns.
*/
val generate_sequence = udf { (start: Int, end: Int) ⇒
(start to end).toArray
}
/**
* Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
* 1. Simple integral number. e.g. 013334848. This part is identified by combination of [1-9][0-9]*[0-9] and [1-9]+ regex
* 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
* [1-9][0-9]*(\\\\$$decimal_point_char)[0-9]+ and (0\\\\$$decimal_point_char)[0-9]*[0-9] regex
*
* After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
* If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
* decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
* part of decimal number.
*
* @param input input string.
* @param len length of characters.
* @param char_to_pad_with character to left pad with. default value is "0"
* @param decimal_point_char A string that specifies the character that represents the decimal point.
* @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
* trimmed of leading zeros.
*/
def decimal_lrepad(
input: Column,
len: Int,
char_to_pad_with: String = "0",
decimal_point_char: String = "."
): Column = {
val pattern =
s"""(-?)([1-9][0-9]*(\\$decimal_point_char)[0-9]+|[1-9][0-9]+|[1-9]+|(0\\$decimal_point_char)[0-9]*[0-9])"""
when(input.isNull, lit(null)).otherwise(
when(
length(regexp_extract(input, pattern, 0)) >= len,
regexp_extract(input, pattern, 0)
).otherwise(
lpad(
regexp_replace(
lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),
"\\-",
""
),
len,
"-"
)
)
)
}
/**
* Method to return character code of character at index position in inputStr string.
*
* @param inputStr input string
* @param index location of character to get code.
* @return integer column.
*/
def string_char(inputStr: Column, index: Int): Column = {
when(
(lit(index) <= lit(0)).or(lit(index) > length(inputStr)).or(inputStr.isNull),
lit(null)
).otherwise(
ascii(inputStr.substr(lit(index), length(inputStr) - lit(index) + 1))
)
}
/**
* Method uses a java regex to identify decimal numbers from input string. This decimal number could be of 3 types
* 1. Simple integral number. e.g. 013334848. This part is identified by [[0-9]+ regex.
* 2. decimal number with explicit decimal point. e.g. 123456.90. This part is identified by combination of
* [0-9]+(\\$$decimal_point_char)[0-9]+ and (0\\$$decimal_point_char)[0-9]+ regex
*
* After extracting decimal number this code checks if length of decimal number is more than len parameter or not.
* If length is more than len parameter then it simply returns this extracted decimal number. Otherwise it first left pad
* decimal number with char_to_pad_with to make its length equal to len parameter and then adjusts minus sign (-) to left most
* part of decimal number.
*
* @param input input string.
* @param len length of characters.
* @param char_to_pad_with character to left pad with. default value is "0"
* @param decimal_point_char A string that specifies the character that represents the decimal point.
* @return a decimal string of the specified length or longer, left-padded with a specified character as needed and
* trimmed of leading zeros.
*/
def decimal_lpad(
input: Column,
len: Int,
char_to_pad_with: String = "0",
decimal_point_char: String = "."
): Column = {
val pattern =
s"""(-?)([0-9]+(\\$decimal_point_char)[0-9]+|[0-9]+|(0\\$decimal_point_char)[0-9]+)"""
when(input.isNull, lit(null)).otherwise(
when(
length(regexp_extract(input, pattern, 0)) >= len,
regexp_extract(input, pattern, 0)
).otherwise(
lpad(
regexp_replace(
lpad(regexp_extract(input, pattern, 0), len, char_to_pad_with),
"\\-",
""
),
len,
"-"
)
)
)
}
val re_get_match_with_index = udf(
{ (input: String, pattern: String, offset: Int) ⇒
if (input == null || pattern == null || offset == null) null
else {
val result =
pattern.r
.findAllMatchIn(input.substring(offset))
.map(x ⇒ input.substring(x.start + offset, x.end + offset))
if (result.hasNext) {
result.next()
} else null
}
},
StringType
)
/**
* Compares two input strings, then returns characters that appear in one string but not in the other.
*/
val string_filter_out = udf(
(input1: String, input2: String) ⇒ {
if (input1 == null || input2 == null) null
else {
val str2Set = input2.toCharArray.toSet
input1.toCharArray.filter(!str2Set.contains(_)).mkString("")
}
},
StringType
)
/**
* udf to group input decimal into multiple groups separated by separator
*/
val number_grouping = udf(
{ (input: String, groups: Int, seperator: String) ⇒
if (input == null || groups == null || seperator == null) null
else {
val result = input.split("\\.").map(group_number(_, groups, ',')).mkString(".")
if (seperator != ",") result.replaceAll(",", seperator) else result
}
},
StringType
)
private def group_number(input: String, groups: Int, seperator: Char): String = {
val symbols = new DecimalFormatSymbols()
symbols.setGroupingSeparator(seperator)
val dfDecimal = new DecimalFormat("###########0.0###")
dfDecimal.setDecimalFormatSymbols(symbols)
dfDecimal.setGroupingSize(groups)
dfDecimal.setGroupingUsed(true)
val result = dfDecimal.format(BigDecimal(input))
result.substring(0, result.length - 2)
}
/**
* Returns the first string in a target string that matches a regular expression.
*/
val re_get_match = udf { (input: String, pattern: String) ⇒
_re_get_match(input, pattern)
}
val decodeString =
udf { (input: String, charSet: String) ⇒
new String(input.getBytes("windows-1252"), charSet)
}
val decodeBytes =
udf { (input: Array[Byte], charSet: String) ⇒
new String(input, charSet)
}
val encodeString =
udf { (input: String, charSet: String) ⇒
val result = new String(input.getBytes(charSet), "windows-1252")
result
}
val encodeBytes =
udf { (input: Array[Byte], charSet: String) ⇒
new String(input).getBytes(charSet)
}
val writeLongToBytes =
udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
_writeLongToBytes(input, length, endian, isUnsigned)
}
val writeIntegerToBytes =
udf { (input: Any, length: Int, endian: String, isUnsigned: Boolean) ⇒
_writeIntegerToBytes(input, length, endian, isUnsigned)
}
val readBytesStringIntoLong =
udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
_readBytesStringIntoLong(input, length, endian, isUnsigned)
}
val readBytesIntoLong =
udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
_readBytesIntoLong(input, length, endian, isUnsigned)
}
val readBytesIntoInteger =
udf { (input: Array[Byte], length: Int, endian: String, isUnsigned: Boolean) ⇒
_readBytesIntoInteger(input, length, endian, isUnsigned)
}
val readBytesStringIntoInteger =
udf { (input: String, length: Int, endian: String, isUnsigned: Boolean) ⇒
_readBytesStringIntoInteger(input, length, endian, isUnsigned)
}
val packedBytesToDecimal = udf { (input: Array[Byte], scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
_packedBytesToDecimal(input, scale, isUnsigned, isStripped)
}
val packedBytesStringToDecimal = udf { (input: String, scale: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
_packedBytesStringToDecimal(input, scale, isUnsigned, isStripped)
}
val bigDecimalToPackedBytes = udf {
(input: java.math.BigDecimal, precision: Int, isUnsigned: Boolean, isStripped: Boolean) ⇒
_bigDecimalToPackedBytes(input, precision, isUnsigned, isStripped)
}
/**
* Returns the first string in a target string that matches a regular expression.
*/
val re_index_with_offset = udf { (input: String, pattern: String, offset: Int) ⇒
val result =
pattern.r.findAllMatchIn(input.substring(offset - 1)).map(_.start)
if (result.hasNext) {
result.next() + offset
} else {
0
}
}
def datetime_from_unixtime(seconds: Column): Column =
from_unixtime(seconds, "yyyyMMddHHmmssSSSSSS")
/**
* UDF to get file information for passed input file path.
*/
val file_information = udf { (inputPath: String) ⇒
SparkFunctionsCrossSupport.file_information_def(inputPath)
}
def getMTimeDataframe(filepath: String, format: String, spark: SparkSession): DataFrame = {
import spark.implicits._
Seq(
(filepath,
DateTimeFormat
.forPattern(format)
.print(SparkFunctionsCrossSupport.file_information_def(filepath).modified * 1000)
)
).toDF("file", "mtime")
}
/**
* UDF wrapper over re_index function.
*/
val re_index = udf { (input: String, pattern: String) ⇒
val result = pattern.r.findAllMatchIn(input.substring(0)).map(_.start)
if (result.hasNext) {
result.next() + 1
} else {
0
}
}
def string_suffix(input: Column, len: Int): Column =
when(lit(len) >= length(input), input)
.otherwise(when(lit(len) <= lit(0), lit("")).otherwise(input.substr(length(input) - lit(len) + lit(1), lit(len))))
def scanf_long(format: Column, value: Column): Column =
when(format === lit("%Ld"), value.cast(LongType)).otherwise(lit(null))
def scanf_double(format: Column, value: Column): Column =
when(format === lit("%lf"), value.cast(DoubleType)).otherwise(lit(null))
/**
* Concatenates the elements of column using the delimiter.
*/
def string_join(column: Column, delimiter: String): Column =
SparkFunctionsCrossSupport.string_join(column, delimiter)
/**
* Method to zip two arrays with first one having event_type and second one having event_text
*
* @param column1
* @param column2
* @return
*/
def zip_eventInfo_arrays(column1: Column, column2: Column): Column =
SparkFunctionsCrossSupport.zip_eventInfo_arrays(column1, column2)
/**
* Method to get field at specific position from struct column
*
* @param column
* @return
*/
def getFieldFromStructByPosition(column: Column, position: Int): Column =
SparkFunctionsCrossSupport.getFieldFromStructByPosition(column, position)
/**
* UDF to split input string via delimiter string and remove all empty subtrings.
*/
val string_split_no_empty = udf { (input: String, delimiter: String) ⇒
input.split(delimiter).filterNot(_.isEmpty)
}
def replace_null_with_blank(input: Column): Column = when(input.isNull, lit("")).otherwise(input)
/**
* UDF to split input string via delimiter string.
*/
val string_split = udf { (input: String, delimiter: String) ⇒
input.split(delimiter)
}
val type_info = udf { (dml_type: String) ⇒
_type_info(dml_type)
}
val record_info = udf { (dml_type: String) ⇒
_record_info(dml_type)
}
val type_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
_type_info(dml_type, includes)
}
val record_info_with_includes = udf { (dml_type: String, includes: Array[String]) ⇒
_record_info(dml_type, includes)
}
val read_file = udf { (filePath: String) ⇒
_read_file(filePath)
}
val unique_identifier = udf { () ⇒
java.util.UUID.randomUUID.toString
}
val hash_MD5 = udf { (input: Any) ⇒
val inputBytes = if (input.isInstanceOf[String]) input.asInstanceOf[String].getBytes else getByteArray(input)
val messageDigest = MessageDigest.getInstance("MD5")
messageDigest.reset()
messageDigest.update(inputBytes)
messageDigest.digest()
}
val string_representation = udf { (input: Any) ⇒
_string_representation(input)
}
val canonical_representation =
udf(
(input: Any) ⇒ {
val (encoding, value): (Short, String) = if (input == null) {
(0, "NULL")
} else if (input.isInstanceOf[String]) {
try {
val rep = new String(input.toString.getBytes, "utf-8")
(1, rep)
} catch {
case _: Throwable ⇒
(2, input.toString)
}
} else if (input.isInstanceOf[Row] || input.isInstanceOf[Seq[_]]) {
(3, "")
} else {
(1, input.toString)
}
Row(
encoding,
value
)
},
StructType(
List(
StructField("encoding", ShortType, false),
StructField("value", StringType, false)
)
)
)
/**
* Method to test whether a string matches a specified pattern. This function returns 1 if the input string matches
* a specified pattern, and 0 if the string does not match the pattern.
*
* In abinitio version % character in pattern means to match zero or more characters and _ character means matches a
* single character.
*/
val string_like = udf { (input: String, pattern: String) ⇒
input.matches(pattern.replaceAll("%", "(.*)").replaceAll("_", "."))
}
/**
* UDF to split input string via pattern string and remove all empty subtrings.
*/
val re_split_no_empty = udf { (input: String, pattern: String) ⇒
input.split(pattern).filterNot(_.isEmpty)
}
/**
* UDF to find index of seekStr in inputStr. Returned index will be
* 1 based index.
*/
val string_index = udf { (inputStr: String, seekStr: String) ⇒
inputStr.indexOf(seekStr) + 1
}
/**
* Returns the index of the first character of the last occurrence of a seek string within another input string.
* Returned index is 1 based.
*/
val string_rindex = udf { (inputStr: String, seekStr: String) ⇒
inputStr.lastIndexOf(seekStr) + 1
}
/**
* UDF to find index of seekStr in inputStr from end of inputStr skipping offset number of characters from end. Offset index is number
* of characters, from the end of str, to skip before searching. Returned string position is 1 based position.
*/
val string_rindex_with_offset = udf((inputStr: String, seekStr: String, offset: Int) ⇒ {
if (offset <= inputStr.length) {
val index =
inputStr.substring(0, inputStr.length - offset).lastIndexOf(seekStr)
index + 1
} else {
0
}
})
/**
* UDF to return a flag for each character if it is present or not in input String.
*/
val make_byte_flags = udf { (inputStr: String) ⇒
val charFlag = new Array[Int](256)
inputStr.toCharArray.foreach(c ⇒ charFlag(c.toInt) = 1)
charFlag
}
/**
* UDF to return a string in the native character set made up of bytes from the given map. Each byte of the result is
* the value of map indexed by the character code of the corresponding byte of the input string str. The function
* returns NULL if any argument is NULL.
*/
val translate_bytes = udf((inputStr: String, byteMap: Seq[Byte]) ⇒ {
if (inputStr == null || byteMap == null)
null
else {
val translatedChars = inputStr.getBytes.map { c ⇒
if (c < byteMap.length && c >= 0)
byteMap(c)
else c
}
new String(translatedChars)
}
})
/**
* UDF to identify the number of characters in inputStr which are present in charFlag
*/
val test_characters_all = udf { (inputStr: String, charFlag: Seq[Int]) ⇒
inputStr.toCharArray.filter(c ⇒ charFlag(c.toInt) == 1).size
}
/**
* Converts a string from one character set to another, replacing inconvertible characters with a specified string.
*/
val string_convert_explicit = udf { (input: String, charSet: String, replaceStr: String) ⇒
new String(input.getBytes(), charSet).replaceAll("�", replaceStr)
}
/**
* This implementation is incorrect.
*/
val string_cleanse = udf { (input: String, replacement: String, charSet: String) ⇒
new String(input.getBytes(), charSet).replaceAll("�", replacement)
}
val decode_datetime_as_local =
udf(
(time: String) ⇒ {
val datetimeInfo = _decode_datetime(time)
Row(
datetimeInfo.year,
datetimeInfo.month,
datetimeInfo.day,
datetimeInfo.hour,
datetimeInfo.minute,
datetimeInfo.second,
ZonedDateTime.now.getOffset.getTotalSeconds / 60
)
},
StructType(
List(
StructField("year", IntegerType, false),
StructField("month", IntegerType, false),
StructField("day", IntegerType, false),
StructField("hour", IntegerType, false),
StructField("minute", IntegerType, false),
StructField("second", IntegerType, false),
StructField("microsecond", IntegerType, false),
StructField("timezone_offset", ShortType, false)
)
)
)
val xmlToJSON = udf { (input: String) ⇒
val jsonObj = XML.toJSONObject(input)
val newJsonObj = if (jsonObj.length() == 1) jsonObj.getJSONObject(jsonObj.keys().next()) else jsonObj
newJsonObj.toString
}
def from_xml(content: Column, schema: StructType): Column = {
val json = xmlToJSON(content)
from_json(json, schema)
}
def numberOfPartitions(in: DataFrame): Column =
lit(in.rdd.getNumPartitions)
def findLastElement(input: Column, default: Column = lit(null)): Column = {
val lastElement = SparkFunctionsCrossSupport.element_at(reverse(input), lit(1))
when(lastElement.isNull, default).otherwise(lastElement)
}
def string_length(input: Column) = length(input)
def findFirstElement(input: Column, default: Column = lit(null)): Column = {
val firstElement = SparkFunctionsCrossSupport.element_at(input, lit(1))
when(firstElement.isNull, default).otherwise(firstElement)
}
def findFirstNonBlankElement(input: Column, default: Column): Column = {
val firstElement = SparkFunctionsCrossSupport.element_at(
SparkFunctionsCrossSupport.array_except(input, array(lit(""), lit(" "), lit(" "), lit(" "))),
lit(1)
)
when(firstElement.isNull, default).otherwise(firstElement)
}
def getColumnInSecondArrayByFirstNonBlankPositionInFirstArray(
nonBlankEntryExpr: Column,
firstArray: Column,
secondArray: Column
): Column = {
SparkFunctionsCrossSupport.element_at(
secondArray,
SparkFunctionsCrossSupport
.array_position(
firstArray,
SparkFunctionsCrossSupport.element_at(nonBlankEntryExpr, lit(1))
)
.cast(IntegerType)
)
}
def schemaRowCompareResult(row1: StructType, row2: StructType): Column = {
val value1 = concat(flattenStructSchema(row1): _*)
val value2 = concat(flattenStructSchema(row2): _*)
value1 === value2
}
def flattenStructSchema(schema: StructType, prefix: String = null): Array[Column] = {
schema.fields.flatMap { field ⇒
val columnName = if (prefix == null) field.name else (prefix + "." + field.name)
field.dataType match {
case st: StructType ⇒ flattenStructSchema(st, columnName)
case _ ⇒ Array(col(columnName).as(columnName.replace(".", "_")))
}
}
}
/**
* UDF to get record of type decode_datetime_type. This record will have all its fields populated with
* corresponding entries in input date/timestamp.
*
* Returned record will have following schema.
*
* integer(8) year;
* integer(8) month;
* integer(8) day;
* integer(8) hour;
* integer(8) minute;
* integer(8) second;
* integer(8) microsecond;
*
* Note: Supported Input time is in yyyy-MM-dd HH:mm:ss.SSSSSS or yyyy-MM-dd HH:mm:ss or yyyy-MM-dd formats only.
* Additional handling is done to support timestamp retrieved from now() function call.
*/
val decode_datetime =
udf(
(time: String) ⇒ {
val datetimeInfo = _decode_datetime(time)
Row(
datetimeInfo.year,
datetimeInfo.month,
datetimeInfo.day,
datetimeInfo.hour,
datetimeInfo.minute,
datetimeInfo.second,
datetimeInfo.microsecond
)
},
StructType(
List(
StructField("year", IntegerType, false),
StructField("month", IntegerType, false),
StructField("day", IntegerType, false),
StructField("hour", IntegerType, false),
StructField("minute", IntegerType, false),
StructField("second", IntegerType, false),
StructField("microsecond", IntegerType, false)
)
)
)
val cross_join_index_range = udf(
(input1: Seq[Int], input2: Seq[Int]) ⇒ input1.flatMap(i ⇒ input2.map(j ⇒ Row.fromSeq(Seq(i, j)))),
ArrayType(
StructType(
List(
StructField("i", IntegerType, false),
StructField("j", IntegerType, false)
)
)
)
)
/**
* Method to identify if input string is a blank string or not.
*
* @param input input string.
* @return return 1 if given string contains all blank character or is a zero length string,
* otherwise it returns 0
*/
def is_blank(input: Column): Column =
when(isnull(input), lit(null))
.otherwise(when(length(trim(input)) === lit(0), lit(true)).otherwise(lit(false)))
/**
* Method to list all files present in a passed path where filename starts with passed
* prefix.
*
* @param path
* @param filePrefix
* @return
*/
private def listFiles(path: String, filePrefix: String) = {
new File(path)
.listFiles()
.filter(_.isFile)
.map(_.getName)
.filter(_.startsWith(filePrefix))
}
def directory_listing(path: String, filePrefix: String) =
typedLit(listFiles(path, filePrefix))
private def getDataTypeFromIntLength(lenValue: Int, dataType: DataType) = {
lenValue match {
case 1 ⇒ ByteType
case 2 ⇒ ShortType
case 4 ⇒ IntegerType
case 8 ⇒ LongType
case _ ⇒ dataType
}
}
def is_valid(input: Column): Column =
is_valid(input, false)
def is_valid(input: Column, isNullable: Boolean): Column =
is_valid(input, isNullable, None, None)
def is_valid(input: Column, formatInfo: Option[Any]): Column =
is_valid(input, false, formatInfo, None)
def is_valid(input: Column, formatInfo: Option[Any], len: Option[Seq[Int]]): Column =
is_valid(input, false, formatInfo, len)
def is_valid(input: Column, isNullable: Boolean, formatInfo: Option[Any]): Column =
is_valid(input, isNullable, formatInfo, None)
/**
* Method to identify if passed input column is a valid expression after typecasting to passed dataType.
* Also while typecasting if len is present then this function also makes sure the max length of input column
* after typecasting operation is not greater than len.
*
* @param input input column expression to be identified if is valid.
* @param formatInfo datatype to which input column expression must be typecasted.
* If datatype is a string then it is treated as timestamp format. If it is
* a list of string then it is treated as having current timestamp format and
* and new timestamp format to which input column needs to be typecasted.
* @param len max length of input column after typecasting it to dataType.
* @return 0 if input column is not valid after typecasting or 1 if it is valid.
*/
def is_valid(
input: Column,
isNullable: Boolean,
formatInfo: Option[Any],
len: Option[Seq[Int]]
): Column = {
if (isNullable) lit(true)
else {
val result: Column = len match {
case None | Some(Nil) ⇒
formatInfo match {
case Some(typecast: DataType) ⇒ input.cast(typecast)
case Some(format: String) ⇒ to_timestamp(input, format)
case Some(List(currentFormat: String, newFormat: String)) ⇒
date_format(to_timestamp(input, currentFormat), newFormat)
case _ ⇒ input // TODO: support other types
}
case Some(lenValue) ⇒
formatInfo match {
case Some(x: DecimalType) ⇒
if (lenValue.length == 2) {
input.cast(DecimalType(lenValue(0), lenValue(1)))
} else if (lenValue.length == 1) {
input.cast(DecimalType(lenValue(0), 0))
} else {
input.cast(DoubleType)
}
case Some(x: IntegerType) ⇒ input.cast(getDataTypeFromIntLength(lenValue.head, x))
case Some(y: StringType) ⇒
when(length(input.cast(y)) > lit(lenValue.head), lit(null))
.otherwise(input.cast(y))
case Some(z: DoubleType) ⇒
when(input.cast(DecimalType(lenValue.head, 0)).isNull, lit(null))
.otherwise(
input.cast(DecimalType(lenValue.head, 0)).cast(StringType).cast(z)
)
case _ ⇒ input // TODO: support other types
}
}
when(result.isNull, lit(false)).otherwise(lit(true))
}
}
/**
* Method to get current timestamp.
*
* @return current timestamp in YYYYMMddHHmmssSSSSSS format.
*/
def now(): Column = date_format(current_timestamp(), "yyyyMMddHHmmssSSSSSS")
/**
* Returns the number of hours between two specified dates in standard format yyyy-MM-dd HH:mm:ss.SSSS.
*
* @param end
* @param start
* @return
*/
def datetime_difference_hours(end: Column, start: Column): Column =
(unix_timestamp(end) - unix_timestamp(start)) / lit(3600.0)
val datetime_difference =
udf(
{ (timestamp1: String, timestamp2: String) ⇒
val endTime = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp1).getTime
val startTime = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS").parse(timestamp2).getTime()
val differenceTime = endTime - startTime
val days = differenceTime / (24 * 60 * 60 * 1000)
var remainingTime = differenceTime % (24 * 60 * 60 * 1000)
val hours = remainingTime / (60 * 60 * 1000)
remainingTime = remainingTime % (60 * 60 * 1000)
val minutes = remainingTime / (60 * 1000)
remainingTime = remainingTime % (60 * 1000)
val seconds = remainingTime / 1000
remainingTime = remainingTime % 1000
val microSeconds = remainingTime * 1000
Row(days, hours, minutes, seconds, microSeconds)
},
StructType(
List(
StructField("days", LongType, true),
StructField("hours", LongType, true),
StructField("minutes", LongType, true),
StructField("seconds", LongType, true),
StructField("microseconds", LongType, true)
)
)
)
def string_compare(input1: Column, input2: Column): Column =
when(input1.isNull.or(input2.isNull), lit(null))
.otherwise(when(input1 < input2, lit(-1)).otherwise(when(input1 > input2, lit(1)).otherwise(lit(0))))
/**
* Method to convert
*
* @param timezone
* @param time
* @return
*/
def timezone_to_utc(timezone: String, time: Column): Column =
if (timezone == "local") to_utc_timestamp(time, TimeZone.getDefault.toZoneId.getId)
else to_utc_timestamp(time, timezone)
/**
* Method to return integer value representing number of days to today from “1-1-1990”.
*
* @return integer value
*/
def today(): Column =
datediff(current_date(), to_date(lit("1990-01-01"), "yyyy-MM-dd"))
/**
* Method to check if current column is null or has empty value.
*
* @param input
* @return
*/
def isNullOrEmpty(input: Column): Column =
when(input.isNull || length(trim(input)) === lit(0), lit(true))
.otherwise(lit(false))
/**
* Function to create sequence of array between two passed numbers
*
* @param start starting point of generated sequence
* @param end terminating point of generated sequence.
* @return column containing sequence of integers.
*/
def generate_sequence(start: Int, end: Int, step: Int = 1): Column =
array((start to end).map(lit): _*)
/**
* Method to create array of size "size" containing seedVal as each entry
*
* @param size
* @param seedVal
* @return
*/
def make_constant_vector(size: Int, seedVal: Column): Column =
array((1 to size).map(x ⇒ seedVal): _*)
/**
* Method to create array of size "size" containing seedVal as each entry
*
* @param size
* @param seedVal
* @return
*/
def make_constant_vector(size: Int, seedVal: Int) =
(1 to size).map(x ⇒ seedVal).toArray
/**
* UDF to break input string into multiple string via delimiter. Number of strings after split are adjusted as per
* passed width parameter. If number of strings are less then empty strings are added otherwise in case of more number
* of strings, first width number of entries are picked and remaining are discarded.
*/
val splitIntoMultipleColumnsUdf =
udf { (input: String, delimiter: String, width: Int) ⇒
val entries = input.split(delimiter)
(0 until width)
.map(index ⇒
if (entries.length > index) {
entries(index)
} else {
""
}
)
.toArray
}
/**
* Method to create dataframe with single column containing increasing sequence id from start to end.
*
* @param start
* @param end
* @param columnName
* @param sparkSession
* @return
*/
def generateDataFrameWithSequenceColumn(
start: Int,
end: Int,
columnName: String,
sparkSession: SparkSession
): DataFrame = {
import sparkSession.implicits._
val arrayDF = Seq(((start to end).toArray)).toDF(columnName)
arrayDF.select(explode(col(columnName)).as(columnName))
}
class StringAsStream(val content: String) extends Serializable {
var currentOffset: Int = 0
def read_string(len: Int): String = {
if (currentOffset >= content.length || currentOffset + len > content.length)
null
else {
val result = content.substring(currentOffset, currentOffset + len)
currentOffset = currentOffset + len
result
}
}
}
def getContentAsStream(content: String): StringAsStream =
new StringAsStream(content)
/**
* Method to read values from inputData and create dataframe with column name as columnName and column type as
* columnType for the values in inputData delimiter by delimiter.
*
* @param inputData
* @param delimiter
* @param columnName
* @param columnType
* @param sparkSession
* @return
*/
def createDataFrameFromData(
inputData: String,
delimiter: String,
columnName: String,
columnType: String,
sparkSession: SparkSession
): DataFrame = {
val schema = StructType(
List(StructField(columnName, getColumnDataType(columnType)))
)
val rowData =
inputData.split(delimiter).map(value ⇒ Row.fromSeq(List(value))).toList
sparkSession.createDataFrame(rowData.asJava, schema)
}
/**
* Method to get spark data type.
*
* @param columnType
* @return
*/
private def getColumnDataType(columnType: String): DataType = {
columnType match {
case "StringType" ⇒ StringType
case "IntegerType" ⇒ IntegerType
case "DoubleType" ⇒ DoubleType
case "BooleanType" ⇒ BooleanType
case _ ⇒ StringType
}
}
/**
* Method to get ByteArray from any passed object.
*
* @param input
* @return
*/
private def getByteArray(input: Any): Array[Byte] = {
lazy val default = new ObjectMapper().writeValueAsBytes(input)
input match {
case value: String ⇒ value.getBytes("windows-1252")
case value: Long ⇒ BigInt(value).toByteArray
case value: Int ⇒ BigInt(value).toByteArray
case value: Short ⇒ BigInt(value).toByteArray
case value: Byte ⇒ BigInt(value).toByteArray
case schema: GenericRowWithSchema if schema.size > 0 ⇒
schema.get(0) match {
case LongWrappedArray(array) ⇒ array.filterNot(_ == 0).flatMap(x ⇒ BigInt(x).toByteArray).toArray
case _ ⇒ default
}
case _ ⇒ default
}
}
/**
* Tests whether an object is composed of all binary zero bytes.
* This function returns:
* 1. 1 if obj contains only binary zero bytes or is a zero-length string
* 2. 0 if obj contains any non-zero bytes
* 3. NULL if obj is NULL
*/
val is_bzero = udf { (input: Any) ⇒
if (input.isInstanceOf[String] && input.asInstanceOf[String].isEmpty) true
else if (input.isInstanceOf[Int] && input.asInstanceOf[Int] == 0) true
else false
}
/**
* UDF to get last Byte from ByteArray of input data.
*/
val getByteFromByteArray = udf { (input: Any, offset: Int) ⇒
var byteArray = getByteArray(input)
byteArray = byteArray.slice(offset, byteArray.length)
byteArray(byteArray.length - 1)
}
/**
* UDF to get short comprising of last 2 Bytes from ByteArray of input data.
*/
val getShortFromByteArray = udf { (input: Any, offset: Int) ⇒
var byteArray = getByteArray(input)
byteArray = byteArray.slice(offset, byteArray.length)
val slicedByteArray =
byteArray.slice(byteArray.length - 2, byteArray.length)
new BigInteger(slicedByteArray).shortValue()
}
/**
* UDF to get integer comprising of last 4 Bytes from ByteArray of input data.
*/
val getIntFromByteArray = udf { (input: Any, offset: Int) ⇒
var byteArray = getByteArray(input)
byteArray = byteArray.slice(offset, byteArray.length)
val slicedByteArray =
byteArray.slice(byteArray.length - 4, byteArray.length)
new BigInteger(slicedByteArray).intValue()
}
val url_encode_escapes = udf { (input: String) ⇒
URLEncoder.encode(input, "UTF-8")
}
// TODO support for threshold!
val force_error = udf { (input: String) ⇒
throw new Exception(s"$input")
}
/**
* UDF to get long comprising of last 8 Bytes from ByteArray of input data.
*/
val getLongFromByteArray = udf { (input: Any, offset: Int) ⇒
var byteArray = getByteArray(input)
byteArray = byteArray.slice(offset, byteArray.length)
val slicedByteArray =
byteArray.slice(byteArray.length - 8, byteArray.length)
new BigInteger(slicedByteArray).longValue()
}
/**
* Method to check if all characters of passed input are digit or not.
*
* @param input
* @return
*/
private def isAllDigits(input: String): Boolean =
input.forall(Character.isDigit)
object LongWrappedArray {
def unapply(input: Any): Option[mutable.WrappedArray[Long]] = input match {
case array: mutable.WrappedArray[_] if array.forall(_.isInstanceOf[Long]) ⇒
Some(array.asInstanceOf[mutable.WrappedArray[Long]])
case _ ⇒ None
}
}
object LongSequence {
def unapply(input: Any): Option[Seq[Long]] = input match {
case sequence: Seq[_] if sequence.forall(_.isInstanceOf[Long]) ⇒ Some(sequence.asInstanceOf[Seq[Long]])
case _ ⇒ None
}
}
/**
* Method used for abinitio's reinterpret_as function to read necessary bytes from byteArray for input data and convert
* into struct format as per provided in typeInfo sequence.
*
* TypeInfo can have multiple entries, each could be either decimal or string type. Depending on the argument passed
* within decimal or string bytes are read from input byte array.
*
* If decimal or string argument has some integer then that many bytes are read from input byte array or if decimal or
* string has some string delimiter as its argument then from the current position bytes are read until string delimiter
* is found in input byte array.
*
* @param input
* @param typeInfo
* @return
*/
def convertInputBytesToStructType(input: Any, typeInfo: Seq[String], startByte: Int = 0): Row = input match {
case LongSequence(sequence) ⇒
Row.fromSeq(sequence)
case _ ⇒
var curPointer = 0
var byteArray = getByteArray(input)
byteArray = byteArray.slice(startByte, byteArray.length)
val stringVal = new String(byteArray, "windows-1252")
val rowValues = ListBuffer[Any]()
typeInfo.map { curType ⇒
if (curType.startsWith("decimal") || curType.startsWith("string")) {
val stInd = curType.indexOf("(")
val endInd = curType.indexOf(")")
val arg = curType.substring(stInd + 1, endInd)
if (isAllDigits(arg)) {
val takeUntil = Math.min(Integer.parseInt(arg) + curPointer, stringVal.length)
rowValues += stringVal.substring(curPointer, takeUntil)
curPointer = takeUntil
} else {
val trimmedArg = arg.trim
val pattern = trimmedArg.substring(1, trimmedArg.length - 1)
val index = stringVal.substring(curPointer).indexOf(pattern)
rowValues += stringVal.substring(curPointer, index + curPointer)
curPointer = curPointer + index + pattern.length
}
} else if (curType == "long") {
if (curPointer >= byteArray.length) {
rowValues += 0L
} else {
val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
rowValues += new BigInteger(slicedByteArray).longValue()
curPointer = curPointer + 8
}
}
}
Row.fromSeq(rowValues)
}
/**
* UDF to get long comprising of last 8 Bytes from ByteArray of input data.
*/
val getLongArrayFromByteArray = udf(
(input: Any, size: Int, offset: Int) ⇒ {
var byteArray = getByteArray(input)
byteArray = byteArray.slice(offset, byteArray.length)
var curPointer = 0
(0 until size).map { itr ⇒
if (curPointer >= byteArray.length) {
0L
} else {
val slicedByteArray = byteArray.slice(curPointer, curPointer + 8)
curPointer = curPointer + 8
new BigInteger(slicedByteArray).longValue()
}
}
},
DataTypes.createArrayType(LongType, false)
)
/**
* UDF for murmur hash generation for any column type
*/
val murmur = udf { (x: Any) ⇒
val murmurInstance = new Murmur3F()
if (x.isInstanceOf[GenericRowWithSchema]) {
val input = x.asInstanceOf[GenericRowWithSchema]
val inputByteArray = (0 until input.length).map(input.get(_)).toArray.flatMap(_.toString.getBytes())
murmurInstance.update(inputByteArray)
Array(murmurInstance.getValue, murmurInstance.getValueHigh)
} else if (x.isInstanceOf[Seq[Any]]) {
val inputByteArray = x.asInstanceOf[Seq[Any]].flatMap(_.toString.getBytes()).toArray
murmurInstance.update(inputByteArray)
Array(murmurInstance.getValue, murmurInstance.getValueHigh)
} else {
murmurInstance.update(x.toString.getBytes())
Array(murmurInstance.getValue, murmurInstance.getValueHigh)
}
// if (x.isInstanceOf[GenericRowWithSchema] && x
// .asInstanceOf[GenericRowWithSchema]
// .get(0)
// .isInstanceOf[mutable.WrappedArray[Any]]) {
// MurmurHash3.arrayHash(x.asInstanceOf[GenericRowWithSchema].get(0).asInstanceOf[mutable.WrappedArray[Any]].toArray)
// } else if (x.isInstanceOf[GenericRowWithSchema] && x
// .asInstanceOf[GenericRowWithSchema]
// .length > 0) {
// MurmurHash3.arrayHash(
// (0 until x.asInstanceOf[GenericRowWithSchema].length).map(x.asInstanceOf[GenericRowWithSchema].get(_)).toArray
// )
// } else if (x.isInstanceOf[Seq[Any]]) MurmurHash3.arrayHash(x.asInstanceOf[Array[Any]])
// else if (x.isInstanceOf[String]) MurmurHash3.stringHash(x.asInstanceOf[String])
}
/**
* Method to return the result of evaluating a string expression in the context of a specified input column. Here
* input column could be struct type record, simple column, array type etc.
* Here expr could be reference to nested column inside input column or any expression which requires values from input
* column for its evaulation.
*
* Note: Current implementation only supports scenerio where input column is of struct type and expr is simply dot
* separated column reference to input struct.
*
* @param input
* @param expr
* @return
*/
val eval = udf((input: Any, expr: String) ⇒ {
if (input == null || expr == null || !input.isInstanceOf[Row]) null
else {
val tokens = expr.split("\\.")
if (tokens.length == 2) {
val result = input.asInstanceOf[Row].getAs[Row](tokens(0)).getAs[Any](tokens(1))
if (result != null) {
result.toString
} else {
null
}
} else if (tokens.length == 1) {
val result = input.asInstanceOf[Row].getAs[Any](tokens(0))
if (result != null) {
result.toString
} else {
null
}
} else {
null
}
}
})
/**
* UDF to truncate microseconds part of timestamp. This is needed as abinitio and spark has some incompatibility in
* microseconds part of timestamp format.
*/
val truncateMicroSeconds = udf((sourceFormat: String, timestamp: String) ⇒ {
if (timestamp == null || sourceFormat == null) null
else {
val index = sourceFormat.indexOf("S")
if (index < 0) {
timestamp
} else {
timestamp.substring(0, index)
}
}
})
/**
* Method to replace String Columns with Empty value to Null.
*
* @param input
* @return
*/
def replaceBlankColumnWithNull(input: Column): Column =
when(length(input.cast(StringType)) === lit(0), lit(null)).otherwise(input)
/**
* Method to identify and return first non null expression.
*
* @param expr1
* @param expr2
* @return
*/
def first_defined(expr1: Column, expr2: Column): Column =
when(expr1.isNull, expr2).otherwise(expr1)
/**
* Adds an explicit sign to the number. E.g.
* 2 -> +2; -004 -> -004; 0 -> +0
*/
def sign_explicit(c: Column): Column =
when(c.startsWith("+") || c.startsWith("-"), c)
.otherwise(when(c >= 0, concat(lit("+"), c)).otherwise(concat(lit("-"), c)))
def from_sv(input: Column, separator: String, schema: StructType): Column = {
val splitResult = split(input, separator)
val fields = schema.fields.zipWithIndex.map {
case (field, index) ⇒ splitResult.getItem(index).as(field.name)
}
struct(fields: _*).as("value")
}
/**
* Method removes any non-digit characters from the specified string column.
*
* @param input input String Column
* @return Cleaned string column or null
*/
def remove_non_digit(input: Column): Column = {
val pattern = "[^\\d]"
when(input.isNull, lit(null)).otherwise(regexp_replace(input, pattern, ""))
}
/**
* Computes number of days between two specified dates in "yyyyMMdd" format
*
* @param laterDate input date
* @param earlierDate input date
* @return number of days between laterDate and earlierDate or null if either one is null
*/
def date_difference_days(laterDate: Column, earlierDate: Column): Column = {
val dateFormat = "yyyyMMdd"
when(laterDate.isNull.or(earlierDate.isNull), lit(null))
.otherwise(datediff(to_date(laterDate, dateFormat), to_date(earlierDate, dateFormat)))
}
/**
* Checks if a string is ascii
*
* @param input column to be checked
* @return true if the input string is ascii otherwise false
*/
def is_ascii(input: Column): Column = {
val asciiPattern = "[^\\x00-\\x7f]"
when(input.isNull, lit(null)).otherwise(
when(length(regexp_replace(input, asciiPattern, "")) < length(input), lit(false))
.otherwise(lit(true))
)
}
/**
* Checks if an input string contains only ascii code and numbers
*
* @param input string to be checked
* @return true if input string contains only ascii code and numbers or null if input is null
*/
def is_numeric_ascii(input: Column): Column = {
when(input.isNull, lit(null))
.otherwise(
when((is_ascii(input) === lit(true))
.and((string_is_numeric(input)) === lit(true))
.and(length(input) > lit(0)),
lit(true)
).otherwise(lit(false))
)
}
/**
* Validates date against a input format
*
* @param dateFormat A pattern such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` or `dd.MM.yyyy`
* @param inDate Input date to be validated
* @return true if the input date is valid otherwise false
*/
def is_valid_date(dateFormat: String, inDate: Column): Column =
//when(date_format(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))
when(to_date(inDate, dateFormat).isNull, lit(false)).otherwise(lit(true))
/**
* Converts 1 digit julian year to 4 digits julian year.
*
* @param in_date date in Julian in "YJJJ" format
* @param ref_date date in "yyyyMMdd" format
* @return a date in "YYYYJJJ"
*/
def YJJJ_to_YYYYJJJ(in_date: Column, ref_date: Column): Column = {
val y_in_date = string_substring(in_date, lit(1), lit(1))
val y_ref_date = string_substring(ref_date, lit(4), lit(1))
val yyyy_ref_date = string_substring(ref_date, lit(1), lit(4))
val yyyy_ref_date_tmp = when((y_in_date.notEqual(y_ref_date)).and(y_ref_date.notEqual(lit(0))),
concat(string_substring(yyyy_ref_date, lit(1), lit(3)), y_in_date)
).otherwise(yyyy_ref_date)
val result = concat(yyyy_ref_date_tmp, string_substring(in_date, lit(2), lit(3)))
when((to_date(ref_date, "yyyyMMdd").isNull), lit(null))
.otherwise(when(to_date(result, "yyyyDDD").isNull, yyyyMMdd_to_YYYYJJJ(ref_date)).otherwise(result))
}
/**
* Converts yyyyyMMdd to YYYYJJJ
*
* @param in_date date in yyyyMMdd format
* @return a date converted to YYYYJJJ
*/
def yyyyMMdd_to_YYYYJJJ(in_date: Column) = {
val day_of_year = string_lrepad(date_format(to_date(in_date, "yyyyMMdd"), "D"), 3, "0")
val year = string_substring(in_date, lit(1), lit(4))
when(to_date(in_date, "yyyyMMdd").isNull, lit(null))
.otherwise(concat(year, day_of_year))
}
/**
* Computes number of days in February month in a given year
*
* @param year year whose number of days in February needs to be calculated
* @return number of days
*/
def getFebruaryDay(year: Column): Column =
when((year % 400 === lit(0)).or((year % 4 === lit(0)).and((year % 100).notEqual(lit(0)))), lit(29))
.otherwise(lit(28))
}