All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazon.deequ.analyzers.MutualInformation.scala Maven / Gradle / Ivy

/**
 * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not
 * use this file except in compliance with the License. A copy of the License
 * is located at
 *
 *     http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 *
 */

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.COUNT_COL
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.StructType

/**
  * Mutual Information describes how much information about one column can be inferred from another
  * column.
  *
  * If two columns are independent of each other, then nothing can be inferred from one column about
  * the other, and mutual information is zero. If there is a functional dependency of one column to
  * another and vice versa, then all information of the two columns are shared, and mutual
  * information is the entropy of each column.
  */
case class MutualInformation(columns: Seq[String], where: Option[String] = None)
  extends FrequencyBasedAnalyzer(columns)
    with FilterableAnalyzer {

  override def computeMetricFrom(state: Option[FrequenciesAndNumRows]): DoubleMetric = {

    state match {

      case Some(theState) =>
        val total = theState.numRows
        val Seq(col1, col2) = columns

        val freqCol1 = s"__deequ_f1_$col1"
        val freqCol2 = s"__deequ_f2_$col2"

        val jointStats = theState.frequencies

        val marginalStats1 = jointStats
          .select(col1, COUNT_COL)
          .groupBy(col1)
          .agg(sum(COUNT_COL).as(freqCol1))

        val marginalStats2 = jointStats
          .select(col2, COUNT_COL)
          .groupBy(col2)
          .agg(sum(COUNT_COL).as(freqCol2))


        val miUdf = udf {
          (px: Double, py: Double, pxy: Double) =>
            (pxy / total) * math.log((pxy / total) / ((px / total) * (py / total)))
        }

        val miCol = s"__deequ_mi_${col1}_$col2"
        val value = jointStats
          .join(marginalStats1, usingColumn = col1)
          .join(marginalStats2, usingColumn = col2)
          .withColumn(miCol, miUdf(col(freqCol1), col(freqCol2), col(COUNT_COL)))
          .agg(sum(miCol))

        val resultRow = value.head()

        if (resultRow.isNullAt(0)) {
          metricFromEmpty(this, "MutualInformation", columns.mkString(","), Entity.Multicolumn)
        } else {
          metricFromValue(resultRow.getDouble(0), "MutualInformation", columns.mkString(","),
            Entity.Multicolumn)
        }

      case None =>
        metricFromEmpty(this, "MutualInformation", columns.mkString(","), Entity.Multicolumn)
    }
  }


  /** We need at least one grouping column, and all specified columns must exist */
  override def preconditions: Seq[StructType => Unit] = {
    Preconditions.exactlyNColumns(columns, 2) +: super.preconditions
  }

  override def toFailureMetric(exception: Exception): DoubleMetric = {
    metricFromFailure(exception, "MutualInformation", columns.mkString(","), Entity.Multicolumn)
  }

  override def filterCondition: Option[String] = where
}

object MutualInformation {
  def apply(columnA: String, columnB: String): MutualInformation = {
    new MutualInformation(columnA :: columnB :: Nil)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy