All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.schema.SparkSchema.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.schema

import com.microsoft.ml.spark.core.schema.SchemaConstants._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types._

/** Schema modification and information retrieval methods. */
object SparkSchema {

  /** Sets the label column name.
    *
    * @param dataset    The dataset to set the label column name on.
    * @param modelName  The model name.
    * @param columnName The column name to set as the label.
    * @param scoreValueKindModel The model type.
    * @return The modified dataset.
    */
  def setLabelColumnName: (DataFrame, String, String, String) => DataFrame =
    setColumnName(TrueLabelsColumn)

  /** Sets the scored labels column name.
    *
    * @param dataset    The dataset to set the scored labels column name on.
    * @param modelName  The model name.
    * @param columnName The column name to set as the scored label.
    * @param scoreValueKindModel The model type.
    * @return The modified dataset.
    */
  def setScoredLabelsColumnName: (DataFrame, String, String, String) => DataFrame =
    setColumnName(ScoredLabelsColumn)

  /** Sets the scored probabilities column name.
    *
    * @param dataset    The dataset to set the scored probabilities column name on.
    * @param modelName  The model name.
    * @param columnName The column name to set as the scored probability.
    * @param scoreValueKindModel The model type.
    * @return The modified dataset.
    */
  def setScoredProbabilitiesColumnName: (DataFrame, String, String, String) => DataFrame =
    setColumnName(ScoredProbabilitiesColumn)

  /** Sets the scores column name.
    *
    * @param dataset    The dataset to set the scores column name on.
    * @param modelName  The model name.
    * @param columnName The column name to set as the scores.
    * @param scoreValueKindModel The model type.
    * @return The modified dataset.
    */
  def setScoresColumnName: (DataFrame, String, String, String) => DataFrame =
    setColumnName(ScoresColumn)

  /** Gets the label column name.
    *
    * @param dataset   The dataset to get the label column from.
    * @param modelName The model to retrieve the label column from.
    * @return The label column name.
    */
  def getLabelColumnName(dataset: DataFrame, modelName: String): String =
    getScoreColumnKindColumn(TrueLabelsColumn)(dataset.schema, modelName)

  /** Gets the scored labels column name.
    *
    * @param dataset   The dataset to get the scored labels column from.
    * @param modelName The model to retrieve the scored labels column from.
    * @return The scored labels column name.
    */
  def getScoredLabelsColumnName(dataset: DataFrame, modelName: String): String =
    getScoreColumnKindColumn(ScoredLabelsColumn)(dataset.schema, modelName)

  /** Gets the scores column name.
    *
    * @param dataset   The dataset to get the scores column from.
    * @param modelName The model to retrieve the scores column from.
    * @return The scores column name.
    */
  def getScoresColumnName(dataset: DataFrame, modelName: String): String =
    getScoreColumnKindColumn(ScoresColumn)(dataset.schema, modelName)

  /** Gets the scored probabilities column name.
    *
    * @param dataset   The dataset to get the scored probabilities column from.
    * @param modelName The model to retrieve the scored probabilities column from.
    * @return The scored probabilities column name.
    */
  def getScoredProbabilitiesColumnName(dataset: DataFrame, modelName: String): String =
    getScoreColumnKindColumn(ScoredProbabilitiesColumn)(dataset.schema, modelName)

  /** Gets the label column name.
    *
    * @param dataset   The dataset to get the label column from.
    * @param modelName The model to retrieve the label column from.
    * @return The label column name.
    */
  def getLabelColumnName: (StructType, String) => String =
    getScoreColumnKindColumn(TrueLabelsColumn)

  /** Gets the scored labels column name.
    *
    * @param dataset   The dataset to get the scored labels column from.
    * @param modelName The model to retrieve the scored labels column from.
    * @return The scored labels column name.
    */
  def getScoredLabelsColumnName: (StructType, String) => String =
    getScoreColumnKindColumn(ScoredLabelsColumn)

  /** Gets the scores column name.
    *
    * @param dataset   The dataset to get the scores column from.
    * @param modelName The model to retrieve the scores column from.
    * @return The scores column name.
    */
  def getScoresColumnName: (StructType, String) => String =
    getScoreColumnKindColumn(ScoresColumn)

  /** Gets the scored probabilities column name.
    *
    * @param dataset   The dataset to get the scored probabilities column from.
    * @param modelName The model to retrieve the scored probabilities column from.
    * @return The scored probabilities column name.
    */
  def getScoredProbabilitiesColumnName: (StructType, String) => String =
    getScoreColumnKindColumn(ScoredProbabilitiesColumn)

  /** Gets the score value kind or null if it does not exist from a dataset.
    *
    * @param scoreColumnKindColumn The score column kind to retrieve.
    * @param dataset   The dataset to get the score column kind column name from.
    * @param modelName The model to retrieve the score column kind column name from.
    * @param columnName The column to retrieve the score value kind from.
    * @return
    */
  def getScoreValueKind(dataset: DataFrame, modelName: String, columnName: String): String = {
    getScoreValueKind(dataset.schema, modelName, columnName)
  }

  /** Gets the score value kind or null if it does not exist from the schema.
    *
    * @param scoreColumnKindColumn The score column kind to retrieve.
    * @param schema   The schema to get the score column kind column name from.
    * @param modelName The model to retrieve the score column kind column name from.
    * @param columnName The column to retrieve the score value kind from.
    * @return
    */
  def getScoreValueKind(schema: StructType, modelName: String, columnName: String): String = {
    val metadata = schema(columnName).metadata
    if (metadata == null) return null
    getMetadataFromModule(metadata, modelName, ScoreValueKind)
  }

  /** Sets the score column kind.
    *
    * @param scoreColumnKindColumn The score column kind column.
    * @param dataset               The dataset to set the score column kind on.
    * @param modelName             The model name.
    * @param columnName            The column name to set as the specified score column kind.
    * @param scoreValueKindModel   The model type.
    * @return
    */
  private def setColumnName(scoreColumnKindColumn: String)
                           (dataset: DataFrame, modelName: String,
                            columnName: String, scoreValueKindModel: String): DataFrame = {
    dataset.withColumn(columnName,
      dataset.col(columnName).as(columnName,
        updateMetadata(dataset.schema(columnName).metadata,
          scoreColumnKindColumn, scoreValueKindModel, modelName)))
  }

  /** Gets the score column kind column name or null if it does not exist.
    *
    * @param scoreColumnKindColumn The score column kind to retrieve.
    * @param schema   The schema to get the score column kind column name from.
    * @param modelName The model to retrieve the score column kind column name from.
    * @return
    */
  private def getScoreColumnKindColumn(scoreColumnKindColumn: String)
                                      (schema: StructType, modelName: String): String = {
    val structField = schema.find {
      case StructField(_, _, _, metadata) =>
        getMetadataFromModule(metadata, modelName, ScoreColumnKind) == scoreColumnKindColumn
    }
    if (structField.isEmpty) null else structField.get.name
  }

  private def updateMetadata(metadata: Metadata, scoreColumnKindColumn: String,
                             scoreValueKindModel: String, moduleName: String): Metadata = {
    val mmltagMetadata =
      if (metadata.contains(MMLTag)) metadata.getMetadata(MMLTag)
      else null
    val moduleNameMetadata =
      if (mmltagMetadata != null && mmltagMetadata.contains(moduleName))
        mmltagMetadata.getMetadata(moduleName)
      else null

    val moduleMetadataBuilder = new MetadataBuilder()
    if (mmltagMetadata != null && moduleNameMetadata != null) {
      moduleMetadataBuilder.withMetadata(moduleNameMetadata)
    }
    moduleMetadataBuilder.putString(ScoreColumnKind, scoreColumnKindColumn)
    moduleMetadataBuilder.putString(ScoreValueKind, scoreValueKindModel)

    val moduleBuilder = new MetadataBuilder()
    if (mmltagMetadata != null) {
      moduleBuilder.withMetadata(mmltagMetadata)
    }
    moduleBuilder.putMetadata(moduleName, moduleMetadataBuilder.build())

    new MetadataBuilder()
      .withMetadata(metadata)
      .putMetadata(MMLTag, moduleBuilder.build())
      .build()
  }

  private def getMetadataFromModule(colMetadata: Metadata, moduleName: String, tag: String): String = {
    if (!colMetadata.contains(MMLTag)) return null
    val mlTagMetadata = colMetadata.getMetadata(MMLTag)
    if (!mlTagMetadata.contains(moduleName)) return null
    val modelMetadata = mlTagMetadata.getMetadata(moduleName)
    if (!modelMetadata.contains(tag)) return null
    modelMetadata.getString(tag)
  }

  /** Find if the given column is a string */
  def isString(df: DataFrame, column: String): Boolean = {
    df.schema(column).dataType == DataTypes.StringType
  }

  /** Find if the given column is numeric */
  def isNumeric(df: DataFrame, column: String): Boolean = {
    df.schema(column).dataType.isInstanceOf[NumericType]
  }

  /** Find if the given column is boolean */
  def isBoolean(df: DataFrame, column: String): Boolean = {
    df.schema(column).dataType.isInstanceOf[BooleanType]
  }

  /** Find if the given column is Categorical; use CategoricalColumnInfo for more details */
  def isCategorical(df: DataFrame, column: String): Boolean = {
    val info = new CategoricalColumnInfo(df, column)
    info.isCategorical
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy