All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.schema.Categoricals.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.schema

/** Contains objects and functions to manipulate Categoricals */

import com.microsoft.ml.spark.core.schema.SchemaConstants._
import javassist.bytecode.DuplicateMemberException
import org.apache.spark.ml.attribute._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.injections.MetadataUtilities

import scala.reflect.ClassTag

object CategoricalUtilities {

  /** Sets the given levels on the column.
    * @return The modified dataset.
    */
  def setLevels(dataset: DataFrame, column: String, levels: Array[_]): DataFrame = {
    if (levels == null) dataset
    else {
      val nonNullLevels = levels.filter(_ != null)
      val hasNullLevels = nonNullLevels.length != levels.length
      dataset.withColumn(column,
        dataset.col(column).as(column,
          updateLevelsMetadata(dataset.schema(column).metadata,
            nonNullLevels,
            getCategoricalTypeForValue(nonNullLevels.head), hasNullLevels)))
    }
  }

  /** Update the levels on the existing metadata.
    * @param existingMetadata The existing metadata to add to.
    * @param levels The levels to add to the metadata.
    * @param dataType The datatype of the levels.
    * @return The new metadata.
    */
  def updateLevelsMetadata(existingMetadata: Metadata,
                           levels: Array[_],
                           dataType: DataType,
                           hasNullLevels: Boolean): Metadata = {
    val bldr =
      if (existingMetadata.contains(MMLTag)) {
        new MetadataBuilder().withMetadata(existingMetadata.getMetadata(MMLTag))
      } else {
        new MetadataBuilder()
      }
    bldr.putBoolean(Ordinal, false)
    bldr.putBoolean(HasNullLevels, hasNullLevels)
    dataType match {
      case DataTypes.StringType  => bldr.putStringArray(ValuesString, levels.asInstanceOf[Array[String]])
      case DataTypes.DoubleType  => bldr.putDoubleArray(ValuesDouble, levels.asInstanceOf[Array[Double]])
      // Ints require special treatment, because Spark does not have putIntArray yet:
      case DataTypes.IntegerType => bldr.putLongArray(ValuesInt, levels.asInstanceOf[Array[Int]].map(_.toLong))
      case DataTypes.LongType    => bldr.putLongArray(ValuesLong, levels.asInstanceOf[Array[Long]])
      case DataTypes.BooleanType => bldr.putBooleanArray(ValuesBool, levels.asInstanceOf[Array[Boolean]])
      case _           => throw new UnsupportedOperationException("Unsupported categorical data type: " + dataType)
    }
    val metadata = bldr.build()

    new MetadataBuilder().withMetadata(existingMetadata).putMetadata(MMLTag, metadata).build()
  }

  /** Gets the levels from the dataset.
    * @param schema The schema to get the levels from.
    * @param column The column to retrieve metadata levels from.
    * @return The levels.
    */
  def getLevels(schema: StructType, column: String): Option[Array[_]] = {
    val metadata = schema(column).metadata

    if (metadata.contains(MMLTag)) {
      val dataType: Option[DataType] = CategoricalColumnInfo.getDataType(metadata, false)
      if (dataType.isEmpty) None
      else {
        dataType.get match {
          case DataTypes.StringType => Some(getMap[String](metadata).levels)
          case DataTypes.LongType => Some(getMap[Long](metadata).levels)
          case DataTypes.IntegerType => Some(getMap[Int](metadata).levels)
          case DataTypes.DoubleType => Some(getMap[Double](metadata).levels)
          case DataTypes.BooleanType => Some(getMap[Boolean](metadata).levels)
          case default => throw new UnsupportedOperationException("Unknown categorical type: " + default.typeName)
        }
      }
    } else {
      None
    }
  }

  /** Get the map of array of T from the metadata.
    *
    * @param ct Implicit class tag.
    * @param metadata The metadata to retrieve from.
    * @tparam T The type of map to retrieve.
    * @return The map of array of T.
    */
  def getMap[T](metadata: Metadata)(implicit ct: ClassTag[T]): CategoricalMap[T] = {
    val data =
      if (metadata.contains(MMLTag)) {
        metadata.getMetadata(MMLTag)
      } else if (metadata.contains(MLlibTag)) {
        metadata.getMetadata(MLlibTag)
      } else {
        sys.error("Invalid metadata to retrieve map from")
      }

    val hasNullLevel =
      if (data.contains(HasNullLevels)) data.getBoolean(HasNullLevels)
      else false
    val isOrdinal = false

    val categoricalMap = implicitly[ClassTag[T]] match  {
      case ClassTag.Int => new CategoricalMap[Int](data.getLongArray(ValuesInt).map(_.toInt), isOrdinal, hasNullLevel)
      case ClassTag.Double => new CategoricalMap[Double](data.getDoubleArray(ValuesDouble), isOrdinal, hasNullLevel)
      case ClassTag.Boolean => new CategoricalMap[Boolean](data.getBooleanArray(ValuesBool), isOrdinal, hasNullLevel)
      case ClassTag.Long => new CategoricalMap[Long](data.getLongArray(ValuesLong), isOrdinal, hasNullLevel)
      case _ => new CategoricalMap[String](data.getStringArray(ValuesString), isOrdinal, hasNullLevel)
    }
    categoricalMap.asInstanceOf[CategoricalMap[T]]
  }

  /** Get a type for the given value.
    * @param value The value to get the type from.
    * @tparam T The generic type of the value.
    * @return The DataType based on the value.
    */
  def getCategoricalTypeForValue[T](value: T): DataType = {
    value match {
      // Complicated type matching is requred to get around type erasure
      case _: String  => DataTypes.StringType
      case _: Double  => DataTypes.DoubleType
      case _: Int     => DataTypes.IntegerType
      case _: Long    => DataTypes.LongType
      case _: Boolean => DataTypes.BooleanType
      case _          => throw new UnsupportedOperationException("Unsupported categorical data type " + value)
    }
  }

}

/** A wrapper around level maps: Map[T -> Int] and Map[Int -> T] that converts
  *   the data to/from Spark Metadata in both MLib and AzureML formats.
  * @param levels  The level values are assumed to be already sorted as needed
  * @param isOrdinal  A flag that indicates if the data are ordinal
  * @tparam T  Input levels could be String, Double, Int, Long, Boolean
  */
class CategoricalMap[T](val levels: Array[T],
                        val isOrdinal: Boolean = false,
                        val hasNullLevel: Boolean = false) extends Serializable {
  require(levels.distinct.size == levels.size, "Categorical levels are not unique.")
  require(!levels.isEmpty, "Levels should not be empty")

  /** Total number of levels */
  val numLevels = levels.length //TODO: add the maximum possible number of levels?

  /** Spark DataType corresponding to type T */
  val dataType = CategoricalUtilities.getCategoricalTypeForValue(levels.find(_ != null).head)

  /** Maps levels to the corresponding integer index */
  private lazy val levelToIndex: Map[T, Int] = levels.zipWithIndex.toMap

  /** Returns the index of the given level, can throw */
  def getIndex(level: T): Int = levelToIndex(level)

  /** Returns the index of a given level as Option; does not throw */
  def getIndexOption(level: T): Option[Int] = levelToIndex.get(level)

  /** Checks if the given level exists */
  def hasLevel(level: T): Boolean = levelToIndex.contains(level)

  /** Returns the level of the given index; can throw */
  def getLevel(index: Int): T = levels(index)

  /** Returns the level of the given index as Option; does not throw */
  def getLevelOption(index: Int): Option[T] =
    if (index < 0 || index >= numLevels) None else Some(levels(index))

  /** Stores levels in Spark Metadata in MLlib format */
  private def toMetadataMllib(existingMetadata: Metadata): Metadata = {
    require(!isOrdinal, "Cannot save Ordinal data in MLlib Nominal format currently," +
                        " because it does not have a public constructor that accepts Ordinal")

    // Currently, MLlib converts all non-string categorical values to string;
    // see org.apache.spark.ml.feature.StringIndexer
    val strLevels = levels.filter(_ != null).map(_.toString).asInstanceOf[Array[String]]

    NominalAttribute.defaultAttr.withValues(strLevels).toMetadata(existingMetadata)
  }

  /** Stores levels in Spark Metadata in MML format */
  private def toMetadataMML(existingMetadata: Metadata): Metadata = {
    CategoricalUtilities.updateLevelsMetadata(existingMetadata, levels, dataType, hasNullLevel)
  }

  /** Add categorical levels to existing Spark Metadata
    * @param existingMetadata [tag, categorical metadata] pair is added to existingMetadata,
    *   where tag is either MLlib or MML
    * @param mmlStyle MML (true) or MLlib metadata (false)
    */
  def toMetadata(existingMetadata: Metadata, mmlStyle: Boolean): Metadata = {

    // assert that metadata does not have data with this tag
    def assertNoTag(tag: String) =
      assert(!existingMetadata.contains(tag),
             //TODO: add tests to ensure
             s"Metadata already contains the tag $tag; all the data are eraised")

    if (mmlStyle) {
      assertNoTag(MMLTag)
      toMetadataMML(existingMetadata)
    } else {
      assertNoTag(MLlibTag)
      toMetadataMllib(existingMetadata)
    }
  }

  /** Add categorical levels and in either MML or MLlib style metadata
    * @param mmlStyle MML (true) or MLlib metadata (false)
    */
  def toMetadata(mmlStyle: Boolean): Metadata = toMetadata(Metadata.empty, mmlStyle)

}

/** Utilities for getting categorical column info. */
object CategoricalColumnInfo {
  /** Gets the datatype from the column metadata.
    * @param columnMetadata The column metadata
    * @return The datatype
    */
  def getDataType(metadata: Metadata, throwOnInvalid: Boolean = true): Option[DataType] = {
    val mmlMetadata =
      if (metadata.contains(MMLTag)) {
        metadata.getMetadata(MMLTag)
      } else {
        throw new NoSuchFieldException(s"Could not find valid $MMLTag metadata")
      }
    val keys = MetadataUtilities.getMetadataKeys(mmlMetadata)
    val validatedDataType = keys.foldRight(None: Option[DataType])((metadataKey, result) => metadataKey match {
      case ValuesString => getValidated(result, DataTypes.StringType)
      case ValuesLong => getValidated(result, DataTypes.LongType)
      case ValuesInt => getValidated(result, DataTypes.IntegerType)
      case ValuesDouble => getValidated(result, DataTypes.DoubleType)
      case ValuesBool => getValidated(result, DataTypes.BooleanType)
      case _ => if (result.isDefined) result else None
    })
    if (validatedDataType.isEmpty && throwOnInvalid) {
      throw new NoSuchElementException("Unrecognized datatype or no datatype found in MML metadata")
    }
    validatedDataType
  }

  private def getValidated(result: Option[DataType], dataType: DataType): Option[DataType] = {
    if (result.isDefined) {
      throw new DuplicateMemberException("DataType metadata specified twice")
    }
    Option(dataType)
  }
}

/** Extract categorical info from the DataFrame column
  * @param df dataframe
  * @param column column name
  */
class CategoricalColumnInfo(df: DataFrame, column: String) {

  private val columnSchema   = df.schema(column)
  private val metadata = columnSchema.metadata

  /** Get the basic info: whether the column is categorical or not, actual type of the column, etc */
  val (isCategorical, isMML, isOrdinal, dataType, hasNullLevels) = {

    val notCategorical = (false, false, false, NullType, false)

    if (columnSchema.dataType != DataTypes.IntegerType
      && columnSchema.dataType != DataTypes.DoubleType) notCategorical
    else if (metadata.contains(MMLTag)) {
      val columnMetadata = metadata.getMetadata(MMLTag)

      if (!columnMetadata.contains(Ordinal)) notCategorical
      else {
        val isOrdinal = columnMetadata.getBoolean(Ordinal)
        val hasNullLevels =
          if (columnMetadata.contains(HasNullLevels)) columnMetadata.getBoolean(HasNullLevels)
          else false

        val dataType: DataType = CategoricalColumnInfo.getDataType(metadata).get

        (true, true, isOrdinal, dataType, hasNullLevels)
      }
    } else if (metadata.contains(MLlibTag)) {
      val columnMetadata = metadata.getMetadata(MLlibTag)
      // nominal metadata has ["type" -> "nominal"] pair
      val isCategorical = columnMetadata.contains(MLlibTypeTag) &&
                          columnMetadata.getString(MLlibTypeTag) == AttributeType.Nominal.name

      if (!isCategorical) notCategorical
      else {
        val isOrdinal = if (columnMetadata.contains(Ordinal)) columnMetadata.getBoolean(Ordinal) else false
        val hasNullLevels =
          if (columnMetadata.contains(HasNullLevels)) columnMetadata.getBoolean(HasNullLevels)
          else false
        val dataType =
          if (columnMetadata.contains(ValuesString)) DataTypes.StringType
          else throw new UnsupportedOperationException("nominal attribute does not contain string levels")
        (true, false, isOrdinal, dataType, hasNullLevels)
      }
    } else
      notCategorical
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy