Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.datawizards.dqm.validator.DataValidator.scala Maven / Gradle / Ivy
package com.datawizards.dqm.validator
import com.datawizards.dqm.configuration.{GroupByConfiguration, TableConfiguration, ValidationContext}
import com.datawizards.dqm.history.HistoryStatisticsReader
import com.datawizards.dqm.result._
import com.datawizards.dqm.rules.FieldRules
import com.datawizards.dqm.rules.trend.TableTrendRule
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Row}
import scala.collection.mutable.ListBuffer
import scala.util.parsing.json.JSONObject
object DataValidator {
private val countColumn = "count"
def validate(
tableConfiguration: TableConfiguration,
context: ValidationContext,
historyStatisticsReader: HistoryStatisticsReader
): ValidationResult = {
val tableLocation = tableConfiguration.location
val filterByProcessingDateStrategy = tableConfiguration.filterByProcessingDateStrategy
val input = tableLocation.load()
val df = if(filterByProcessingDateStrategy.isDefined) filterByProcessingDateStrategy.get.filter(input, context.processingDate) else input
val fields = getDataFrameFields(df)
val aggregate = aggregateDataFrame(df, fields)
val rowsCount = calculateRowsCount(aggregate)
val (groupByStatisticsList, invalidGroups) = calculateGroupByStatistics(df, context, tableConfiguration.groups)
ValidationResult(
invalidRecords = calculateInvalidRecords(df, context, tableConfiguration.rules.rowRules),
tableStatistics = calculateTableStatistics(df, context, rowsCount),
columnsStatistics = calculateColumnStatistics(context, rowsCount, fields, aggregate),
groupByStatisticsList = groupByStatisticsList,
invalidGroups = invalidGroups,
invalidTableTrends = calculateInvalidTableTrends(historyStatisticsReader, context, tableConfiguration.rules.tableTrendRules)
)
}
private def calculateInvalidRecords(input: DataFrame, context: ValidationContext, rowRules: Seq[FieldRules]): Array[InvalidRecord] = {
val spark = input.sparkSession
import spark.implicits._
input.flatMap{row =>
rowRules.flatMap{fieldRules => {
val field = fieldRules.field
fieldRules
.rules
.withFilter(fr => !fr.validate(field, row))
.map{fr =>
val fieldValue = row.getAs[Any](field)
val values = row.getValuesMap[Any](row.schema.fieldNames).mapValues(v => if(v == null) "null" else v)
val rowString = JSONObject(values).toString()
val maxSize = 1000
InvalidRecord(
tableName = context.tableName,
columnName = field,
row = if(rowString.size > maxSize) rowString.substring(0, maxSize) else rowString,
value = if(fieldValue == null) "null" else fieldValue.toString,
rule = fr.name,
year = context.processingYear,
month = context.processingMonth,
day = context.processingDay
)
}
}}
}.collect()
}
private def calculateTableStatistics(df: DataFrame, context: ValidationContext, rowsCount: Long): TableStatistics = {
TableStatistics(
tableName = context.tableName,
rowsCount = rowsCount,
columnsCount = calculateColumnsCount(df),
year = context.processingYear,
month = context.processingMonth,
day = context.processingDay
)
}
private def aggregateDataFrame(df: DataFrame, fields: Seq[StructField]): Row = {
val exprs = buildAggregateExpressions(fields)
df
.agg(count("*").alias(countColumn), exprs:_*)
.collect()
.head
}
private def buildAggregateExpressions(fields: Seq[StructField]): Seq[Column] = {
fields
.flatMap(f => {
val columnName = f.name
val numericFieldAggregations = if(isNumericType(f.dataType))
List(
min(columnName).alias(minForColumnName(columnName))
, max(columnName).alias(maxForColumnName(columnName))
, avg(columnName).alias(avgForColumnName(columnName))
, stddev(columnName).alias(stdDevForColumnName(columnName))
)
else
Nil
count(columnName).alias(countForColumnName(columnName)) :: numericFieldAggregations
})
}
private def getDataFrameFields(df: DataFrame): Seq[StructField] =
df.schema.fields
private def calculateColumnsCount(df: DataFrame): Int = {
df.columns.length
}
private def countForColumnName(c: String) = c + "_count"
private def minForColumnName(c: String) = c + "_min"
private def maxForColumnName(c: String) = c + "_max"
private def avgForColumnName(c: String) = c + "_avg"
private def stdDevForColumnName(c: String) = c + "_stddev"
private def calculateRowsCount(aggregate: Row): Long =
aggregate.getAs[Long](countColumn)
private def calculateColumnStatistics(context: ValidationContext, rowsCount: Long, fields: Seq[StructField], aggregate: Row): Seq[ColumnStatistics] = {
var columnIndex = -1
fields
.map(f => {
columnIndex += 1
val columnName = getColumnName(f)
val columnType = getColumnType(f)
val min = getAggregateValueIfNumericField(aggregate, minForColumnName(columnName), columnType)
val max = getAggregateValueIfNumericField(aggregate, maxForColumnName(columnName), columnType)
val notMissingCount = aggregate.getAs[Long](countForColumnName(columnName))
ColumnStatistics(
tableName = context.tableName,
columnName = columnName,
columnType = f.dataType.toString,
notMissingCount = notMissingCount,
rowsCount = rowsCount,
percentageNotMissing = if(rowsCount == 0) None else Some(1.0*notMissingCount/rowsCount),
min = min,
max = max,
avg = getAggregateValueIfNumericField(aggregate, avgForColumnName(columnName), columnType),
stddev = getAggregateValueIfNumericField(aggregate, stdDevForColumnName(columnName), columnType),
year = context.processingYear,
month = context.processingMonth,
day = context.processingDay
)
})
}
private def getAggregateValueIfNumericField(aggregate: Row, columnName: String, columnType: DataType): Option[Double] =
if(isNumericType(columnType)) castToDouble(aggregate.getAs[Double](columnName))
else None
private def getColumnType(f: StructField): DataType = f.dataType
private def getColumnName(f: StructField): String = f.name
private def isNumericType(columnType: DataType): Boolean =
columnType.equals(IntegerType) || columnType.equals(LongType) || columnType.equals(DoubleType)
private def castToDouble(v: Any): Option[Double] = v match {
case d: Double => Some(d)
case b: Byte => Some(b.toDouble)
case s: Short => Some(s.toDouble)
case i: Int => Some(i.toDouble)
case j: Long => Some(j.toDouble)
case f: Float => Some(f.toDouble)
case Some(d: Double) => Some(d)
case Some(b: Byte) => Some(b.toDouble)
case Some(s: Short) => Some(s.toDouble)
case Some(i: Int) => Some(i.toDouble)
case Some(j: Long) => Some(j.toDouble)
case Some(f: Float) => Some(f.toDouble)
case _ => None
}
private def calculateGroupByStatistics(df: DataFrame, context: ValidationContext, groups: Seq[GroupByConfiguration]): (Seq[GroupByStatistics], Seq[InvalidGroup]) = {
val groupByStatisticsListBuffer = new ListBuffer[GroupByStatistics]
val invalidGroupsBuffer = new ListBuffer[InvalidGroup]
for(group <- groups) {
val groupByResult = df.groupBy(group.groupByFieldName).count().collect().map(r => r.getAs[String](group.groupByFieldName) -> r.getAs[Long]("count"))
val groupByStatisticsList = groupByResult.map{case (groupByFieldValue,rowsCount) =>
GroupByStatistics(
tableName = context.tableName,
groupName = group.groupName,
groupByFieldValue = groupByFieldValue,
rowsCount = rowsCount,
year = context.processingYear,
month = context.processingMonth,
day = context.processingDay
)}
groupByStatisticsListBuffer ++= groupByStatisticsList
for(groupRule <- group.rules)
invalidGroupsBuffer ++= groupRule.validate(groupByStatisticsList, context, group.groupName)
}
(groupByStatisticsListBuffer.toList, invalidGroupsBuffer.toList)
}
private def calculateInvalidTableTrends(
historyStatisticsReader: HistoryStatisticsReader,
context: ValidationContext,
tableTrendRules: Seq[TableTrendRule]
): Seq[InvalidTableTrend] = {
if(tableTrendRules.isEmpty) Seq.empty
else {
val tableStatistics = historyStatisticsReader.readTableStatistics(context.tableName)
tableTrendRules
.map(r => r.validate(tableStatistics, context))
.reduce(_ ++ _)
}
}
}