All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.stoys.spark.db.DbLoader.scala Maven / Gradle / Ivy

package io.stoys.spark.db

import io.stoys.scala.{Configuration, IO, Reflection, Strings}
import io.stoys.spark._
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SaveMode}

import java.nio.file.{Files, Paths}
import java.sql.{Connection, DriverManager}
import java.time.format.DateTimeFormatter
import java.time.{Instant, ZoneOffset}
import scala.reflect.runtime.universe._
import scala.util.{Failure, Success}

class DbLoader(args: Array[String]) {
  private val logger = org.log4s.getLogger

  private val configuration = Configuration(args)
  private val sparkConfig = configuration.readConfig[SparkConfig]
  private val sparkIOConfig = configuration.readConfig[SparkIOConfig]
  private val config = configuration.readConfig[DbLoaderConfig]

  private val explicitTimestamp = Strings.trim(config.timestamp).map(DbLoader.TIMESTAMP_FORMATTER.parse)
  private val timestampInEpochS = explicitTimestamp.getOrElse(Instant.now())
  private val timestamp = DbLoader.TIMESTAMP_FORMATTER.format(timestampInEpochS)
  private val timestampParam = Map("timestamp" -> timestamp)
  require(Strings.trim(config.schemaName).isDefined, "Please specify schema_name!")
  private val schemaName = Strings.replaceParams(config.schemaName, params = timestampParam)
  private val params = timestampParam ++ Map("schema_name" -> schemaName)

  private val jdbcOptions = replaceParams(config.jdbcOptions, JDBCOptions.JDBC_SESSION_INIT_STATEMENT)
  private val writeOptions = replaceParams(config.sparkWriteOptions, JDBCOptions.JDBC_SESSION_INIT_STATEMENT)
  private val jdbcProperties = new java.util.Properties()
  jdbcOptions.foreach(kv => jdbcProperties.put(kv._1, kv._2))
  config.jdbcUser.map(v => jdbcProperties.put("user", v))
  config.jdbcPassword.map(v => jdbcProperties.put("password", v))

  private val executedSqlStatements = Seq.newBuilder[String]

  def run(): Unit = {
    val sparkSession = SparkUtils.createSparkSession(sparkConfig)
    val sparkIO = new SparkIO(sparkSession, sparkIOConfig)

    IO.using(DriverManager.getConnection(config.jdbcUrl, jdbcProperties)) { connection =>
      runDbSqlFile(connection, config.beforeLoadScript, params = params)
      if (!config.disableSchemaCreation) {
        logger.info(s"Creating schema $schemaName.")
        runDbSql(connection, s"CREATE SCHEMA $schemaName")
      }
      val tableNameLookup = new TableNameLookup(lookupClasses(config.caseClassNames))
      config.tableNames.foreach { fullTableName =>
        tableNameLookup.lookupEntityTableName(fullTableName) match {
          case Some(tableName) => writeTable(tableName, sparkIO, connection)
          case None => logger.error(s"Unable to lookup TableName for '$fullTableName'")
        }
      }
      runDbSqlFile(connection, config.afterLoadScript, params = params)
    }

    if (config.executedSqlOutputFile.isDefined) {
      val fileName = config.executedSqlOutputFile.get
      logger.info(s"Writing executed sql statements to $fileName.")
      Files.write(Paths.get(fileName), executedSqlStatements.result().mkString("", ";\n\n", ";\n").getBytes())
    }
  }

  def writeTable[T <: Product](tableName: TableName[T], sparkIO: SparkIO, connection: Connection): Unit = {
    val qualifiedTableName = JdbcReflection.getQualifiedTableName[T](tableName, schemaName)

    if (!config.disableTableCreation) {
      logger.info(s"Creating table $qualifiedTableName.")
      runDbSql(connection, JdbcReflection.getCreateTableStatement[T](tableName, schemaName))
    }

    if (config.limit.isEmpty || config.limit.getOrElse(0) > 0) {
      logger.info(s"Writing $qualifiedTableName:")
      val table = normalizeTable[T](sparkIO.df(tableName))(tableName.typeTag)
      val limited = config.limit.map(table.limit).getOrElse(table)
      limited.write.mode(SaveMode.Append).options(writeOptions).jdbc(config.jdbcUrl, qualifiedTableName, jdbcProperties)
    }

    if (!config.disableConstrainCreation) {
      logger.info(s"Creating constrains for table $qualifiedTableName.")
      val constraints = JdbcReflection.getAddConstraintStatements[T](tableName, schemaName)
      constraints.foreach(constraint => runDbSql(connection, constraint))
    }
  }

  def replaceParams(options: Map[String, String], keys: String*): Map[String, String] = {
    val overrides = options.flatMap {
      case (key, value) if keys.contains(key) => Some(key -> Strings.replaceParams(value, params))
      case _ => None
    }
    options ++ overrides
  }

  def runDbSqlFile(connection: Connection, fileName: String, params: Map[String, Any] = Map.empty): Unit = {
    if (Strings.trim(fileName).isDefined) {
      logger.info(s"Running sql script $fileName.")
      val sql = IO.resourceToString(this.getClass, fileName)
      runDbSql(connection, sql, params)
    }
  }

  def runDbSql(connection: Connection, sql: String, params: Map[String, Any] = Map.empty): Unit = {
    // TODO: Can we do proper splitting and comments removing which actually understand sql language?
    val rawSqlStatements = sql.split(";").map(l => Strings.unsafeRemoveLineComments(l, "--")).flatMap(Strings.trim)
    IO.using(connection.createStatement()) { statement =>
      rawSqlStatements.foreach { rawSqlStatement =>
        val sqlStatement = Strings.replaceParams(rawSqlStatement, params)
        statement.execute(sqlStatement)
        executedSqlStatements += sqlStatement
      }
    }
  }

  def normalizeTable[T: TypeTag](df: DataFrame): DataFrame = {
    val expectedColumnNames = Reflection.getCaseClassFieldNames[T]
    val normalizedColumns = df.schema.filter(field => expectedColumnNames.contains(field.name)).map { field =>
      val normalizedName = Strings.toSnakeCase(field.name)
      val normalizedColumn = field.dataType match {
        case _: ArrayType | _: MapType | _: StructType => to_json(col(field.name))
        case _ => col(field.name)
      }
      normalizedColumn.as(normalizedName)
    }
    df.select(normalizedColumns: _*)
  }

  private def lookupClasses[T](classNames: Iterable[String]): Iterable[Class[T]] = {
    classNames.flatMap { className =>
      scala.util.Try(Class.forName(className)) match {
        case Success(clazz) =>
          Some(clazz.asInstanceOf[Class[T]])
        case Failure(e) =>
          logger.error(e)(s"Unable to find class $className on class path!")
          None
      }
    }
  }
}

object DbLoader {
  val TIMESTAMP_FORMATTER: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmss").withZone(ZoneOffset.UTC)

  def main(args: Array[String]): Unit = {
    val dbLoader = new DbLoader(args)
    dbLoader.run()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy