All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.SummarizeData.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, DoubleParam, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{BooleanType, DoubleType, NumericType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.storage.StorageLevel

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer

trait SummarizeDataParams extends Wrappable with DefaultParamsWritable {

  /** Compute count statistics. Default is true.
    * @group param
    */
  final val counts: BooleanParam = new BooleanParam(this, "counts", "Compute count statistics")
  setDefault(counts -> true)

  /** @group getParam */
  final def getCounts: Boolean = $(counts)

  /** @group setParam */
  def setCounts(value: Boolean): this.type = set(counts, value)

  /** Compute basic statistics. Default is true.
    * @group param
    */
  final val basic: BooleanParam = new BooleanParam(this, "basic", "Compute basic statistics")
  setDefault(basic, true)

  /** @group getParam */
  final def getBasic: Boolean = $(basic)

  /** @group setParam */
  def setBasic(value: Boolean): this.type = set(basic, value)

  /** Compute sample statistics. Default is true.
    * @group param
    */
  final val sample: BooleanParam = new BooleanParam(this, "sample", "Compute sample statistics")
  setDefault(sample, true)

  /** @group getParam */
  final def getSample: Boolean = $(sample)

  /** @group setParam */
  def setSample(value: Boolean): this.type = set(sample, value)

  /** Compute percentiles. Default is true
    * @group param
    */
  final val percentiles: BooleanParam = new BooleanParam(this, "percentiles", "Compute percentiles")
  setDefault(percentiles, true)

  /** @group getParam */
  final def getPercentiles: Boolean = $(percentiles)

  /** @group setParam */
  def setPercentiles(value: Boolean): this.type = set(percentiles, value)

  /** Threshold for quantiles - 0 is exact
    * @group param
    */
  final val errorThreshold: DoubleParam =
    new DoubleParam(this, "errorThreshold", "Threshold for quantiles - 0 is exact")
  setDefault(errorThreshold, 0.0)

  /** @group getParam */
  final def getErrorThreshold: Double = $(errorThreshold)

  /** @group setParam */
  def setErrorThreshold(value: Double): this.type = set(errorThreshold, value)

  protected def validateAndTransformSchema(schema: StructType): StructType = {
    val columns = ListBuffer(SummarizeData.FeatureColumn)
    if ($(counts)) columns ++= SummarizeData.CountFields
    if ($(basic)) columns ++= SummarizeData.BasicFields
    if ($(sample)) columns ++= SummarizeData.SampleFields
    if ($(percentiles)) columns ++= SummarizeData.PercentilesFields
    StructType(columns)
  }
}

// UID should be overridden by driver for controlled identification at the DAG level
/** Compute summary statistics for the dataset. The following statistics are computed:
  * - counts
  * - basic
  * - sample
  * - percentiles
  * - errorThreshold - error threshold for quantiles
  * @param uid The id of the module
  */
class SummarizeData(override val uid: String)
  extends Transformer
    with SummarizeDataParams {

  import SummarizeData.Statistic._

  def this() = this(Identifiable.randomUID("SummarizeData"))

  override def transform(dataset: Dataset[_]): DataFrame = {

    val df = dataset.toDF()
    // Some of these statistics are bad to compute
    df.persist(StorageLevel.MEMORY_ONLY)

    val subFrames = ListBuffer[DataFrame]()
    if ($(counts)) subFrames += computeCounts(df)
    if ($(basic)) subFrames += curriedBasic(df)
    if ($(sample)) subFrames += sampleStats(df)
    if ($(percentiles)) subFrames += curriedPerc(df)

    df.unpersist(false)

    val base = createJoinBase(df)
    subFrames.foldLeft(base) { (z, dfi) => z.join(dfi, SummarizeData.FeatureColumnName) }
  }

  def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }

  def copy(extra: ParamMap): SummarizeData = defaultCopy(extra)

  private def computeCounts = computeOnAll(computeCountsImpl, SummarizeData.CountFields)

  private def computeCountsImpl(col: String, df: DataFrame): Array[Double] = {
    val column = df.col(col)
    val dataType = df.schema(col).dataType
    val mExpr = isnull(column) || (if (dataType.equals(BooleanType)) isnan(column.cast(DoubleType)) else isnan(column))

    val countMissings = df.where(mExpr).count().toDouble
    // approxCount returns Long which > Double!
    val dExpr = approx_count_distinct(column)
    val distinctCount = df.select(dExpr).first.getLong(0).toDouble
    Array(df.count() - countMissings, distinctCount, countMissings)
  }

  private def sampleStats = computeOnNumeric(sampleStatsImpl, SummarizeData.SampleFields)

  private def sampleStatsImpl(col: String, df: DataFrame): Array[Double] = {
    val column = df.col(col)
    val k = kurtosis(column)
    val sk = skewness(column)
    val v = variance(column)
    val sd = stddev(column)
    df.select(v, sd, sk, k).first.toSeq.map(_.asInstanceOf[Double]).toArray
  }

  private def curriedBasic = {
    val quants = SummarizeData.BasicQuantiles
    computeOnNumeric(quantStub(quants, $(errorThreshold)), SummarizeData.BasicFields)
  }

  private def curriedPerc = {
    val quants = SummarizeData.PercentilesQuantiles
    computeOnNumeric(quantStub(quants, $(errorThreshold)), SummarizeData.PercentilesFields)
  }

  private def quantStub(vals: Array[Double], err: Double) =
    (cn: String, df: DataFrame) => df.stat.approxQuantile(cn, vals, err)

  private def computeOnNumeric = computeColumnStats(sf => sf.dataType.isInstanceOf[NumericType]) _

  private def computeOnAll = computeColumnStats(sf => true) _

  private def allNaNs(l: Int): Array[Double] = Array.fill(l)(Double.NaN)

  private def createJoinBase(df: DataFrame) = computeColumnStats(sf => false)((cn, df) => Array(), List())(df)

  private def computeColumnStats
  (p: StructField => Boolean)
  (statFunc: (String, DataFrame) => Array[Double], newColumns: Seq[StructField])
  (df: DataFrame): DataFrame = {
    val emptyRow = allNaNs(newColumns.length)
    val outList = df.schema.map(col => (col.name, if (p(col)) statFunc(col.name, df) else emptyRow))
    val rows = outList.map { case (n, r) => Row.fromSeq(n +: r) }
    val schema = SummarizeData.FeatureColumn +: newColumns
    df.sparkSession.createDataFrame(rows.asJava, StructType(schema))
  }

}

object SummarizeData extends DefaultParamsReadable[SummarizeData] {

  object Statistic extends Enumeration {
    type Statistic = Value
    val Counts, Basic, Sample, Percentiles = Value
  }

  final val FeatureColumnName = "Feature"
  final val FeatureColumn = StructField(FeatureColumnName, StringType, false)

  final val PercentilesQuantiles = Array(0.005, 0.01, 0.05, 0.95, 0.99, 0.995)
  final val PercentilesFields = List(
    StructField("P0_5", DoubleType, true),
    StructField("P1", DoubleType, true),
    StructField("P5", DoubleType, true),
    StructField("P95", DoubleType, true),
    StructField("P99", DoubleType, true),
    StructField("P99_5", DoubleType, true))

  final val SampleFields = List(
    StructField("Sample_Variance", DoubleType, true),
    StructField("Sample_Standard_Deviation", DoubleType, true),
    StructField("Sample_Skewness", DoubleType, true),
    StructField("Sample_Kurtosis", DoubleType, true))

  final val BasicQuantiles = Array(0, 0.25, 0.5, 0.75, 1)
  final val BasicFields = List(
    StructField("Min", DoubleType, true),
    StructField("1st_Quartile", DoubleType, true),
    StructField("Median", DoubleType, true),
    StructField("3rd_Quartile", DoubleType, true),
    StructField("Max", DoubleType, true)
    //TODO: StructField("Range", DoubleType, true),
    //TODO: StructField("Mean", DoubleType, true),
    //TODO: StructField("Mean Deviation", DoubleType, true),
    // Mode is JSON Array of modes - needs a little special treatment
    //TODO: StructField("Mode", StringType, true))
  )

  final val CountFields = List(
    StructField("Count", DoubleType, false),
    StructField("Unique_Value_Count", DoubleType, false),
    StructField("Missing_Value_Count", DoubleType, false))
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy