All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.core.schema.DatasetExtensions.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.schema

import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType

import scala.collection.mutable

/** Contains methods for manipulating spark dataframes and datasets. */
object DatasetExtensions {

  implicit class MMLDataFrame(val df: DataFrame) extends AnyVal {
    /** Finds an unused column name given initial column name in the given schema.
      * The unused column name will be given prefix with a number appended to it, eg "testColumn_5".
      * There will be an underscore between the column name and the number appended.
      *
      * @return The unused column name.
      */
    def withDerivativeCol(prefix: String): String = {
      val columnNamesSet = mutable.HashSet(df.columns: _*)
      findUnusedColumnName(prefix)(columnNamesSet)
    }

    /** Gets the column values as the given type.
      * @param colname The column name to retrieve from.
      * @tparam T The type to retrieve.
      * @return The sequence of values in the column.
      */
    def getColAs[T](colname: String): Seq[T] = {
      df.select(colname).collect.map(_.getAs[T](0))
    }

    /** Gets the spark sparse vector column.
      * @return The spark sparse vector column.
      */
    def getSVCol: String => Seq[SparseVector] = getColAs[SparseVector] _

    /** Gets the spark dense vector column.
      * @return The spark dense vector column.
      */
    def getDVCol: String => Seq[DenseVector] = getColAs[DenseVector] _
  }

  /** Finds an unused column name given initial column name and a list of existing column names.
    * The unused column name will be given prefix with a number appended to it, eg "testColumn_5".
    * There will be an underline between the column name and the number appended.
    *
    * @return The unused column name.
    */
  def findUnusedColumnName(prefix: String)(columnNames: scala.collection.Set[String]): String = {
    var counter = 2
    var unusedColumnName = prefix
    while (columnNames.contains(unusedColumnName)) {
      unusedColumnName += "_" + counter
      counter += 1
    }
    unusedColumnName
  }

  def findUnusedColumnName(prefix: String, schema: StructType): String = {
    findUnusedColumnName(prefix)(schema.fieldNames.toSet)
  }

  def findUnusedColumnName(prefix: String, df: DataFrame): String = {
    findUnusedColumnName(prefix, df.schema)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy