All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prophecy.libs.UDFUtils.scala Maven / Gradle / Ivy

There is a newer version: 6.3.0-3.3.0
Show newest version
/*
 * ====================================================================
 *
 * PROPHECY CONFIDENTIAL
 *
 * Prophecy Inc
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains
 * the property of Prophecy Inc, the intellectual and technical concepts contained
 * herein are proprietary to Prophecy Inc and may be covered by U.S. and Foreign Patents,
 * patents in process, and are protected by trade secret or copyright law.
 * Dissemination of this information or reproduction of this material
 * is strictly forbidden unless prior written permission is obtained
 * from Prophecy Inc.
 *
 * ====================================================================
 */
package io.prophecy.libs

import io.prophecy.libs.crosssupport.{RowEncoderCrossSupport, UnsafeRowWriterCrossSupport}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{array, callUDF, lit, udf}
import org.apache.spark.sql.types._
import com.typesafe.scalalogging.LazyLogging
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable
import scala.language.implicitConversions

/**
  * Utility class with different UDFs to take care of
  * miscellaneous tasks.
  */
trait UDFUtils extends RestAPIUtils with Serializable with LazyLogging {

  /**
    * UDF to return nth element from last in passed array of elements.
    * In case input sequence has less number of elements than n then first element
    * is returned.
    */
  val take_last_nth: UserDefinedFunction = udf((a: Seq[String], b: Int) ⇒ a.takeRight(b).head)

  /**
    * UDF to take Nth element from beginning. In case input sequence has less
    * element than N then exception is thrown.
    */
  val take_nth: UserDefinedFunction = udf((a: Seq[String], b: Int) ⇒ a.apply(b))

  /**
    * UDF to find str in input sequence toBeReplaced and return replace if found.
    * Otherwise str is returned.
    */
  val replace_string: UserDefinedFunction = udf((str: String, replace: String, toBeReplaced: Seq[String]) ⇒
    toBeReplaced
      .find(x ⇒ str == x || (str != null && x != null && str.equals(x)))
      .map { x ⇒
        replace
      }
      .orElse(Some(str))
      .get
  )

  /**
    * UDF to find str in input sequence toBeReplaced and return null if found.
    * Otherwise str is returned.
    */
  val replace_string_with_null: UserDefinedFunction = udf((str: String, toBeReplaced: Seq[String]) ⇒
    toBeReplaced
      .find(x ⇒ str == x || (str != null && x != null && str.equals(x)))
      .map { x ⇒
        null
      }
      .orElse(Some(str))
      .get
  )

  /**
    * UDF to find and return element in arr sequence at passed index. If no element
    * found then null is returned.
    */
  val array_value: UserDefinedFunction =
    udf((arr: Seq[String], index: Int) ⇒ if (arr.size > index) arr(index) else null)

  /**
    * Function to take variable number of values and create an array column out of it.
    *
    * @param value input value
    * @param values variable number of input values.
    * @return an array of column.
    */
  def arrayColumn(value: String, values: String*) = {
    val allValues = lit(value) :: values.toList.map { x ⇒
      lit(x)
    }
    array(allValues: _*)
  }

  /**
    * Function to add new column in passed dataframe. Newly added column value is decided by the presence of
    * value corresponding to inputCol in array comprised of value and values. If inputCol is found then value of
    * replaceWith is added in new column otherwise inputCol value is added.
    *
    * @param sparkSession spark session.
    * @param df input dataframe.
    * @param outputCol name of new column to be added.
    * @param inputCol column name whose value is searched.
    * @param replaceWith value with which to replace searched value if found.
    * @param value element to be combined in array column
    * @param values all values to be combined in array column for searching purpose.
    * @return dataframe with new column with column name outputCol
    */
  def replaceString(
    sparkSession: SparkSession,
    df:           DataFrame,
    outputCol:    String,
    inputCol:     String,
    replaceWith:  String,
    value:        String,
    values:       String*
  ) =
    df.withColumn(
      outputCol,
      replace_string(
        df(inputCol),
        lit(replaceWith),
        arrayColumn(value, values: _*)
      ).as(outputCol)
    )

  /**
    * Function to add new column in passed dataframe. Newly added column value is decided by the presence of
    * value corresponding to inputCol in array comprised of value and values and null. If inputCol is found then value of
    * replaceWith is added in new column otherwise inputCol value is added.
    *
    * @param sparkSession spark session.
    * @param df input dataframe.
    * @param outputCol name of new column to be added.
    * @param inputCol column name whose value is searched.
    * @param replaceWith value with which to replace searched value if found.
    * @param value element to be combined in array column
    * @param values all values to be combined in array column for searching purpose.
    * @return dataframe with new column with column name outputCol
    */
  def replaceStringNull(
    sparkSession: SparkSession,
    df:           DataFrame,
    outputCol:    String,
    inputCol:     String,
    replaceWith:  String,
    value:        String,
    values:       String*
  ) = {
    val allValues = value :: values.toList
    df.withColumn(
      outputCol,
      replace_string(
        df(inputCol),
        lit(replaceWith),
        arrayColumn(null, allValues: _*)
      ).as(outputCol)
    )
  }

  /**
    * Function to add new column in passed dataframe. Newly added column value is decided by the presence of
    * value corresponding to inputCol in array comprised of value and values and null. If inputCol is found then value of
    * null is added in new column otherwise inputCol value is added.
    *
    * @param sparkSession spark session.
    * @param df input dataframe.
    * @param outputCol name of new Column to be added.
    * @param inputCol column name whose value is searched.
    * @param value element to be combined in array column.
    * @param values all values to be combined in array column for searching purpose.
    * @return dataframe with new column with column name outputCol
    */
  def replaceStringWithNull(
    sparkSession: SparkSession,
    df:           DataFrame,
    outputCol:    String,
    inputCol:     String,
    value:        String,
    values:       String*
  ) =
    df.withColumn(
      outputCol,
      replace_string_with_null(df(inputCol), arrayColumn(value, values: _*))
        .as(outputCol)
    )

  /**
    * Function to split column with colName in input dataframe using split pattern into multiple columns.
    * If prefix name is provided each new generated column is prefixed with prefix followed by column number,
    * otherwise original column name is used.
    *
    * @param sparkSession spark session.
    * @param df input dataframe.
    * @param colName column in dataframe which needs to be split into multiple columns.
    * @param pattern regex with which column in input dataframe will be split into multiple
    *                columns.
    * @param prefix column prefix to be used with all newly generated columns.
    * @return new dataframe with new columns where new column values are generated after splitting
    *         original column colName.
    */
  def splitIntoMultipleColumns(
    sparkSession: SparkSession,
    df:           DataFrame,
    colName:      String,
    pattern:      String,
    prefix:       String = null
  ): DataFrame = {
    val intermittentColumn = s"${colName}_result_split_multiple_columns"
    val intermittentSizeColumn =
      s"${colName}_result_split_multiple_columns_size"
    val t1 = df.withColumn(
      intermittentColumn,
      functions.split(df(colName), pattern).as(intermittentColumn)
    )
    val newdf = t1.withColumn(
      intermittentSizeColumn,
      functions.size(t1(intermittentColumn)).as(intermittentSizeColumn)
    )
    val row = newdf
      .select(functions.max(newdf(intermittentSizeColumn)).as("max"))
      .collect()(0)
    val maxLength = row.getInt(0)
    val finalPrefix = {
      if (prefix == null) {
        colName
      } else {
        prefix
      }
    }
    var tmp = newdf
    for (x ← 0 until maxLength.toInt) {
      val outputColumn = s"${finalPrefix}_${x + 1}"
      tmp = tmp.withColumn(
        outputColumn,
        array_value(tmp(intermittentColumn), lit(x)).as(outputColumn)
      )
    }
    tmp = tmp.drop(intermittentColumn, intermittentSizeColumn)
    tmp
  }

  /**
    * Function to add new typecasted column in input dataframe. Newly added column
    * is typecasted version of passed column. Typecast operation is supported for
    * string, boolean, byte, short, int, long, float, double, decimal, date,
    * timestamp
    *
    * @param sparkSession spark session
    * @param df input dataframe
    * @param column input column to be typecasted
    * @param dataType datatype to cast column to.
    * @param replaceColumn column name to be added in dataframe.
    * @return new dataframe with new typecasted column.
    */
  def castDataType(
    sparkSession:  SparkSession,
    df:            DataFrame,
    column:        Column,
    dataType:      String,
    replaceColumn: String
  ): DataFrame = {
    val validCast = dataType match {
      case "string"    ⇒ true
      case "boolean"   ⇒ true
      case "byte"      ⇒ true
      case "short"     ⇒ true
      case "int"       ⇒ true
      case "long"      ⇒ true
      case "float"     ⇒ true
      case "double"    ⇒ true
      case "decimal"   ⇒ true
      case "date"      ⇒ true
      case "timestamp" ⇒ true
      case _           ⇒ false
    }
    if (validCast) {
      df.withColumn(replaceColumn, column.cast(dataType).as(replaceColumn))
    } else {
      df
    }
  }

  /**
    * Function to drop passed columns from input dataframe.
    *
    * @param sparkSession spark session
    * @param df input dataframe.
    * @param columns list of columns to be dropped from dataframe.
    * @return new dataframe with dropped columns.
    */
  def dropColumns(sparkSession: SparkSession, df: DataFrame, columns: Column*): DataFrame = {
    var last = df
    var next = last
    for (c ← columns) {
      next = last.drop(c)
      last = next
    }
    next
  }

  /**
    * Method to create UDF which looks for passed input double in input dataframe. This function first
    * loads the data of dataframe in broadcast variable and then defines a UDF which looks for input double
    * value in the data stored in broadcast variable. If input double lies between passed col1 and col2 values
    * then it adds corresponding row in the returned result. If value of input double doesn't lie between col1 and
    * col2 then it simply returns null for current row in result.
    *
    * @param name created UDF name
    * @param df input dataframe
    * @param spark spark session
    * @param minColumn column whose value to be considered as minimum in comparison.
    * @param maxColumn column whose value to be considered as maximum in comparison.
    * @param valueColumns remaining column names to be part of result.
    * @return registers UDF which in turn returns rows corresponding to each row in dataframe on which range UDF is called.
    */
  def createRangeLookup(
    name:         String,
    df:           DataFrame,
    spark:        SparkSession,
    minColumn:    String,
    maxColumn:    String,
    valueColumns: String*
  ): UserDefinedFunction = {
    def error(message: String) = throw new Exception(s"Error creating a ranged lookup $name: $message")

    val rowSchema = df.schema.filter { x ⇒
      valueColumns.contains(x.name)
    }

    val minColumnType = df.schema.fields
      .find(_.name == minColumn)
      .getOrElse(error(s"Min column $minColumn doesn't exist in the DataFrame (${df.schema.sql})."))
    val maxColumnType = df.schema.fields
      .find(_.name == maxColumn)
      .getOrElse(error(s"Max column $maxColumn doesn't exist in the DataFrame (${df.schema.sql})."))

    if (minColumnType.dataType != maxColumnType.dataType) {
      error(
        s"Max and min column types have to be the same, but found ${maxColumnType.dataType} & ${minColumnType.dataType}"
      )
    }

    val collected = measure(df.collect())(s"$name - collecting lookup data")

    // Spark broadcast's serialization is slow for deeply nested data structures like a BigDecimal, that Row contains.
    // Therefore, we're converting Row's into UnsafeRows' here. This ensures that we're serializing only the byte arrays
    // buffers stored within the UnsafeRows.
    val data = measure(
      collected
        .map { row ⇒
          val unsafeRowWriter = new UnsafeRowWriterCrossSupport(rowSchema.size + 2)
          unsafeRowWriter.reset()

          val allColumns = minColumn :: maxColumn :: valueColumns.toList
          allColumns.zipWithIndex.foreach {
            case (columnName, idx) ⇒
              val dataType = df.schema.apply(columnName).dataType
              dataType match {
                case StringType ⇒
                  val value = row.getString(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, UTF8String.fromString(value))
                case IntegerType ⇒
                  val value = row.getInt(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, value)
                case LongType ⇒
                  val value = row.getLong(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, value)
                case FloatType ⇒
                  val value = row.getFloat(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, value)
                case ShortType ⇒
                  val value = row.getShort(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, value)
                case DoubleType ⇒
                  val value = row.getDouble(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, value)
                case DecimalType() ⇒
                  val value = row.getDecimal(row.fieldIndex(columnName))
                  unsafeRowWriter.writer.write(idx, Decimal(value), value.precision(), value.scale())
                case x ⇒ error(s"Column $columnName has unsupported type: $x")
              }
          }

          unsafeRowWriter.getRow
        }
    )(s"$name - conversion of rows to unsafe rows")

    val broadcastVer = measure(spark.sparkContext.broadcast(data))(s"$name - broadcasting lookup data")

    val schema       = StructType(minColumnType :: maxColumnType :: rowSchema.toList)
    val deserializer = new RowEncoderCrossSupport(schema)

    def findValue[Type](x: Type)(implicit ord: Ordering[Type]) =
      broadcastVer.value
        .map(unsafeRow ⇒ deserializer.deserialize(unsafeRow))
        .find(row ⇒ ord.lteq(row.get(0).asInstanceOf[Type], x) && ord.gt(row.get(1).asInstanceOf[Type], x))
        .map(row ⇒ Row.apply(row.toSeq.drop(2): _*))
        .getOrElse(Row.apply(rowSchema.map(_ ⇒ null): _*))

    val rangeUdf = minColumnType.dataType match {
      case StringType    ⇒ udf((x: String) ⇒ findValue(x),               StructType(rowSchema))
      case IntegerType   ⇒ udf((x: Int) ⇒ findValue(x),                  StructType(rowSchema))
      case LongType      ⇒ udf((x: Long) ⇒ findValue(x),                 StructType(rowSchema))
      case DecimalType() ⇒ udf((x: java.math.BigDecimal) ⇒ findValue(x), StructType(rowSchema))
      case FloatType     ⇒ udf((x: Float) ⇒ findValue(x),                StructType(rowSchema))
      case ShortType     ⇒ udf((x: Short) ⇒ findValue(x),                StructType(rowSchema))
      case DoubleType    ⇒ udf((x: Double) ⇒ findValue(x),               StructType(rowSchema))
      case _             ⇒ error(s"Unsupported type of min-column $minColumn: ${minColumnType.dataType}")
    }

    spark.udf.register(name, rangeUdf)
  }

  def measure[T](fn: ⇒ T)(caller: String = findCaller()): T = {
    val start  = System.nanoTime()
    val result = fn
    logger.info(s"$caller: %.6f seconds".format((System.nanoTime() - start) / 1e9))
    result
  }

  private def findCaller(): String = {
    val callerThread       = Thread.getAllStackTraces.get(Thread.currentThread())
    val callerStackElement = callerThread(4)
    s"(${callerStackElement.getFileName}:${callerStackElement.getLineNumber})"
  }

  implicit private class IterableIntExtension(iterable: Iterable[Int]) {
    def maxOrZero: Int = if (iterable.nonEmpty) iterable.max else 0
  }

  /**
    * Function registers 4 different UDFs with spark registry. UDF for lookup_match, lookup_count,
    * lookup_row and lookup functions are registered. This function stores the data of input dataframe in
    * a broadcast variable, then uses this broadcast variable in different lookup functions.
    *
    * lookup : This function returns the first matching row for given input keys
    * lookup_count : This function returns the count of all matching rows for given input keys.
    * lookup_match : This function returns 0 if there is no matching row and 1 for some matching rows for given input keys.
    * lookup_row : This function returns all the matching rows for given input keys.
    *
    * This function registers for upto 10 matching keys as input to these lookup functions.
    *
    * @param name UDF Name
    * @param df input dataframe
    * @param spark spark session
    * @param keyCols columns to be used as keys in lookup functions.
    * @param rowCols schema of entire row which will be stored for each matching key.
    * @return registered UDF definitions for lookup functions. These UDF functions returns different results depending
    *         on the lookup function.
    */
  def createLookup(
    name:    String,
    df:      DataFrame,
    spark:   SparkSession,
    keyCols: List[String],
    rowCols: String*
  ): UserDefinedFunction = {
    val rowFields = df.schema
      .filter { x ⇒
        rowCols.contains(x.name)
      }
      .map(x ⇒ (x.name, x))
      .toMap
    val rowSchema = StructType(rowCols.map(x ⇒ rowFields(x)))

    val keyFields = df.schema
      .filter { x ⇒
        keyCols.contains(x.name)
      }
      .map(x ⇒ (x.name, x))
      .toMap
    val keySchema = StructType(keyCols.map(x ⇒ keyFields(x)))

    val fullRowSchema = StructType(
      List(
        StructField("key", keySchema, false),
        StructField("row", rowSchema, false)
      )
    )
    val tmpDataMap = mutable.Map[List[Any], mutable.ArrayBuffer[Row]]()

    val data = df
      .collect()
      .map { x ⇒
        new GenericRowWithSchema(
          Array(
            new GenericRowWithSchema(
              keyCols.map { y ⇒
                x.getAs[Any](y)
              }.toArray,
              StructType(keySchema)
            ): Row,
            new GenericRowWithSchema(
              rowCols.map { y ⇒
                x.getAs[Any](y)
              }.toArray,
              StructType(rowSchema)
            ): Row
          ),
          fullRowSchema
        ): Row
      }
      .toList
    data.foreach { row ⇒
      val keyPartOfRow = row.getStruct(row.fieldIndex("key"))
      val value        = row.getStruct(row.fieldIndex("row")).toSeq
      val mapKey       = keyCols.map(keyPartOfRow.getAs[Any])
      val tmpList      = tmpDataMap.getOrElse(mapKey, mutable.ArrayBuffer[Row]())
      tmpList += Row.fromSeq(value)
      tmpDataMap.put(mapKey, tmpList)
    }

    // .mapValues in scala is not strict and not serializable, therefore it's required to run an identity map afterwards
    // it's a scala bug, as described here: https://github.com/scala/bug/issues/7005
    val dataMap = tmpDataMap.mapValues(_.toArray).map(identity)

    val entriesCount      = dataMap.map(_._2.length).sum
    val keyColumnsCount   = dataMap.map(_._1.size).maxOrZero
    val valueColumnsCount = dataMap.map(_._2.map(_.length).toIterable.maxOrZero).maxOrZero
    logger.info(
      s"Created lookup $name with $entriesCount entries, " +
        s"$keyColumnsCount key columns, and $valueColumnsCount value columns."
    )

    val (
      lookupUDF,
      lookup_lastUDF,
      lookup_matchUDF,
      lookup_countUDF,
      lookup_rowUDF,
      lookup_row_reverseUDF,
      lookup_nthUDF
    ) = {
      val broadcastVerMap = spark.sparkContext.broadcast(dataMap)

      keyCols.length match {
        case 0 ⇒ throw new Exception("lookup udf can't have 0 arguments.")
        case 1 ⇒
          (
            udf(
              (x: Any) ⇒ broadcastVerMap.value.getOrElse(List(x), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (x: Any) ⇒ broadcastVerMap.value.getOrElse(List(x), Array.empty).lastOption,
              rowSchema
            ),
            udf { (x: Any) ⇒
              broadcastVerMap.value.contains(List(x))
            },
            udf { (x: Any) ⇒
              broadcastVerMap.value.getOrElse(List(x), Array.empty).length
            },
            udf(
              (x: Any) ⇒ broadcastVerMap.value.getOrElse(List(x), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (x: Any) ⇒ broadcastVerMap.value.getOrElse(List(x), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (x: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(x), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 2 ⇒
          (
            udf(
              (x: Any, y: Any) ⇒ broadcastVerMap.value.getOrElse(List(x, y), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (x: Any, y: Any) ⇒ broadcastVerMap.value.getOrElse(List(x, y), Array.empty).lastOption,
              rowSchema
            ),
            udf { (x: Any, y: Any) ⇒
              broadcastVerMap.value.contains(List(x, y))
            },
            udf { (x: Any, y: Any) ⇒
              broadcastVerMap.value.getOrElse(List(x, y), Array.empty).length
            },
            udf(
              (x: Any, y: Any) ⇒ broadcastVerMap.value.getOrElse(List(x, y), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (x: Any, y: Any) ⇒ broadcastVerMap.value.getOrElse(List(x, y), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (x: Any, y: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(x, y), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 3 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty).lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any) ⇒ broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 4 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty).lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 5 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty).lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 6 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty).lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, index: Int) ⇒ {
                val result = broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 7 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty).headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty).lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty).reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, index: Int) ⇒ {
                val result =
                  broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 8 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty)
                  .headOption,
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty)
                  .lastOption,
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
              broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty).length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
                broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty)
                  .reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, index: Int) ⇒ {
                val result =
                  broadcastVerMap.value.getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 9 ⇒
          (
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒ {
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty)
                  .headOption
              },
              rowSchema
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒ {
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty)
                  .lastOption
              },
              rowSchema
            ),
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒
              broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8))
            },
            udf { (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒
              broadcastVerMap.value
                .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty)
                .length
            },
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (tmp0: Any, tmp1: Any, tmp2: Any, tmp3: Any, tmp4: Any, tmp5: Any, tmp6: Any, tmp7: Any, tmp8: Any) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty)
                  .reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (
                tmp0:  Any,
                tmp1:  Any,
                tmp2:  Any,
                tmp3:  Any,
                tmp4:  Any,
                tmp5:  Any,
                tmp6:  Any,
                tmp7:  Any,
                tmp8:  Any,
                index: Int
              ) ⇒ {
                val result =
                  broadcastVerMap.value
                    .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
        case 10 ⇒
          (
            udf(
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒ {
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty)
                  .headOption
              },
              rowSchema
            ),
            udf(
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒ {
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty)
                  .lastOption
              },
              rowSchema
            ),
            udf {
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒
                broadcastVerMap.value.contains(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9))
            },
            udf {
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty)
                  .length
            },
            udf(
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty),
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (
                tmp0: Any,
                tmp1: Any,
                tmp2: Any,
                tmp3: Any,
                tmp4: Any,
                tmp5: Any,
                tmp6: Any,
                tmp7: Any,
                tmp8: Any,
                tmp9: Any
              ) ⇒
                broadcastVerMap.value
                  .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty)
                  .reverse,
              DataTypes.createArrayType(rowSchema, false)
            ),
            udf(
              (
                tmp0:  Any,
                tmp1:  Any,
                tmp2:  Any,
                tmp3:  Any,
                tmp4:  Any,
                tmp5:  Any,
                tmp6:  Any,
                tmp7:  Any,
                tmp8:  Any,
                tmp9:  Any,
                index: Int
              ) ⇒ {
                val result =
                  broadcastVerMap.value
                    .getOrElse(List(tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9), Array.empty)
                if (result.length > index) Some(result(index)) else None
              },
              rowSchema
            )
          )
      }
    }

    if (
      lookupUDF != null && lookup_matchUDF != null && lookup_countUDF != null && lookup_rowUDF != null && lookup_nthUDF != null && lookup_row_reverseUDF != null && lookup_lastUDF != null
    ) {
      spark.udf.register(name,                  lookupUDF)
      spark.udf.register(name + "_last",        lookup_lastUDF)
      spark.udf.register(name + "_match",       lookup_matchUDF)
      spark.udf.register(name + "_count",       lookup_countUDF)
      spark.udf.register(name + "_row",         lookup_rowUDF)
      spark.udf.register(name + "_row_reverse", lookup_row_reverseUDF)
      spark.udf.register(name + "_nth",         lookup_nthUDF)

      // Deprecated functions:
      spark.udf.register(name + "_lookup_last",        lookup_lastUDF)
      spark.udf.register(name + "_lookup_match",       lookup_matchUDF)
      spark.udf.register(name + "_lookup_count",       lookup_countUDF)
      spark.udf.register(name + "_lookup_row",         lookup_rowUDF)
      spark.udf.register(name + "_lookup_row_reverse", lookup_row_reverseUDF)
      spark.udf.register(name + "_lookup_nth",         lookup_nthUDF)
    } else {
      null
    }
  }

  /**
    * By default returns only the first matching record
    */
  def lookup(lookupName: String, cols: Column*): Column =
    callUDF(lookupName, cols: _*)

  /**
    * Returns the last matching record
    * @param lookupName
    * @param cols
    * @return
    */
  def lookup_last(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_last", cols: _*)

  /**
    * @param lookupName
    * @return Boolean Column
    */
  def lookup_match(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_match", cols: _*)

  def lookup_count(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_count", cols: _*)

  def lookup_row(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_row", cols: _*)

  def lookup_row_reverse(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_row_reverse", cols: _*)

  def lookup_nth(lookupName: String, cols: Column*): Column =
    callUDF(lookupName + "_nth", cols: _*)

  def lookup_range(lookupName: String, input: Column): Column =
    callUDF(lookupName, input)

  def registerProphecyUdfs(spark: SparkSession): Unit =
    registerRestAPIUdfs(spark)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy