All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.labs.automl.pipeline.DatasetsUnionTransformer.scala Maven / Gradle / Ivy

package com.databricks.labs.automl.pipeline

import com.databricks.labs.automl.utils.AutoMlPipelineMlFlowUtils
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset}

import scala.util.Sorting

/**
  * @author Jas Bali
  * A transformer stage that is useful to do joins on two datasets. It is useful
  * when there is a need to do a join on two datasets in the intermediate step of a pipeline
  *
  * NOTE: A transformer semantics does not allow to pass two datasets to a transform method.
  * As a workaround, the first dataset needs to be registered as a temp table outside of this transformer
  * using [[RegisterTempTableTransformer]] transformer.
  */
class DatasetsUnionTransformer(override val uid: String)
  extends AbstractTransformer
    with DefaultParamsWritable {

  final val unionDatasetName = new Param[String](this, "unionDatasetName", "unionDatasetName")

  def setUnionDatasetName(value: String): this.type = set(unionDatasetName, value)

  def getUnionDatasetName: String = $(unionDatasetName)

  def this() = {
    this(Identifiable.randomUID("DatasetsUnionTransformer"))
    setAutomlInternalId(AutoMlPipelineMlFlowUtils.AUTOML_INTERNAL_ID_COL)
    setDebugEnabled(false)
  }

  override def transformInternal(dataset: Dataset[_]): DataFrame = {
    val dfs = prepareUnion(
      dataset.sqlContext.sql(s"select * from $getUnionDatasetName"),
      dataset.toDF())
    dfs._1.unionByName(dfs._2)
  }

  private def prepareUnion(df1: DataFrame, df2: DataFrame):  (DataFrame, DataFrame) = {
    validateUnion(df1, df2)
    val colNames = df1.schema.fieldNames
    Sorting.quickSort(colNames)
    val newDf1 = df1.select( colNames map col:_*)
    val newDf2 = df2.select( colNames map col:_*)
    val returnVal = (newDf1, newDf2)
    returnVal
  }

  private def validateUnion(df1: DataFrame, df2: DataFrame): Unit = {
    val df1Cols = df1.schema.fieldNames
    Sorting.quickSort(df1Cols)
    val df2Cols = df2.schema.fieldNames
    Sorting.quickSort(df2Cols)
    val df1SchemaString = df1.select(df1Cols map col:_*).schema.toString()
    val df2SchemaString = df2.select(df2Cols map col:_*).schema.toString()
    assert(df1SchemaString.equals(df2SchemaString),
      s"Different schemas for union DFs. \n DF1 schema $df1SchemaString \n " +
        s"DF2 schema $df2SchemaString \n")
  }

  override def transformSchemaInternal(schema: StructType): StructType = {
    schema
  }

  override def copy(extra: ParamMap): DatasetsUnionTransformer = defaultCopy(extra)
}

object DatasetsUnionTransformer extends DefaultParamsReadable[DatasetsUnionTransformer] {
  override def load(path: String): DatasetsUnionTransformer = super.load(path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy