All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.chronon.spark.LogFlattenerJob.scala Maven / Gradle / Ivy

There is a newer version: 13_publish-0.0.12
Show newest version
package ai.chronon.spark

import ai.chronon.api
import ai.chronon.api.Extensions._
import ai.chronon.api._
import ai.chronon.online._
import ai.chronon.spark.Extensions.StructTypeOps
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import java.util.Base64
import scala.collection.mutable
import scala.collection.Seq
import scala.util.ScalaJavaConversions.{IterableOps, MapOps}
import scala.util.{Failure, Success, Try}

/**
  * Purpose of LogFlattenerJob is to unpack serialized Avro data from online requests and flatten each field
  * (both keys and values) into individual columns and save to an offline "flattened" log table.
  *
  * Steps:
  * 1. determine unfilled range and pull raw logs from partitioned log table
  * 2. fetch joinCodecs for all unique schema_hash present in the logs
  * 3. build a merged schema from all schema versions, which will be used as output schema
  * 4. unpack each row and adhere to the output schema
  * 5. save the schema info in the flattened log table properties (cumulatively)
  */
class LogFlattenerJob(session: SparkSession,
                      joinConf: api.Join,
                      endDate: String,
                      logTable: String,
                      schemaTable: String,
                      stepDays: Option[Int] = None)
    extends Serializable {
  implicit val tableUtils: TableUtils = TableUtils(session)
  val joinTblProps: Map[String, String] = Option(joinConf.metaData.tableProperties)
    .map(_.toScala)
    .getOrElse(Map.empty[String, String])
  val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.JoinLogFlatten, joinConf)

  private def getUnfilledRanges(inputTable: String, outputTable: String): Seq[PartitionRange] = {
    val partitionName: String = joinConf.metaData.nameToFilePath.replace("/", "%2F")
    val unfilledRangeTry = Try(
      tableUtils.unfilledRanges(
        outputTable,
        PartitionRange(null, endDate),
        Some(Seq(inputTable)),
        Map(inputTable -> Map("name" -> partitionName))
      )
    )

    val ranges = unfilledRangeTry match {
      case Failure(_: AssertionError) => {
        println(s"""
             |The join name ${joinConf.metaData.nameToFilePath} does not have available logged data yet.
             |Please double check your logging status""".stripMargin)
        Seq()
      }
      case Success(None) => {
        println(
          s"$outputTable seems to be caught up - to either " +
            s"$inputTable(latest ${tableUtils.lastAvailablePartition(inputTable)}) or $endDate.")
        Seq()
      }
      case Success(Some(partitionRange)) => {
        partitionRange
      }
      case Failure(otherException) => throw otherException
    }

    if (stepDays.isEmpty) {
      ranges
    } else {
      ranges.flatMap(_.steps(stepDays.get))
    }
  }

  // remove duplicate fields and raise if two fields have the same name but different data types.
  // output order of the fields is based on the first appearance of each field in the iterable
  private def dedupeFields(fields: Iterable[StructField]): Iterable[StructField] = {
    val fieldsBuilder = mutable.LinkedHashMap[String, DataType]()
    fields.foreach { f =>
      {
        if (fieldsBuilder.contains(f.name)) {
          if (fieldsBuilder(f.name) != f.fieldType) {
            throw new Exception(
              s"Found field with same name ${f.name} but different dataTypes: ${fieldsBuilder(f.name)} vs ${f.fieldType}")
          }
        } else {
          fieldsBuilder.put(f.name, f.fieldType)
        }
      }
    }
    fieldsBuilder.map(f => StructField(f._1, f._2))
  }

  private def flattenKeyValueBytes(rawDf: Dataset[Row], schemaMap: Map[String, LoggingSchema]): DataFrame = {

    val allDataFields = dedupeFields(schemaMap.values.flatMap(_.keyFields) ++ schemaMap.values.flatMap(_.valueFields))
    // contextual features are logged twice in keys and values, where values are prefixed with ext_contextual
    // here we exclude the duplicated fields as the two are always identical
    val dataFields = allDataFields.filterNot(_.name.startsWith(Constants.ContextualPrefix))
    val metadataFields = StructField(Constants.SchemaHash, StringType) +: JoinCodec.timeFields
    val outputSchema = StructType("", metadataFields ++ dataFields)
    val (keyBase64Idx, valueBase64Idx, tsIdx, dsIdx, schemaHashIdx) = (0, 1, 2, 3, 4)
    val outputRdd: RDD[Row] = rawDf
      .select("key_base64", "value_base64", "ts_millis", "ds", Constants.SchemaHash)
      .rdd
      .flatMap { row =>
        if (row.isNullAt(schemaHashIdx)) {
          // ignore older logs that do not have schema_hash info
          None
        } else {
          val joinCodec = schemaMap(row.getString(schemaHashIdx))
          val keyBytes = Base64.getDecoder.decode(row.getString(keyBase64Idx))
          val valueBytes = Base64.getDecoder.decode(row.getString(valueBase64Idx))
          val keyRow = Try(joinCodec.keyCodec.decodeRow(keyBytes))
          val valueRow = Try(joinCodec.valueCodec.decodeRow(valueBytes))

          if (keyRow.isFailure || valueRow.isFailure) {
            metrics.increment(Metrics.Name.Exception)
            None
          } else {
            val dataColumns = dataFields.map { field =>
              val keyIdxOpt = joinCodec.keyIndices.get(field)
              val valIdxOpt = joinCodec.valueIndices.get(field)
              if (keyIdxOpt.isDefined) {
                keyRow.toOption.map(_.apply(keyIdxOpt.get)).orNull
              } else if (valIdxOpt.isDefined) {
                valueRow.toOption.map(_.apply(valIdxOpt.get)).orNull
              } else {
                null
              }
            }.toArray

            val metadataColumns = Array(row.get(schemaHashIdx), row.get(tsIdx), row.get(dsIdx))
            val outputRow = metadataColumns ++ dataColumns
            val unpackedRow = SparkConversions.toSparkRow(outputRow, outputSchema).asInstanceOf[GenericRow]
            Some(unpackedRow)
          }
        }
      }

    val outputSparkSchema = SparkConversions.fromChrononSchema(outputSchema)
    session.createDataFrame(outputRdd, outputSparkSchema)
  }

  private def fetchSchemas(hashes: Seq[String]): Map[String, String] = {
    val schemaTableDs = tableUtils.lastAvailablePartition(schemaTable)
    if (schemaTableDs.isEmpty) {
      throw new Exception(s"$schemaTable has no partitions available!")
    }

    session
      .table(schemaTable)
      .where(col(tableUtils.partitionColumn) === schemaTableDs.get)
      .where(col(Constants.SchemaHash).isin(hashes.toSeq: _*))
      .select(
        col(Constants.SchemaHash),
        col("schema_value_last").as("schema_value")
      )
      .collect()
      .map(row => (row.getString(0), row.getString(1)))
      .toMap
  }

  def buildTableProperties(schemaMap: Map[String, String]): Map[String, String] = {
    def escape(str: String): String = str.replace("""\""", """\\""")
    (LogFlattenerJob.readSchemaTableProperties(tableUtils, joinConf.metaData.loggedTable) ++ schemaMap)
      .map {
        case (key, value) => (escape(s"${Constants.SchemaHash}_$key"), escape(value))
      }
  }

  private def columnCount(): Int = {
    Try(tableUtils.getSchemaFromTable(joinConf.metaData.loggedTable).fields.length).toOption.getOrElse(0)
  }

  def buildLogTable(): Unit = {
    if (!joinConf.metaData.isSetSamplePercent) {
      println(s"samplePercent is unset for ${joinConf.metaData.name}. Exit.")
      return
    }
    val unfilledRanges = getUnfilledRanges(logTable, joinConf.metaData.loggedTable)
    if (unfilledRanges.isEmpty) return
    val joinName = joinConf.metaData.nameToFilePath

    val start = System.currentTimeMillis()
    val columnBeforeCount = columnCount()
    val rowCounts = unfilledRanges.map { unfilled =>
      val rawTableScan = unfilled.genScanQuery(null, logTable)
      val rawDf = tableUtils.sql(rawTableScan).where(col("name") === joinName)
      val schemaHashes = rawDf.select(col(Constants.SchemaHash)).distinct().collect().map(_.getString(0)).toSeq
      val schemaStringsMap = fetchSchemas(schemaHashes)

      // we do not have exact joinConf at time of logging, and since it is not used during flattening, we pass in null
      val schemaMap = schemaStringsMap.mapValues(LoggingSchema.parseLoggingSchema).map(identity).toMap
      val flattenedDf = flattenKeyValueBytes(rawDf, schemaMap)

      val schemaTblProps = buildTableProperties(schemaStringsMap)
      println("======= Log table schema =======")
      println(flattenedDf.schema.pretty)
      tableUtils.insertPartitions(flattenedDf,
                                  joinConf.metaData.loggedTable,
                                  tableProperties =
                                    joinTblProps ++ schemaTblProps ++ Map(Constants.ChrononLogTable -> true.toString),
                                  autoExpand = true)

      val inputRowCount = rawDf.count()
      // read from output table to avoid recomputation
      val outputRowCount = tableUtils.sql(unfilled.genScanQuery(null, joinConf.metaData.loggedTable)).count()

      (inputRowCount, outputRowCount)
    }
    val totalInputRowCount = rowCounts.map(_._1).sum
    val totalOutputRowCount = rowCounts.map(_._2).sum
    val columnAfterCount = columnCount()

    val failureCount = totalInputRowCount - totalOutputRowCount
    metrics.gauge(Metrics.Name.RowCount, totalOutputRowCount)
    metrics.gauge(Metrics.Name.FailureCount, failureCount)
    println(s"Processed logs: ${totalOutputRowCount} rows success, ${failureCount} rows failed.")
    metrics.gauge(Metrics.Name.ColumnBeforeCount, columnBeforeCount)
    metrics.gauge(Metrics.Name.ColumnAfterCount, columnAfterCount)
    val elapsedMins = (System.currentTimeMillis() - start) / 60000
    metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
  }
}

object LogFlattenerJob {

  def readSchemaTableProperties(tableUtils: TableUtils, logTable: String): Map[String, String] = {
    val curTblProps = tableUtils.getTableProperties(logTable).getOrElse(Map.empty)
    curTblProps
      .filterKeys(_.startsWith(Constants.SchemaHash))
      .map {
        case (key, value) => (key.substring(Constants.SchemaHash.length + 1), value)
      }
      .toMap
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy