All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazon.deequ.analyzers.Analyzer.scala Maven / Gradle / Ivy

/**
 * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not
 * use this file except in compliance with the License. A copy of the License
 * is located at
 *
 *     http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 *
 */

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.metrics.{DoubleMetric, Entity, Metric}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Row}
import com.amazon.deequ.analyzers.runners._

import scala.language.existentials
import scala.util.{Failure, Success}

/**
  * A state (sufficient statistic) computed from data, from which we can compute a metric.
  * Must be combinable with other states of the same type
  * (= algebraic properties of a commutative semi-group)
  */
trait State[S <: State[S]] {

  // Unfortunately this is required due to type checking issues
  private[analyzers] def sumUntyped(other: State[_]): S = {
    sum(other.asInstanceOf[S])
  }

  /** Combine this with another state */
  def sum(other: S): S

  /** Same as sum, syntatic sugar */
  def +(other: S): S = {
    sum(other)
  }
}

/** A state which produces a double valued metric  */
trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] {
  def metricValue(): Double
}

/** Common trait for all analyzers which generates metrics from states computed on data frames */
trait Analyzer[S <: State[_], +M <: Metric[_]] {

  /**
    * Compute the state (sufficient statistics) from the data
    * @param data data frame
    * @return
    */
  def computeStateFrom(data: DataFrame): Option[S]

  /**
    * Compute the metric from the state (sufficient statistics)
    * @param state wrapper holding a state of type S (required due to typing issues...)
    * @return
    */
  def computeMetricFrom(state: Option[S]): M

  /**
    * A set of assertions that must hold on the schema of the data frame
    * @return
    */
  def preconditions: Seq[StructType => Unit] = {
    Seq.empty
  }

  /**
    * Runs preconditions, calculates and returns the metric
    *
    * @param data Data frame being analyzed
    * @param aggregateWith loader for previous states to include in the computation (optional)
    * @param saveStatesWith persist internal states using this (optional)
    * @return Returns failure metric in case preconditions fail.
    */
  def calculate(
      data: DataFrame,
      aggregateWith: Option[StateLoader] = None,
      saveStatesWith: Option[StatePersister] = None)
    : M = {

    try {
      preconditions.foreach { condition => condition(data.schema) }

      val state = computeStateFrom(data)

      calculateMetric(state, aggregateWith, saveStatesWith)
    } catch {
      case error: Exception => toFailureMetric(error)
    }
  }

  private[deequ] def toFailureMetric(failure: Exception): M

  def calculateMetric(
      state: Option[S],
      aggregateWith: Option[StateLoader] = None,
      saveStatesWith: Option[StatePersister] = None)
    : M = {

    // Try to load the state
    val loadedState: Option[S] = aggregateWith.flatMap { _.load[S](this) }

    // Potentially merge existing and loaded state
    val stateToComputeMetricFrom: Option[S] = Analyzers.merge(state, loadedState)

    // Persist the state if it is not empty and a persister was provided
    stateToComputeMetricFrom
      .foreach { state =>
        saveStatesWith.foreach {
          _.persist[S](this, state)
        }
      }

    computeMetricFrom(stateToComputeMetricFrom)
  }

  private[deequ] def aggregateStateTo(
      sourceA: StateLoader,
      sourceB: StateLoader,
      target: StatePersister)
    : Unit = {

    val maybeStateA = sourceA.load[S](this)
    val maybeStateB = sourceB.load[S](this)

    val aggregated = (maybeStateA, maybeStateB) match {
      case (Some(stateA), Some(stateB)) => Some(stateA.sumUntyped(stateB).asInstanceOf[S])
      case (Some(stateA), None) => Some(stateA)
      case (None, Some(stateB)) => Some(stateB)
      case _ => None
    }

    aggregated.foreach { state => target.persist[S](this, state) }
  }

  private[deequ] def loadStateAndComputeMetric(source: StateLoader): Option[M] = {
    source.load[S](this).map { state =>
      computeMetricFrom(Option(state))
    }
  }

  /** Copy the state from source to target. Needs to be here to allow the compiler
    * to correctly infer the types.
    *
    * @param source state provider to read from
    * @param target state provider to write to
    */
  private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
    source.load[S](this).foreach { state => target.persist(this, state) }
  }

}

/** An analyzer that runs a set of aggregation functions over the data,
  * can share scans over the data */
trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S, M] {

  /** Defines the aggregations to compute on the data */
  private[deequ] def aggregationFunctions(): Seq[Column]

  /** Computes the state from the result of the aggregation functions */
  private[deequ] def fromAggregationResult(result: Row, offset: Int): Option[S]

  /** Runs aggregation functions directly, without scan sharing */
  override def computeStateFrom(data: DataFrame): Option[S] = {
    val aggregations = aggregationFunctions()
    val result = data.agg(aggregations.head, aggregations.tail: _*).collect().head
    fromAggregationResult(result, 0)
  }

  /** Produces a metric from the aggregation result */
  private[deequ] def metricFromAggregationResult(
      result: Row,
      offset: Int,
      aggregateWith: Option[StateLoader] = None,
      saveStatesWith: Option[StatePersister] = None)
    : M = {

    val state = fromAggregationResult(result, offset)

    calculateMetric(state, aggregateWith, saveStatesWith)
  }

}

/** A scan-shareable analyzer that produces a DoubleMetric */
abstract class StandardScanShareableAnalyzer[S <: DoubleValuedState[_]](
    name: String,
    instance: String,
    entity: Entity.Value = Entity.Column)
  extends ScanShareableAnalyzer[S, DoubleMetric] {

  override def computeMetricFrom(state: Option[S]): DoubleMetric = {
    state match {
      case Some(theState) =>
        metricFromValue(theState.metricValue(), name, instance, entity)
      case _ =>
        metricFromEmpty(this, name, instance, entity)
    }
  }

  override private[deequ] def toFailureMetric(exception: Exception): DoubleMetric = {
    metricFromFailure(exception, name, instance, entity)
  }

  override def preconditions: Seq[StructType => Unit] = {
    additionalPreconditions() ++ super.preconditions
  }

  protected def additionalPreconditions(): Seq[StructType => Unit] = {
    Seq.empty
  }
}

/** A state for computing ratio-based metrics,
  * contains #rows that match a predicate and overall #rows */
case class NumMatchesAndCount(numMatches: Long, count: Long)
  extends DoubleValuedState[NumMatchesAndCount] {

  override def sum(other: NumMatchesAndCount): NumMatchesAndCount = {
    NumMatchesAndCount(numMatches + other.numMatches, count + other.count)
  }

  override def metricValue(): Double = {
    if (count == 0L) {
      Double.NaN
    } else {
      numMatches.toDouble / count
    }
  }
}

/** Base class for analyzers that compute ratios of matching predicates */
abstract class PredicateMatchingAnalyzer(
    name: String,
    instance: String,
    predicate: Column,
    where: Option[String])
  extends StandardScanShareableAnalyzer[NumMatchesAndCount](name, instance) {

  override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

    if (result.isNullAt(offset) || result.isNullAt(offset + 1)) {
      None
    } else {
      val state = NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1))
      Some(state)
    }
  }

  override def aggregationFunctions(): Seq[Column] = {

    val selection = Analyzers.conditionalSelection(predicate, where)

    selection :: count("*") :: Nil
  }
}

/** Base class for analyzers that require to group the data by specific columns */
abstract class GroupingAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S, M] {

  /** The columns to group the data by */
  def groupingColumns(): Seq[String]

  /** Ensure that the grouping columns exist in the data */
  override def preconditions: Seq[StructType => Unit] = {
    groupingColumns().map { name => Preconditions.hasColumn(name) } ++ super.preconditions
  }
}

/** Helper method to check conditions on the schema of the data */
object Preconditions {

  private[this] val numericDataTypes =
    Set(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)

  private[this] val nestedDataTypes = Set(StructType, MapType, ArrayType)

  /* Return the first (potential) exception thrown by a precondition */
  def findFirstFailing(
      schema: StructType,
      conditions: Seq[StructType => Unit])
    : Option[Exception] = {

    conditions.map { condition =>
      try {
        condition(schema)
        None
      } catch {
        /* We only catch exceptions, not errors! */
        case e: Exception => Some(e)
      }
    }
    .find { _.isDefined }
    .flatten
  }

  /** At least one column is specified */
  def atLeastOne(columns: Seq[String]): StructType => Unit = { _ =>
    if (columns.isEmpty) {
      throw new NoColumnsSpecifiedException("At least one column needs to be specified!")
    }
  }

  /** At least one column is specified */
  def exactlyNColumns(columns: Seq[String], n: Int): StructType => Unit = { _ =>
    if (columns.size != n) {
      throw new NumberOfSpecifiedColumnsException(s"$n columns have to be specified! " +
        s"Currently, columns contains only ${columns.size} column(s): ${columns.mkString(",")}!")
    }
  }

  def isNotNested(column: String): StructType => Unit = { schema =>
    if (schema.fieldNames.contains(column)) {
      val columnDataType = schema(column).dataType
      columnDataType match {
        case _ : StructType | _ : MapType | _ : ArrayType =>
          throw new WrongColumnTypeException(
            s"Unsupported nested column type of column $column: $columnDataType!")
        case _ =>
      }
    }
  }

  /** Specified column exists in the data */
  def hasColumn(column: String): StructType => Unit = { schema =>
    if (!schema.fieldNames.contains(column)) {
      throw new NoSuchColumnException(s"Input data does not include column $column!")
    }
  }

  /** Specified column has a numeric type */
  def isNumeric(column: String): StructType => Unit = { schema =>
    val columnDataType = schema(column).dataType
    val hasNumericType = columnDataType match {
      case ByteType | ShortType | IntegerType | LongType | FloatType |
           DoubleType | _ : DecimalType => true
      case _ => false
    }

    if (!hasNumericType) {
      throw new WrongColumnTypeException(s"Expected type of column $column to be one of " +
        s"(${numericDataTypes.mkString(",")}), but found $columnDataType instead!")
    }
  }

  /** Specified column has string type */
  def isString(column: String): StructType => Unit = { schema =>
    val columnDataType = schema(column).dataType
    val hasStringType = columnDataType match {
      case StringType => true
      case _ => false
    }

    if (!hasStringType) {
      throw new WrongColumnTypeException(s"Expected type of column $column to be " +
        s"StringType, but found $columnDataType instead!")
    }
  }
}

private[deequ] object Analyzers {

  val COL_PREFIX = "com_amazon_deequ_dq_metrics_"
  val COUNT_COL = s"${COL_PREFIX}count"

  /** Merges a sequence of potentially empty states. */
  def merge[S <: State[_]](
      state: Option[S],
      anotherState: Option[S],
      moreStates: Option[S]*)
    : Option[S] = {

    val statesToMerge = Seq(state, anotherState) ++ moreStates

    statesToMerge.reduce { (stateA: Option[S], stateB: Option[S]) =>

      (stateA, stateB) match {
        case (Some(theStateA), Some(theStateB)) =>
          Some(theStateA.sumUntyped(theStateB).asInstanceOf[S])

        case (Some(_), None) => stateA
        case (None, Some(_)) => stateB
        case _ => None
      }
    }
  }

  /** Tests whether the result columns from offset to offset + howMany are non-null */
  def ifNoNullsIn[S <: State[_]](
      result: Row,
      offset: Int,
      howMany: Int = 1)
      (func: Unit => S)
    : Option[S] = {

    val nullInResult = (offset until offset + howMany).exists { index => result.isNullAt(index) }

    if (nullInResult) {
      None
    } else {
      Option(func(Unit))
    }
  }

  def entityFrom(columns: Seq[String]): Entity.Value = {
    if (columns.size == 1) Entity.Column else Entity.Mutlicolumn
  }

  def conditionalSelection(selection: String, where: Option[String]): Column = {
    conditionalSelection(col(selection), where)
  }

  def conditionalSelection(selection: Column, condition: Option[String]): Column = {
    val conditionColumn = condition.map { expression => expr(expression) }
    conditionalSelectionFromColumns(selection, conditionColumn)
  }

  private[this] def conditionalSelectionFromColumns(
      selection: Column,
      conditionColumn: Option[Column])
    : Column = {

    conditionColumn
      .map { condition => when(condition, selection) }
      .getOrElse(selection)
  }

  def conditionalCount(where: Option[String]): Column = {
    where
      .map { filter => sum(expr(filter).cast(LongType)) }
      .getOrElse(count("*"))
  }

  def metricFromValue(
      value: Double,
      name: String,
      instance: String,
      entity: Entity.Value = Entity.Column)
    : DoubleMetric = {

    DoubleMetric(entity, name, instance, Success(value))
  }

  def emptyStateException(analyzer: Analyzer[_, _]): EmptyStateException = {
    new EmptyStateException(s"Empty state for analyzer $analyzer, all input values were NULL.")
  }

  def metricFromEmpty(
      analyzer: Analyzer[_, _],
      name: String,
      instance: String,
      entity: Entity.Value = Entity.Column)
    : DoubleMetric = {
    metricFromFailure(emptyStateException(analyzer), name, instance, entity)
  }

  def metricFromFailure(
      exception: Throwable,
      name: String,
      instance: String,
      entity: Entity.Value = Entity.Column)
    : DoubleMetric = {

    DoubleMetric(entity, name, instance, Failure(
      MetricCalculationException.wrapIfNecessary(exception)))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy