All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.featurize.CleanMissingData.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.featurize
import com.microsoft.ml.spark.core.contracts.{HasInputCols, HasOutputCols, Wrappable}
import com.microsoft.ml.spark.core.serialize.{ConstructorReadable, ConstructorWritable}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types._

import reflect.runtime.universe.{TypeTag, typeTag}
import scala.collection.mutable.ListBuffer

object CleanMissingData extends DefaultParamsReadable[CleanMissingData] {
  val MeanOpt = "Mean"
  val MedianOpt = "Median"
  val CustomOpt = "Custom"
  val Modes = Array(MeanOpt, MedianOpt, CustomOpt)

  def validateAndTransformSchema(schema: StructType,
                                 inputCols: Array[String],
                                 outputCols: Array[String]): StructType = {
    inputCols.zip(outputCols).foldLeft(schema)((oldSchema, io) => {
      if (oldSchema.fieldNames.contains(io._2)) {
        val index = oldSchema.fieldIndex(io._2)
        val fields = oldSchema.fields
        fields(index) = oldSchema.fields(oldSchema.fieldIndex(io._1))
        StructType(fields)
      } else {
        oldSchema.add(oldSchema.fields(oldSchema.fieldIndex(io._1)))
      }
    })
  }
}

/** Removes missing values from input dataset.
  * The following modes are supported:
  *   Mean   - replaces missings with mean of fit column
  *   Median - replaces missings with approximate median of fit column
  *   Custom - replaces missings with custom value specified by user
  * For mean and median modes, only numeric column types are supported, specifically:
  *   `Int`, `Long`, `Float`, `Double`
  * For custom mode, the types above are supported and additionally:
  *   `String`, `Boolean`
  */
class CleanMissingData(override val uid: String) extends Estimator[CleanMissingDataModel]
  with HasInputCols with HasOutputCols with Wrappable with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("CleanMissingData"))

  val cleaningMode: Param[String] = new Param[String](this, "cleaningMode", "Cleaning mode")
  setDefault(cleaningMode->CleanMissingData.MeanOpt)
  def setCleaningMode(value: String): this.type = set(cleaningMode, value)
  def getCleaningMode: String = $(cleaningMode)

  /** Custom value for imputation, supports numeric, string and boolean types.
    * Date and Timestamp currently not supported.
    */
  val customValue: Param[String] = new Param[String](this, "customValue", "Custom value for replacement")
  def setCustomValue(value: String): this.type = set(customValue, value)
  def getCustomValue: String = $(customValue)

  /** Fits the dataset, prepares the transformation function.
    *
    * @param dataset The input dataset.
    * @return The model for removing missings.
    */
  override def fit(dataset: Dataset[_]): CleanMissingDataModel = {
    val replacementValues = getReplacementValues(dataset, getInputCols, getOutputCols, getCleaningMode)
    new CleanMissingDataModel(uid, replacementValues, getInputCols, getOutputCols)
  }

  override def copy(extra: ParamMap): Estimator[CleanMissingDataModel] = defaultCopy(extra)

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType =
    CleanMissingData.validateAndTransformSchema(schema, getInputCols, getOutputCols)

  // Only support numeric types (no string, boolean or Date/Timestamp) for numeric imputation
  private def isSupportedTypeForImpute(dataType: DataType): Boolean = dataType.isInstanceOf[NumericType]

  private def verifyColumnsSupported(dataset: Dataset[_], colsToClean: Array[String]): Unit = {
    if (colsToClean.exists(name => !isSupportedTypeForImpute(dataset.schema(name).dataType)))
      throw new UnsupportedOperationException("Only numeric types supported for numeric imputation")
  }

  private def rowToValues(row: Row): Array[Double] = {
    (0 until row.size).map { i =>
      row.get(i) match {
        case n: Double => n
        case n: Int => n.toDouble
        case _ => throw new UnsupportedOperationException("Unknown type in row")
      }
    }.toArray
  }

  private def getReplacementValues(dataset: Dataset[_],
                                   colsToClean: Array[String],
                                   outputCols: Array[String],
                                   mode: String): Map[String, Any] = {
    import org.apache.spark.sql.functions._
    val columns = colsToClean.map(col => dataset(col))
    val metrics = getCleaningMode match {
      case CleanMissingData.MeanOpt =>
        // Verify columns are supported for imputation
        verifyColumnsSupported(dataset, colsToClean)
        val row = dataset.select(columns.map(column => avg(column)): _*).collect()(0)
        rowToValues(row)
      case CleanMissingData.MedianOpt =>
        // Verify columns are supported for imputation
        verifyColumnsSupported(dataset, colsToClean)
        val row =
          dataset.select(columns.map(column => callUDF("percentile_approx",
                                                       column, lit(0.5))): _*)
          .collect()(0)
        rowToValues(row)
      case CleanMissingData.CustomOpt =>
        // Note: All column types supported for custom value
        colsToClean.map(col => getCustomValue)
    }
    outputCols.zip(metrics).toMap
  }

}

/** Model produced by [[CleanMissingData]]. */
class CleanMissingDataModel(val uid: String,
                            val replacementValues: Map[String, Any],
                            val inputCols: Array[String],
                            val outputCols: Array[String])
    extends Model[CleanMissingDataModel] with ConstructorWritable[CleanMissingDataModel] {

  val ttag: TypeTag[CleanMissingDataModel] = typeTag[CleanMissingDataModel]
  def objectsToSave: List[Any] = List(uid, replacementValues, inputCols, outputCols)

  override def copy(extra: ParamMap): CleanMissingDataModel =
    new CleanMissingDataModel(uid, replacementValues, inputCols, outputCols)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val datasetCols = dataset.columns.map(name => dataset(name)).toList
    val datasetInputCols = inputCols.zip(outputCols)
      .flatMap(io =>
        if (io._1 == io._2) {
          None
        } else {
          Some(dataset(io._1).as(io._2))
        }).toList
    val addedCols = dataset.select(datasetCols ::: datasetInputCols: _*)
    addedCols.na.fill(replacementValues)
  }

  @DeveloperApi
  override def transformSchema(schema: StructType): StructType =
    CleanMissingData.validateAndTransformSchema(schema, inputCols, outputCols)
}

object CleanMissingDataModel extends ConstructorReadable[CleanMissingDataModel]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy