All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.labs.automl.sanitize.FeatureCorrelationDetection.scala Maven / Gradle / Ivy

package com.databricks.labs.automl.sanitize

import com.databricks.labs.automl.exceptions.FeatureCorrelationException
import com.databricks.labs.automl.utils.SparkSessionWrapper
import com.databricks.labs.automl.utils.structures.{
  FieldCorrelationAggregationStats,
  FieldCorrelationPayload,
  FieldPairs,
  FieldRemovalPayload
}
import org.apache.log4j.Logger
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool

class FeatureCorrelationDetection(data: DataFrame, fieldListing: Array[String])
    extends SparkSessionWrapper {

  private final val DEVIATION = "_deviation"
  private final val SQUARED = "_squared"
  private final val COV = "covariance_value"
  private final val PRODUCT = "_product"

  private val logger: Logger = Logger.getLogger(this.getClass)

  private var _correlationCutoffHigh: Double = 0.0
  private var _correlationCutoffLow: Double = 0.0
  private var _labelCol: String = "label"
  private var _parallelism: Int = 20

  final private val _dataFieldNames = data.schema.fieldNames

  def setCorrelationCutoffHigh(value: Double): this.type = {
    require(
      value <= 1.0,
      "Maximum range of Correlation Cutoff on the high end must be less than 1.0"
    )
    _correlationCutoffHigh = value
    this
  }

  def setCorrelationCutoffLow(value: Double): this.type = {
    require(
      value >= -1.0,
      "Minimum range of Correlation Cutoff on the low end must be greater than -1.0"
    )
    _correlationCutoffLow = value
    this
  }

  def setLabelCol(value: String): this.type = {
    require(
      _dataFieldNames.contains(value),
      s"Label field $value is not in Dataframe"
    )
    _labelCol = value
    this
  }

  def setParallelism(value: Int): this.type = {
    _parallelism = value
    this
  }

  def getCorrelationCutoffHigh: Double = _correlationCutoffHigh

  def getCorrelationCutoffLow: Double = _correlationCutoffLow

  def getLabelCol: String = _labelCol
  def getParallelism: Int = _parallelism

  def filterFeatureCorrelation(): DataFrame = {

    assert(
      _dataFieldNames.contains(_labelCol),
      s"Label field ${_labelCol} is not in Dataframe"
    )
    val featureCorrelation = determineFieldsToDrop
    data.drop(featureCorrelation.dropFields: _*)
  }

  /**
    * Create the left/right testing pairs to be used in determining correlation between feature fields
    * @return Array of distinct pairs of feature fields to test
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  def buildFeaturePairs(): Array[FieldPairs] = {

    fieldListing
      .combinations(2)
      .map { case Array(x, y) => FieldPairs(x, y) }
      .toArray

  }

  /**
    * Method for calculating all of the pairwise correlation calculations for the feature fields
    * @return Array of FieldCorrelationPayload data (left/right name pairs and the correlation value)
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  def calculateFeatureCorrelation: Array[FieldCorrelationPayload] = {

    val aggregationStats = getAggregationStats

    val interactionData = new ArrayBuffer[FieldCorrelationPayload]

    val taskSupport = new ForkJoinTaskSupport(new ForkJoinPool(_parallelism))
    val pairs = buildFeaturePairs().par
    pairs.tasksupport = taskSupport

    pairs.foreach { x =>
      interactionData += calculateCorrelation(
        x,
        aggregationStats.rowCounts,
        aggregationStats.averageMap
      )
    }

    val preSort = interactionData.toArray.groupBy(_.primaryColumn)
    fieldListing.flatMap(x => preSort.get(x)).flatten

  }

  /**
    * Private method for determining the interaction combinations for the recursive pairwise comparison and
    * evaluating whether a particular column has positive or negative correlation to all other columns
    * and therefore should be dropped from the dataset.
    * @param data Array[FieldCorrelationPayload] that represents the pairwise correlation between fields
    * @return Map[String, Double] with each field and the percentage of other fields it meets the criteria for
    *         filtering based on the correlation cutoff values.
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  private def calculateGroupStats(
    data: Array[FieldCorrelationPayload]
  ): Map[String, Double] = {
    data.groupBy(_.primaryColumn).map {
      case (k, v) =>
        val positiveCounts =
          v.count(_.correlation >= _correlationCutoffHigh).toDouble
        val negativeCounts =
          v.count(_.correlation <= _correlationCutoffLow).toDouble
        k -> (positiveCounts + negativeCounts) / v.length
    }
  }

  /**
    * Method for determining which columns need to be dropped from the feature set based on the correlation cutoff settings
    * @return FieldRemovalPayload that contains the removal and retain fields.
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    * @throws FeatureCorrelationException: totalFields, removedFields
    */
  @throws(classOf[RuntimeException])
  def determineFieldsToDrop: FieldRemovalPayload = {

    val retainBuffer = mutable.SortedSet[String]()
    val removeBuffer = mutable.SortedSet[String]()

    val correlationResult = calculateFeatureCorrelation
    val groupData = calculateGroupStats(correlationResult)

    correlationResult.foreach(x => {

      if (!removeBuffer
            .contains(x.primaryColumn) && groupData(x.primaryColumn) < 1.0) {
        if (x.correlation >= _correlationCutoffHigh) {
          retainBuffer += x.primaryColumn
          removeBuffer += x.pairs.right
        } else if (x.correlation <= _correlationCutoffLow) {
          retainBuffer += x.primaryColumn
          removeBuffer += x.pairs.right
        } else {
          retainBuffer += x.primaryColumn
        }
      } else {
        removeBuffer += x.primaryColumn
      }

    })

    // Validation
    if (retainBuffer.isEmpty)
      throw FeatureCorrelationException(
        fieldListing,
        removeBuffer.result.toArray
      )
    FieldRemovalPayload(removeBuffer.toArray, retainBuffer.toArray)

  }

  /**
    * Debug method to allow for an inspection of the correlation between each feature value to one another
    * @return DataFrame that contains the pair information and the correlation values of those pairs
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  def generateFeatureCorrelationReport: DataFrame = {

    import spark.sqlContext.implicits._

    sc.parallelize(calculateFeatureCorrelation).toDF

  }

  /**
    * Private method for accessing the required data needed for correlation calculation (one-time calculation)
    * @return FieldCorrelationAggregationStats of row counts and map of average values for each feature field
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  private def getAggregationStats: FieldCorrelationAggregationStats = {

    val summaryMap = data
      .select(fieldListing map col: _*)
      .summary("mean")
      .filter(col("summary") === "mean")
      .drop("summary")
      .first()
      .getValuesMap[Double](fieldListing)

    val rowCounts = data
      .select(_labelCol)
      .agg(count(_labelCol).alias(_labelCol))
      .withColumn(_labelCol, col(_labelCol).cast(DoubleType))
      .first()
      .getAs[Double](_labelCol)
    FieldCorrelationAggregationStats(rowCounts, summaryMap)
  }

  /**
    * Private method for executing a pair-wise correlation calculation between two feature fields
    * @param pair the left/right pair of feature fields to calculate
    * @param rowCount Number of rows within the data set
    * @param averages average value of each feature as a Map
    * @return FieldCorrelationPayload that contains the pair and the correlation value
    * @since 0.6.2
    * @author Ben Wilson, Databricks
    */
  private def calculateCorrelation(
    pair: FieldPairs,
    rowCount: Double,
    averages: Map[String, Double]
  ): FieldCorrelationPayload = {

    // Establish a subset DataFrame that contains only the two columns being tested
    val subsetDataFrame = data
      .select(pair.left, pair.right)
      .withColumn(pair.left, col(pair.left).cast(DoubleType))
      .withColumn(pair.right, col(pair.right).cast(DoubleType))

    // Get the map of the values required to support the linear correlation calculation
    val covarianceMap = subsetDataFrame
      .withColumn(pair.left + DEVIATION, col(pair.left) - averages(pair.left))
      .withColumn(
        pair.right + DEVIATION,
        col(pair.right) - averages(pair.right)
      )
      .withColumn(pair.left + SQUARED, pow(col(pair.left), 2))
      .withColumn(pair.right + SQUARED, pow(col(pair.right), 2))
      .withColumn(COV, col(pair.left + DEVIATION) * col(pair.right + DEVIATION))
      .withColumn(PRODUCT, col(pair.left) * col(pair.right))
      .agg(
        sum(pair.left).alias(pair.left),
        sum(pair.right).alias(pair.right),
        sum(PRODUCT).alias(PRODUCT),
        sum(COV).alias(COV),
        sum(pair.left + SQUARED).alias(pair.left + SQUARED),
        sum(pair.right + SQUARED).alias(pair.right + SQUARED)
      )
      .first()
      .getValuesMap[Double](
        Seq(
          COV,
          pair.left,
          pair.right,
          PRODUCT,
          pair.left + SQUARED,
          pair.right + SQUARED
        )
      )

    // Calculate the correlation
    val linearCorrelationCoefficient = (covarianceMap(PRODUCT) - (covarianceMap(
      pair.left
    ) * covarianceMap(pair.right) / rowCount)) /
      math.sqrt(
        (covarianceMap(pair.left + SQUARED) - math
          .pow(covarianceMap(pair.left), 2.0) / rowCount) * (covarianceMap(
          pair.right + SQUARED
        ) - math.pow(covarianceMap(pair.right), 2.0) / rowCount)
      )

    FieldCorrelationPayload(pair.left, pair, linearCorrelationCoefficient)

  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy