All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sparkplug.models.PlugRule.scala Maven / Gradle / Ivy

package sparkplug.models

import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.util.{Success, Try}

case class PlugAction(key: String, value: String) {
  val updateKey = key.split('.').head
}

case class PlugActionConverted(key: String, value: Any)

case class PlugRule(name: String, version: String, condition: String, actions: Seq[PlugAction]) {

  lazy val fieldNames =
    actions.map(x => s""""${x.key}"""").mkString("array(", ",", ")")

  private val defaultValidation: ((Seq[PlugAction], StructType)) => List[PlugRuleValidationError] = _ => List.empty

  private val actionsNotEmpty: PartialFunction[(Seq[PlugAction], StructType), List[PlugRuleValidationError]] = {
    case (as, _) if as.isEmpty =>
      List(validationError("At the least one action must be specified per rule."))
  }

  private val actionFieldsPresentInSchema
    : PartialFunction[(Seq[PlugAction], StructType), List[PlugRuleValidationError]] = {
    case (as, schema) =>
      val fields = buildFieldsMap(schema).toMap

      as.flatMap((action: PlugAction) => {
          fields.get(action.key) match {
            case None =>
              Some(validationError(s"""Field "${action.key}" not found in the schema."""))
            case Some(x) if convertActionValueTo(action.value, x).isFailure =>
              Some(validationError(s"""Value "${action.value}" cannot be assigned to field ${action.key}."""))
            case _ =>
              None
          }
        })
        .toList
  }

  def validate(schema: StructType): List[PlugRuleValidationError] =
    List(actionsNotEmpty, actionFieldsPresentInSchema).flatMap(_.applyOrElse((actions, schema), defaultValidation))

  def asSql(schema: StructType, plugDetailsColumn: Option[String]) = {
    val notEqualsBuilder = new ListBuffer[String]
    val builder          = new StringBuilder

    val convertedActions = convertActions(schema)
    convertedActions.zipWithIndex.foldLeft(builder) {
      case (b, (action, i)) =>
        val actionKey   = action.key
        val actionValue = action.value
        notEqualsBuilder.append(s"not($actionKey <=> $actionValue)")
        val builderWithField = updateField(schema, b, actionKey, actionValue)
        if (i < convertedActions.length - 1) {
          builderWithField.append(",")
        } else {
          builderWithField
        }
    }

    if (plugDetailsColumn.nonEmpty) {
      val notEqualCondition = s"(${notEqualsBuilder.mkString(" or ")})"
      builder.append(
        s""", if($condition and $notEqualCondition, addPlugDetail(${plugDetailsColumn.get}, "$name", "$version", $fieldNames), ${plugDetailsColumn.get}) as ${updatedPlugDetailsColumn(
          plugDetailsColumn.get
        )}"""
      )
    }

    builder.mkString
  }

  def withColumnsRenamed(dataset: Dataset[Row], plugDetailsColumn: Option[String], isKeepOldField: Boolean) = {
    val pluggedDf = actions
      .foldLeft(dataset)((overridden, action) => {
        val modified = overridden
          .withColumnRenamed(action.updateKey, oldKey(action.updateKey))
          .withColumnRenamed(newKey(action.updateKey), action.updateKey)

        if (!isKeepOldField)
          modified.drop(oldKey(action.key))
        else
          modified
      })

    plugDetailsColumn.fold(pluggedDf) { c =>
      pluggedDf
        .drop(c)
        .withColumnRenamed(updatedPlugDetailsColumn(c), c)
    }
  }

  private def updatedPlugDetailsColumn(plugDetailsColumn: String) =
    s"${plugDetailsColumn}_updated"

  private def updateField(schema: StructType, builder: mutable.StringBuilder, actionKey: String, actionValue: Any) =
    actionKey match {
      case x if x.contains('.') =>
        val Array(parent, child) = x.split('.')
        val parentSchema = schema
          .filter((field: StructField) => field.name == parent)
          .head
          .dataType
          .asInstanceOf[StructType]
        builder.append(s"if($parent is null, null, named_struct(")
        val namedStruct = parentSchema.fields
          .map((field: StructField) => {
            if (field.name == child)
              s"'${field.name}', if($condition, $actionValue, $actionKey)"
            else
              s"'${field.name}',$parent.${field.name}"
          })
          .mkString(",")
        builder.append(namedStruct)
        builder.append(s")) as ${newKey(parent)}")
      case _ =>
        builder.append(s"if($condition, $actionValue, $actionKey) as ${newKey(actionKey)}")
    }

  private def convertActions(schema: StructType) = {
    val fields = buildFieldsMap(schema).toMap
    actions
      .map(x => PlugActionConverted(x.key, convertActionValueTo(x.value, fields(x.key)).getOrElse(null)))
  }

  private def convertActionValueTo(actionValue: String, dataType: DataType) =
    if (actionValue.contains('`')) {
      Success(actionValue.replace("`", ""))
    } else {
      Try(dataType match {
        case IntegerType => actionValue.toInt
        case DoubleType  => s"cast(${actionValue.toDouble} as double)"
        case StringType  => s""""${actionValue.toString}""""
      })
    }

  private def validationError(message: String) =
    PlugRuleValidationError(name, message)

  private def buildFieldsMap(schema: StructType, prefix: String = ""): Seq[(String, DataType)] =
    schema.fields.flatMap {
      case f if f.dataType.isInstanceOf[StructType] =>
        (s"$prefix${f.name}", f.dataType) +: buildFieldsMap(f.dataType.asInstanceOf[StructType], s"$prefix${f.name}.")
      case f => Seq((s"$prefix${f.name}", f.dataType))
    }

  private def oldKey(actionKey: String) = s"${actionKey}_${name}_old"

  private def newKey(actionKey: String) = s"${actionKey}_new"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy