All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.timgent.sparkdataquality.metrics.MetricCalculator.scala Maven / Gradle / Ivy

package com.github.timgent.sparkdataquality.metrics

import com.github.timgent.sparkdataquality.metrics.MetricValue.{DoubleMetric, LongMetric, NumericMetricValue}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, Row}

private[sparkdataquality] sealed trait MetricCalculator {
  type MetricType <: MetricValue

  def wrapMetricValue(metricValue: MetricType#T): MetricType
}

private[sparkdataquality] object MetricCalculator {

  /**
    * MetricCalculator that calculates metrics based on a simple aggregation function on the whole dataset
    */
  sealed trait SimpleMetricCalculator extends MetricCalculator {
    def aggFunction: Column

    def valueFromRow(row: Row, index: Int): MetricType =
      wrapMetricValue(row.getAs[MetricType#T](index))

    def filter: MetricFilter
  }

  case class SizeMetricCalculator(filter: MetricFilter) extends SimpleMetricCalculator {
    override type MetricType = LongMetric

    override def aggFunction: Column = {
      sum(when(filter.filter, 1).otherwise(0))
    }

    override def wrapMetricValue(metricValue: Long): LongMetric = LongMetric(metricValue)
  }

  case class ComplianceMetricCalculator(complianceFn: ComplianceFn, filter: MetricFilter) extends SimpleMetricCalculator {
    override type MetricType = DoubleMetric

    override def aggFunction: Column = {
      sum(when(filter.filter and complianceFn.definition, 1).otherwise(0)) /
        sum(when(filter.filter, 1).otherwise(0))
    }

    override def wrapMetricValue(metricValue: Double): DoubleMetric = DoubleMetric(metricValue)
  }

  case class SumValuesMetricCalculator[MV <: NumericMetricValue: MetricValueConstructor](onColumn: String, filter: MetricFilter) extends SimpleMetricCalculator {
    override type MetricType = MV

    override def aggFunction: Column = sum(when(filter.filter, col(onColumn)).otherwise(0))

    override def wrapMetricValue(metricValue: MV#T): MV = implicitly[MetricValueConstructor[MV]].apply(metricValue)
  }

  case class DistinctValuesMetricCalculator(onColumns: List[String], filter: MetricFilter) extends SimpleMetricCalculator {
    override type MetricType = LongMetric

    override def aggFunction: Column = {
      val countDistinctCols: List[Column] =
        onColumns.map(onColumn => when(not(filter.filter), null).otherwise(col(onColumn)))
      countDistinct(
        countDistinctCols.head,
        countDistinctCols.tail: _*
      ) // TODO: Handle empty col list case and bad filter case
    }

    override def wrapMetricValue(metricValue: Long): LongMetric = LongMetric(metricValue)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy