ai.chronon.spark.Analyzer.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark_uber_2.11 Show documentation
Show all versions of spark_uber_2.11 Show documentation
Chronon is a feature engineering platform
package ai.chronon.spark
import ai.chronon.api
import ai.chronon.api.{Accuracy, AggregationPart, Constants, DataType, TimeUnit, Window}
import ai.chronon.api.Extensions._
import ai.chronon.online.SparkConversions
import ai.chronon.spark.Driver.parseConf
import com.yahoo.memory.Memory
import com.yahoo.sketches.ArrayOfStringsSerDe
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch}
import org.apache.spark.sql.{DataFrame, Row, types}
import org.apache.spark.sql.functions.{col, from_unixtime, lit}
import org.apache.spark.sql.types.StringType
import ai.chronon.aggregator.row.StatsGenerator
import ai.chronon.api.DataModel.{DataModel, Entities, Events}
import scala.collection.{Seq, immutable, mutable}
import scala.collection.mutable.ListBuffer
import scala.util.ScalaJavaConversions.ListOps
//@SerialVersionUID(3457890987L)
//class ItemSketchSerializable(var mapSize: Int) extends ItemsSketch[String](mapSize) with Serializable {}
class ItemSketchSerializable extends Serializable {
var sketch: ItemsSketch[String] = null
def init(mapSize: Int): ItemSketchSerializable = {
sketch = new ItemsSketch[String](mapSize)
this
}
// necessary for serialization
private def writeObject(out: java.io.ObjectOutputStream): Unit = {
val serDe = new ArrayOfStringsSerDe
val bytes = sketch.toByteArray(serDe)
out.writeInt(bytes.size)
out.writeBytes(new String(bytes))
}
private def readObject(input: java.io.ObjectInputStream): Unit = {
val size = input.readInt()
val bytes = new Array[Byte](size)
input.read(bytes)
val serDe = new ArrayOfStringsSerDe
sketch = ItemsSketch.getInstance[String](Memory.wrap(bytes), serDe)
}
}
class Analyzer(tableUtils: TableUtils,
conf: Any,
startDate: String,
endDate: String,
count: Int = 64,
sample: Double = 0.1,
enableHitter: Boolean = false,
silenceMode: Boolean = false) {
// include ts into heavy hitter analysis - useful to surface timestamps that have wrong units
// include total approx row count - so it is easy to understand the percentage of skewed data
def heavyHittersWithTsAndCount(df: DataFrame,
keys: Array[String],
frequentItemMapSize: Int = 1024,
sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = {
val baseDf = df.withColumn("total_count", lit("rows"))
val baseKeys = keys :+ "total_count"
if (df.schema.fieldNames.contains(Constants.TimeColumn)) {
heavyHitters(baseDf.withColumn("ts_year", from_unixtime(col("ts") / 1000, "yyyy")),
baseKeys :+ "ts_year",
frequentItemMapSize,
sampleFraction)
} else {
heavyHitters(baseDf, baseKeys, frequentItemMapSize, sampleFraction)
}
}
// Uses a variant Misra-Gries heavy hitter algorithm from Data Sketches to find topK most frequent items in data
// frame. The result is a Array of tuples of (column names, array of tuples of (heavy hitter keys, counts))
// [(keyCol1, [(key1: count1) ...]), (keyCol2, [...]), ....]
def heavyHitters(df: DataFrame,
frequentItemKeys: Array[String],
frequentItemMapSize: Int = 1024,
sampleFraction: Double = 0.1): Array[(String, Array[(String, Long)])] = {
assert(frequentItemKeys.nonEmpty, "No column arrays specified for frequent items summary")
// convert all keys into string
val stringifiedCols = frequentItemKeys.map { col =>
val stringified = df.schema.fields.find(_.name == col) match {
case Some(types.StructField(name, StringType, _, _)) => name
case Some(types.StructField(name, _, _, _)) => s"CAST($name AS STRING)"
case None =>
throw new IllegalArgumentException(s"$col is not present among: [${df.schema.fieldNames.mkString(", ")}]")
}
s"COALESCE($stringified, 'NULL')"
}
val colsLength = stringifiedCols.length
val init = Array.fill(colsLength)((new ItemSketchSerializable).init(frequentItemMapSize))
val freqMaps = df
.selectExpr(stringifiedCols: _*)
.sample(sampleFraction)
.rdd
.treeAggregate(init)(
seqOp = {
case (sketches, row) =>
var i = 0
while (i < colsLength) {
sketches(i).sketch.update(row.getString(i))
i += 1
}
sketches
},
combOp = {
case (sketches1, sketches2) =>
var i = 0
while (i < colsLength) {
sketches1(i).sketch.merge(sketches2(i).sketch)
i += 1
}
sketches1
}
)
.map(_.sketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES))
.map(_.map(sketchRow => sketchRow.getItem -> (sketchRow.getEstimate.toDouble / sampleFraction).toLong).toArray)
frequentItemKeys.zip(freqMaps)
}
private val range = PartitionRange(startDate, endDate)(tableUtils)
// returns with heavy hitter analysis for the specified keys
def analyze(df: DataFrame, keys: Array[String], sourceTable: String): String = {
val result = heavyHittersWithTsAndCount(df, keys, count, sample)
val header = s"Analyzing heavy-hitters from table $sourceTable over columns: [${keys.mkString(", ")}]"
val colPrints = result.flatMap {
case (col, heavyHitters) =>
Seq(s" $col") ++ heavyHitters.map { case (name, count) => s" $name: $count" }
}
(header +: colPrints).mkString("\n")
}
// Rich version of structType which includes additional info for a groupBy feature schema
case class AggregationMetadata(name: String,
columnType: DataType,
operation: String = null,
window: String = null,
inputColumn: String = null,
groupByName: String = null) {
def asMap: Map[String, String] = {
Map(
"name" -> name,
"window" -> window,
"columnType" -> DataType.toString(columnType),
"inputColumn" -> inputColumn,
"operation" -> operation,
"groupBy" -> groupByName
)
}
}
def toAggregationMetadata(aggPart: AggregationPart, columnType: DataType): AggregationMetadata = {
AggregationMetadata(aggPart.outputColumnName,
columnType,
aggPart.operation.toString.toLowerCase,
aggPart.window.str.toLowerCase,
aggPart.inputColumn.toLowerCase)
}
def toAggregationMetadata(columnName: String, columnType: DataType): AggregationMetadata = {
AggregationMetadata(columnName, columnType, "No operation", "Unbounded", columnName)
}
def analyzeGroupBy(groupByConf: api.GroupBy,
prefix: String = "",
includeOutputTableName: Boolean = false,
enableHitter: Boolean = false): (Array[AggregationMetadata], Map[String, DataType]) = {
groupByConf.setups.foreach(tableUtils.sql)
val groupBy = GroupBy.from(groupByConf, range, tableUtils, computeDependency = enableHitter, finalize = true)
val name = "group_by/" + prefix + groupByConf.metaData.name
println(s"""|Running GroupBy analysis for $name ...""".stripMargin)
val analysis =
if (enableHitter)
analyze(groupBy.inputDf,
groupByConf.keyColumns.toScala.toArray,
groupByConf.sources.toScala.map(_.table).mkString(","))
else ""
val schema = if (groupByConf.isSetBackfillStartDate && groupByConf.hasDerivations) {
// handle group by backfill mode for derivations
// todo: add the similar logic to join derivations
val sparkSchema = SparkConversions.fromChrononSchema(groupBy.outputSchema)
val dummyOutputDf = tableUtils.sparkSession
.createDataFrame(tableUtils.sparkSession.sparkContext.parallelize(immutable.Seq[Row]()), sparkSchema)
val finalOutputColumns = groupByConf.derivations.toScala.finalOutputColumn(dummyOutputDf.columns).toSeq
val derivedDummyOutputDf = dummyOutputDf.select(finalOutputColumns: _*)
val columns = SparkConversions.toChrononSchema(derivedDummyOutputDf.schema)
api.StructType("", columns.map(tup => api.StructField(tup._1, tup._2)))
} else {
groupBy.outputSchema
}
if (silenceMode) {
println(s"""ANALYSIS completed for group_by/${name}.""".stripMargin)
} else {
println(s"""
|ANALYSIS for $name:
|$analysis
""".stripMargin)
if (includeOutputTableName)
println(s"""
|----- OUTPUT TABLE NAME -----
|${groupByConf.metaData.outputTable}
""".stripMargin)
val keySchema = groupBy.keySchema.fields.map { field => s" ${field.name} => ${field.dataType}" }
schema.fields.map { field => s" ${field.name} => ${field.fieldType}" }
println(s"""
|----- KEY SCHEMA -----
|${keySchema.mkString("\n")}
|----- OUTPUT SCHEMA -----
|${schema.mkString("\n")}
|------ END --------------
|""".stripMargin)
}
val aggMetadata = if (groupByConf.aggregations != null) {
groupBy.aggPartWithSchema.map { entry => toAggregationMetadata(entry._1, entry._2) }.toArray
} else {
schema.map { tup => toAggregationMetadata(tup.name, tup.fieldType) }.toArray
}
val keySchemaMap = groupBy.keySchema.map { field =>
field.name -> SparkConversions.toChrononType(field.name, field.dataType)
}.toMap
(aggMetadata, keySchemaMap)
}
def analyzeJoin(joinConf: api.Join, enableHitter: Boolean = false, validationAssert: Boolean = false)
: (Map[String, DataType], ListBuffer[AggregationMetadata], Map[String, DataType]) = {
val name = "joins/" + joinConf.metaData.name
println(s"""|Running join analysis for $name ...""".stripMargin)
// run SQL environment setups such as UDFs and JARs
joinConf.setups.foreach(tableUtils.sql)
val leftDf = JoinUtils.leftDf(joinConf, range, tableUtils, allowEmpty = true).get
val analysis = if (enableHitter) analyze(leftDf, joinConf.leftKeyCols, joinConf.left.table) else ""
val leftSchema: Map[String, DataType] =
leftDf.schema.fields.map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType))).toMap
val aggregationsMetadata = ListBuffer[AggregationMetadata]()
val keysWithError: ListBuffer[(String, String)] = ListBuffer.empty[(String, String)]
val gbTables = ListBuffer[String]()
val gbStartPartitions = mutable.Map[String, List[String]]()
// Pair of (table name, group_by name, expected_start) which indicate that the table no not have data available for the required group_by
val dataAvailabilityErrors: ListBuffer[(String, String, String)] = ListBuffer.empty[(String, String, String)]
val rangeToFill =
JoinUtils.getRangesToFill(joinConf.left, tableUtils, endDate, historicalBackfill = joinConf.historicalBackfill)
println(s"[Analyzer] Join range to fill $rangeToFill")
val unfilledRanges = tableUtils
.unfilledRanges(joinConf.metaData.outputTable, rangeToFill, Some(Seq(joinConf.left.table)))
.getOrElse(Seq.empty)
joinConf.joinParts.toScala.foreach { part =>
val (aggMetadata, gbKeySchema) =
analyzeGroupBy(part.groupBy, part.fullPrefix, includeOutputTableName = true, enableHitter = enableHitter)
aggregationsMetadata ++= aggMetadata.map { aggMeta =>
AggregationMetadata(part.fullPrefix + "_" + aggMeta.name,
aggMeta.columnType,
aggMeta.operation,
aggMeta.window,
aggMeta.inputColumn,
part.getGroupBy.getMetaData.getName)
}
// Run validation checks.
keysWithError ++= runSchemaValidation(leftSchema, gbKeySchema, part.rightToLeft)
gbTables ++= part.groupBy.sources.toScala.map(_.table)
dataAvailabilityErrors ++= runDataAvailabilityCheck(joinConf.left.dataModel, part.groupBy, unfilledRanges)
// list any startPartition dates for conflict checks
val gbStartPartition = part.groupBy.sources.toScala
.map(_.query.startPartition)
.filter(_ != null)
if (!gbStartPartition.isEmpty)
gbStartPartitions += (part.groupBy.metaData.name -> gbStartPartition)
}
val noAccessTables = runTablePermissionValidation((gbTables.toList ++ List(joinConf.left.table)).toSet)
val rightSchema: Map[String, DataType] =
aggregationsMetadata.map(aggregation => (aggregation.name, aggregation.columnType)).toMap
val statsSchema = StatsGenerator.statsIrSchema(api.StructType.from("Stats", rightSchema.toArray))
if (silenceMode) {
println(s"""ANALYSIS completed for join/${joinConf.metaData.cleanName}.""".stripMargin)
} else {
println(s"""
|ANALYSIS for join/${joinConf.metaData.cleanName}:
|$analysis
|----- OUTPUT TABLE NAME -----
|${joinConf.metaData.outputTable}
|------ LEFT SIDE SCHEMA -------
|${leftSchema.mkString("\n")}
|------ RIGHT SIDE SCHEMA ----
|${rightSchema.mkString("\n")}
|------ STATS SCHEMA ---------
|${statsSchema.unpack.toMap.mkString("\n")}
|------ END ------------------
|""".stripMargin)
}
println(s"----- Validations for join/${joinConf.metaData.cleanName} -----")
if (!gbStartPartitions.isEmpty) {
println(
"----- Following Group_Bys contains a startPartition. Please check if any startPartition will conflict with your backfill. -----")
gbStartPartitions.foreach {
case (gbName, startPartitions) =>
println(s"$gbName : ${startPartitions.mkString(",")}")
}
}
if (keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty) {
println("----- Backfill validation completed. No errors found. -----")
} else {
println(s"----- Schema validation completed. Found ${keysWithError.size} errors")
val keyErrorSet: Set[(String, String)] = keysWithError.toSet
println(keyErrorSet.map { case (key, errorMsg) => s"$key => $errorMsg" }.mkString("\n"))
println(s"---- Table permission check completed. Found permission errors in ${noAccessTables.size} tables ----")
println(noAccessTables.mkString("\n"))
println(s"---- Data availability check completed. Found issue in ${dataAvailabilityErrors.size} tables ----")
dataAvailabilityErrors.foreach(error =>
println(s"Table ${error._1} : Group_By ${error._2} : Expected start ${error._3}"))
}
if (validationAssert) {
assert(
keysWithError.isEmpty && noAccessTables.isEmpty && dataAvailabilityErrors.isEmpty,
"ERROR: Join validation failed. Please check error message for details."
)
}
// (schema map showing the names and datatypes, right side feature aggregations metadata for metadata upload)
(leftSchema ++ rightSchema, aggregationsMetadata, statsSchema.unpack.toMap)
}
// validate the schema of the left and right side of the join and make sure the types match
// return a map of keys and corresponding error message that failed validation
def runSchemaValidation(left: Map[String, DataType],
right: Map[String, DataType],
keyMapping: Map[String, String]): Map[String, String] = {
keyMapping.flatMap {
case (_, leftKey) if !left.contains(leftKey) =>
Some(leftKey ->
s"[ERROR]: Left side of the join doesn't contain the key $leftKey. Available keys are [${left.keys.mkString(",")}]")
case (rightKey, _) if !right.contains(rightKey) =>
Some(
rightKey ->
s"[ERROR]: Right side of the join doesn't contain the key $rightKey. Available keys are [${right.keys
.mkString(",")}]")
case (rightKey, leftKey) if left(leftKey) != right(rightKey) =>
Some(
leftKey ->
s"[ERROR]: Join key, '$leftKey', has mismatched data types - left type: ${left(
leftKey)} vs. right type ${right(rightKey)}")
case _ => None
}
}
// validate the table permissions for given list of tables
// return a list of tables that the user doesn't have access to
def runTablePermissionValidation(sources: Set[String]): Set[String] = {
println(s"Validating ${sources.size} tables permissions ...")
val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
//todo: handle offset-by-1 depending on temporal vs snapshot accuracy
val partitionFilter = tableUtils.partitionSpec.minus(today, new Window(2, TimeUnit.DAYS))
sources.filter { sourceTable =>
!tableUtils.checkTablePermission(sourceTable, partitionFilter)
}
}
// validate that data is available for the group by
// - For aggregation case, gb table earliest partition should go back to (first_unfilled_partition - max_window) date
// - For none aggregation case or unbounded window, no earliest partition is required
// return a list of (table, gb_name, expected_start) that don't have data available
def runDataAvailabilityCheck(leftDataModel: DataModel,
groupBy: api.GroupBy,
unfilledRanges: Seq[PartitionRange]): List[(String, String, String)] = {
if (unfilledRanges.isEmpty) {
println("No unfilled ranges found.")
List.empty
} else {
val firstUnfilledPartition = unfilledRanges.min.start
lazy val groupByOps = new GroupByOps(groupBy)
lazy val leftShiftedPartitionRangeStart = unfilledRanges.min.shift(-1).start
lazy val rightShiftedPartitionRangeStart = unfilledRanges.min.shift(1).start
val maxWindow = groupByOps.maxWindow
maxWindow match {
case Some(window) =>
val expectedStart = (leftDataModel, groupBy.dataModel, groupBy.inferredAccuracy) match {
// based on the end of the day snapshot
case (Entities, Events, _) => tableUtils.partitionSpec.minus(rightShiftedPartitionRangeStart, window)
case (Entities, Entities, _) => firstUnfilledPartition
case (Events, Events, Accuracy.SNAPSHOT) =>
tableUtils.partitionSpec.minus(leftShiftedPartitionRangeStart, window)
case (Events, Events, Accuracy.TEMPORAL) =>
tableUtils.partitionSpec.minus(firstUnfilledPartition, window)
case (Events, Entities, Accuracy.SNAPSHOT) => leftShiftedPartitionRangeStart
case (Events, Entities, Accuracy.TEMPORAL) =>
tableUtils.partitionSpec.minus(leftShiftedPartitionRangeStart, window)
}
groupBy.sources.toScala.flatMap { source =>
val table = source.table
println(s"Checking table $table for data availability ... Expected start partition: $expectedStart")
//check if partition available or table is cumulative
if (!tableUtils.ifPartitionExistsInTable(table, expectedStart) && !source.isCumulative) {
Some((table, groupBy.getMetaData.getName, expectedStart))
} else {
None
}
}
case None =>
List.empty
}
}
}
def run(): Unit =
conf match {
case confPath: String =>
if (confPath.contains("/joins/")) {
val joinConf = parseConf[api.Join](confPath)
analyzeJoin(joinConf, enableHitter = enableHitter)
} else if (confPath.contains("/group_bys/")) {
val groupByConf = parseConf[api.GroupBy](confPath)
analyzeGroupBy(groupByConf, enableHitter = enableHitter)
}
case groupByConf: api.GroupBy => analyzeGroupBy(groupByConf, enableHitter = enableHitter)
case joinConf: api.Join => analyzeJoin(joinConf, enableHitter = enableHitter)
}
}