All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.chronon.spark.GroupByUpload.scala Maven / Gradle / Ivy

There is a newer version: 13_publish-0.0.12
Show newest version
package ai.chronon.spark

import ai.chronon.aggregator.windowing.{FinalBatchIr, FiveMinuteResolution, Resolution, SawtoothOnlineAggregator}
import ai.chronon.api
import ai.chronon.api.{Accuracy, Constants, DataModel, GroupByServingInfo, QueryUtils, ThriftJsonCodec}
import ai.chronon.api.Extensions.{GroupByOps, MetadataOps, SourceOps}
import ai.chronon.online.Extensions.ChrononStructTypeOps
import ai.chronon.online.{GroupByServingInfoParsed, Metrics, SparkConversions}
import ai.chronon.spark.Extensions._
import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{col, lit, not}
import org.apache.spark.sql.{Row, SparkSession, types}

import scala.collection.Seq
import scala.util.ScalaJavaConversions.{ListOps, MapOps}
import scala.util.Try

class GroupByUpload(endPartition: String, groupBy: GroupBy) extends Serializable {
  implicit val sparkSession: SparkSession = groupBy.sparkSession
  implicit private val tableUtils = TableUtils(sparkSession)
  private def fromBase(rdd: RDD[(Array[Any], Array[Any])]): KvRdd = {
    KvRdd(rdd.map { case (keyAndDs, values) => keyAndDs.init -> values }, groupBy.keySchema, groupBy.postAggSchema)
  }
  def snapshotEntities: KvRdd = {
    if (groupBy.aggregations == null || groupBy.aggregations.isEmpty) {
      // pre-agg to PairRdd
      val keysAndPartition = (groupBy.keyColumns :+ tableUtils.partitionColumn).toArray
      val keyBuilder = FastHashing.generateKeyBuilder(keysAndPartition, groupBy.inputDf.schema)
      val values = groupBy.inputDf.schema.map(_.name).filterNot(keysAndPartition.contains)
      val valuesIndices = values.map(groupBy.inputDf.schema.fieldIndex).toArray
      val rdd = groupBy.inputDf.rdd
        .map { row =>
          keyBuilder(row).data.init -> valuesIndices.map(row.get)
        }
      KvRdd(rdd, groupBy.keySchema, groupBy.preAggSchema)
    } else {
      fromBase(groupBy.snapshotEntitiesBase)
    }
  }

  def snapshotEvents: KvRdd =
    fromBase(groupBy.snapshotEventsBase(PartitionRange(endPartition, endPartition)))

  // Shared between events and mutations (temporal entities).
  def temporalEvents(resolution: Resolution = FiveMinuteResolution): KvRdd = {
    val endTs = tableUtils.partitionSpec.epochMillis(endPartition)
    println(s"TemporalEvents upload end ts: $endTs")
    val sawtoothOnlineAggregator = new SawtoothOnlineAggregator(
      endTs,
      groupBy.aggregations,
      SparkConversions.toChrononSchema(groupBy.inputDf.schema),
      resolution)
    val irSchema = SparkConversions.fromChrononSchema(sawtoothOnlineAggregator.batchIrSchema)
    val keyBuilder = FastHashing.generateKeyBuilder(groupBy.keyColumns.toArray, groupBy.inputDf.schema)

    println(s"""
        |BatchIR Element Size: ${SparkEnv.get.serializer
      .newInstance()
      .serialize(sawtoothOnlineAggregator.init)
      .capacity()}
        |""".stripMargin)
    val outputRdd = groupBy.inputDf.rdd
      .keyBy(keyBuilder)
      .mapValues(SparkConversions.toChrononRow(_, groupBy.tsIndex))
      .aggregateByKey(sawtoothOnlineAggregator.init)( // shuffle point
        seqOp = sawtoothOnlineAggregator.update, combOp = sawtoothOnlineAggregator.merge)
      .mapValues(sawtoothOnlineAggregator.normalizeBatchIr)
      .map {
        case (keyWithHash: KeyWithHash, finalBatchIr: FinalBatchIr) =>
          val irArray = new Array[Any](2)
          irArray.update(0, finalBatchIr.collapsed)
          irArray.update(1, finalBatchIr.tailHops)
          keyWithHash.data -> irArray
      }
    KvRdd(outputRdd, groupBy.keySchema, irSchema)
  }

}

object GroupByUpload {

  // TODO - remove this if spark streaming can't reach hive tables
  def buildServingInfo(groupByConf: api.GroupBy, session: SparkSession, endDs: String): GroupByServingInfoParsed = {
    val groupByServingInfo = new GroupByServingInfo()
    implicit val tableUtils: TableUtils = TableUtils(session)
    val nextDay = tableUtils.partitionSpec.after(endDs)

    val groupBy = ai.chronon.spark.GroupBy
      .from(groupByConf,
            PartitionRange(endDs, endDs),
            TableUtils(session),
            computeDependency = false,
            mutationScan = false)
    groupByServingInfo.setBatchEndDate(nextDay)
    groupByServingInfo.setGroupBy(groupByConf)
    groupByServingInfo.setKeyAvroSchema(groupBy.keySchema.toAvroSchema("Key").toString(true))
    groupByServingInfo.setSelectedAvroSchema(groupBy.preAggSchema.toAvroSchema("Value").toString(true))
    if (groupByConf.streamingSource.isDefined) {
      val streamingSource = groupByConf.streamingSource.get

      // TODO: move this to SourceOps
      def getInfo(source: api.Source): (String, api.Query, Boolean) = {
        if (source.isSetEvents) {
          (source.getEvents.getTable, source.getEvents.getQuery, false)
        } else if (source.isSetEntities) {
          (source.getEntities.getSnapshotTable, source.getEntities.getQuery, true)
        } else {
          val left = source.getJoinSource.getJoin.getLeft
          getInfo(left)
        }
      }

      val (rootTable, query, _) = getInfo(streamingSource)
      val fullInputSchema = tableUtils.getSchemaFromTable(rootTable)
      val inputSchema: types.StructType =
        if (Option(query.selects).isEmpty) fullInputSchema
        else {
          val selects = query.selects.toScala ++ Map(Constants.TimeColumn -> query.timeColumn)
          val streamingQuery =
            QueryUtils.build(selects, rootTable, query.wheres.toScala)
          val reqColumns = tableUtils.getColumnsFromQuery(streamingQuery)
          types.StructType(fullInputSchema.filter(col => reqColumns.contains(col.name)))
        }
      groupByServingInfo.setInputAvroSchema(inputSchema.toAvroSchema(name = "Input").toString(true))
    } else {
      println("Not setting InputAvroSchema to GroupByServingInfo as there is no streaming source defined.")
    }

    val result = new GroupByServingInfoParsed(groupByServingInfo, tableUtils.partitionSpec)
    val firstSource = groupByConf.sources.get(0)
    println(s"""
        |Built GroupByServingInfo for ${groupByConf.metaData.name}:
        |table: ${firstSource.table} / data-model: ${firstSource.dataModel}
        |     keySchema: ${Try(result.keyChrononSchema.catalogString)}
        |   valueSchema: ${Try(result.valueChrononSchema.catalogString)}
        |mutationSchema: ${Try(result.mutationChrononSchema.catalogString)}
        |   inputSchema: ${Try(result.inputChrononSchema.catalogString)}
        |selectedSchema: ${Try(result.selectedChrononSchema.catalogString)}
        |  streamSchema: ${Try(result.streamChrononSchema.catalogString)}
        |""".stripMargin)
    result
  }

  def run(groupByConf: api.GroupBy,
          endDs: String,
          tableUtilsOpt: Option[TableUtils] = None,
          showDf: Boolean = false,
          jsonPercent: Int = 1): Unit = {
    val context = Metrics.Context(Metrics.Environment.GroupByUpload, groupByConf)
    val startTs = System.currentTimeMillis()
    implicit val tableUtils: TableUtils =
      tableUtilsOpt.getOrElse(
        TableUtils(
          SparkSessionBuilder
            .build(s"groupBy_${groupByConf.metaData.name}_upload")))
    groupByConf.setups.foreach(tableUtils.sql)
    // add 1 day to the batch end time to reflect data [ds 00:00:00.000, ds + 1 00:00:00.000)
    val batchEndDate = tableUtils.partitionSpec.after(endDs)
    // for snapshot accuracy - we don't need to scan mutations
    lazy val groupBy = GroupBy.from(groupByConf,
                                    PartitionRange(endDs, endDs),
                                    tableUtils,
                                    computeDependency = true,
                                    mutationScan = false,
                                    showDf = showDf)
    lazy val groupByUpload = new GroupByUpload(endDs, groupBy)
    // for temporal accuracy - we don't need to scan mutations for upload
    // when endDs = xxxx-01-02 the timestamp from airflow is more than (xxxx-01-03 00:00:00)
    // we wait for event partitions of (xxxx-01-02) which contain data until (xxxx-01-02 23:59:59.999)
    lazy val shiftedGroupBy =
      GroupBy.from(groupByConf,
                   PartitionRange(endDs, endDs).shift(1),
                   tableUtils,
                   computeDependency = true,
                   mutationScan = false,
                   showDf = showDf)
    lazy val shiftedGroupByUpload = new GroupByUpload(batchEndDate, shiftedGroupBy)
    // for mutations I need the snapshot from the previous day, but a batch end date of ds +1
    lazy val otherGroupByUpload = new GroupByUpload(batchEndDate, groupBy)

    println(s"""
         |GroupBy upload for: ${groupByConf.metaData.team}.${groupByConf.metaData.name}
         |Accuracy: ${groupByConf.inferredAccuracy}
         |Data Model: ${groupByConf.dataModel}
         |""".stripMargin)

    val kvRdd = ((groupByConf.inferredAccuracy, groupByConf.dataModel) match {
      case (Accuracy.SNAPSHOT, DataModel.Events)   => groupByUpload.snapshotEvents
      case (Accuracy.SNAPSHOT, DataModel.Entities) => groupByUpload.snapshotEntities
      case (Accuracy.TEMPORAL, DataModel.Events)   => shiftedGroupByUpload.temporalEvents()
      case (Accuracy.TEMPORAL, DataModel.Entities) => otherGroupByUpload.temporalEvents()
    })

    val kvDf = kvRdd.toAvroDf(jsonPercent = jsonPercent)
    if (showDf) {
      kvRdd.toFlatDf.prettyPrint()
    }

    val groupByServingInfo = buildServingInfo(groupByConf, session = tableUtils.sparkSession, endDs).groupByServingInfo

    val metaRows = Seq(
      Row(
        Constants.GroupByServingInfoKey.getBytes(Constants.UTF8),
        ThriftJsonCodec.toJsonStr(groupByServingInfo).getBytes(Constants.UTF8),
        Constants.GroupByServingInfoKey,
        ThriftJsonCodec.toJsonStr(groupByServingInfo)
      ))
    val metaRdd = tableUtils.sparkSession.sparkContext.parallelize(metaRows.toSeq)
    val metaDf = tableUtils.sparkSession.createDataFrame(metaRdd, kvDf.schema)
    kvDf
      .union(metaDf)
      .withColumn("ds", lit(endDs))
      .saveUnPartitioned(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps)

    val kvDfReloaded = tableUtils.sparkSession
      .table(groupByConf.metaData.uploadTable)
      .where(not(col("key_json").eqNullSafe(Constants.GroupByServingInfoKey)))

    val metricRow =
      kvDfReloaded.selectExpr("sum(bit_length(key_bytes))/8", "sum(bit_length(value_bytes))/8", "count(*)").collect()
    context.gauge(Metrics.Name.KeyBytes, metricRow(0).getDouble(0).toLong)
    context.gauge(Metrics.Name.ValueBytes, metricRow(0).getDouble(1).toLong)
    context.gauge(Metrics.Name.RowCount, metricRow(0).getLong(2))
    context.gauge(Metrics.Name.LatencyMinutes, (System.currentTimeMillis() - startTs) / (60 * 1000))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy