All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.starlake.utils.Job.scala Maven / Gradle / Ivy

package ai.starlake.utils

import ai.starlake.config.{Settings, SparkEnv, UdfRegistration}
import com.google.gson.Gson
import com.typesafe.scalalogging.StrictLogging
import org.apache.spark.SparkConf
import org.apache.spark.sql._

import java.io.{ByteArrayOutputStream, PrintStream}
import scala.jdk.CollectionConverters._
import scala.util.Try

case class IngestionCounters(inputCount: Long, acceptedCount: Long, rejectedCount: Long)

trait JobResult {
  def asMap(): List[Map[String, Any]] = Nil
  def prettyPrint(format: String, dryRun: Boolean = false): String = ""
  def prettyPrint(
    format: String,
    headers: List[String],
    values: List[List[String]]
  ): String = {
    val baos = new ByteArrayOutputStream()
    val printStream = new PrintStream(baos)

    format match {
      case "csv" =>
        (headers :: values).foreach { row =>
          printStream.println(row.mkString(","))
        }

      case "table" =>
        headers :: values match {
          case Nil =>
            printStream.println("Result is empty.")
          case _ =>
            printStream.println(TableFormatter.format(headers :: values))
        }

      case "json" =>
        val res = values.foreach { value =>
          val map = headers.zip(value).toMap.asJava
          val json = new Gson().toJson(map)
          printStream.println(json)
        }

      case "json-array" =>
        val res = values.map { value =>
          val map = headers.zip(value).toMap.asJava
          map
        }
        val json = new Gson().toJson(res.asJava)
        printStream.println(json)
    }
    baos.toString
  }
}
/*
          val values = rows.iterateAll().asScala.toList.map { row =>
            val fields = row
              .iterator()
              .asScala
              .toList
            asMap(fields, headers)

 */
case class SparkJobResult(
  dataframe: Option[DataFrame],
  counters: Option[IngestionCounters]
) extends JobResult {
  override def asMap(): List[Map[String, Any]] = {
    dataframe
      .map { dataFrame =>
        val headers = dataFrame.schema.fields.map(_.name).toList
        val dataAsList = dataFrame
          .collect()
          .map { row =>
            val fields = row.toSeq.map(Option(_).map(_.toString).getOrElse("NULL")).toList
            headers.zip(fields).toMap
          }
          .toList
        dataAsList
      }
      .getOrElse(Nil)
  }
  override def prettyPrint(format: String, dryRun: Boolean = false): String = {
    dataframe
      .map { dataFrame =>
        val dataAsList = dataFrame
          .collect()
          .map(_.toSeq.map(Option(_).map(_.toString).getOrElse("NULL")).toList)
          .toList
        val headers = dataFrame.schema.fields.map(_.name).toList
        prettyPrint(format, headers, dataAsList)
      }
      .getOrElse("")
  }

}

case class JdbcJobResult(headers: List[String], rows: List[List[String]] = Nil) extends JobResult {
  override def prettyPrint(format: String, dryRun: Boolean = false): String = {
    prettyPrint(format, headers, rows)
  }

  override def asMap(): List[Map[String, Any]] = {
    rows.map { value => headers.zip(value).toMap }

  }

  def show(format: String): Unit = {
    val result = prettyPrint(format, headers, rows)
    println(result)
  }
}
object JobResult {
  def empty: JobResult = EmptyJobResult
}
case object EmptyJobResult extends JobResult
case object FailedJobResult extends JobResult

/** All Spark Job extend this trait. Build Spark session using spark variables from
  * application.conf.
  */

trait JobBase extends StrictLogging with DatasetLogging {
  def name: String
  implicit def settings: Settings

  val appName =
    Option(System.getenv("SL_JOB_ID"))
      .orElse(settings.appConfig.jobIdEnvName.flatMap(e => Option(System.getenv(e))))
      .getOrElse(s"$name-${System.currentTimeMillis()}")

  def applicationId(): String = appName

  /** Just to force any job to implement its entry point using within the "run" method
    *
    * @return
    *   : Spark Dataframe for Spark Jobs None otherwise
    */
  def run(): Try[JobResult]

}

/** All Spark Job extend this trait. Build Spark session using spark variables from
  * application.conf. Make sure all variables are lazy since we do not want to build a spark session
  * for any of the other services
  */
trait SparkJob extends JobBase {

  protected def withExtraSparkConf(sourceConfig: SparkConf): SparkConf = {
    // During Job execution, schema update are done on the table before data is written
    // These two options below are thus disabled.
    // We disable them because even though the user asked for WRITE_APPEND
    // On merge, we write in WRITE_TRUNCATE mode.
    // Moreover, since we handle schema validaty through the YAML file, we manage these settings automatically
    sourceConfig.remove("spark.datasource.bigquery.allowFieldAddition")
    sourceConfig.remove("spark.datasource.bigquery.allowFieldRelaxation")
    settings.storageHandler().extraConf.foreach { case (k, v) =>
      sourceConfig.set("spark.hadoop." + k, v)
    }

    val thisConf = sourceConfig.setAppName(appName).set("spark.app.id", appName)
    logger.whenDebugEnabled {
      logger.debug(thisConf.toDebugString)
    }
    thisConf
  }

  private lazy val sparkEnv: SparkEnv = SparkEnv.get(name, withExtraSparkConf)

  def getTableLocation(domain: String, schema: String): String = {
    getTableLocation(s"$domain.$schema")
  }

  def getTableLocation(fullTableName: String): String = {
    import session.implicits._
    session
      .sql(s"desc formatted $fullTableName")
      .toDF()
      .filter(Symbol("col_name") === "Location")
      .collect()(0)(1)
      .toString
  }

  protected def registerUdf(udf: String): Unit = {
    val udfInstance: UdfRegistration =
      Class
        .forName(udf)
        .getDeclaredConstructor()
        .newInstance()
        .asInstanceOf[UdfRegistration]
    udfInstance.register(sparkEnv.session)
  }

  lazy val session: SparkSession = {
    val udfs = settings.appConfig.getEffectiveUdfs()
    udfs.foreach(registerUdf)
    sparkEnv.session
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy