All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.stages.MultiColumnAdapter.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.contracts.Wrappable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml._
import org.apache.spark.ml.param.{ParamMap, PipelineStageParam, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._

object MultiColumnAdapter extends ComplexParamsReadable[MultiColumnAdapter]

/** The MultiColumnAdapter takes a unary pipeline stage and a list of input output column pairs
  * and applies the pipeline stage to each input column after being fit
  */
class MultiColumnAdapter(override val uid: String) extends Estimator[PipelineModel]
  with Wrappable with ComplexParamsWritable {
  def this() = this(Identifiable.randomUID("MultiColumnAdapter"))

  /** List of input column names, encoded as a string. These are the columns for the pipeline stage.
    * @group param
    */
  val inputCols: StringArrayParam =
    new StringArrayParam(this,
      "inputCols",
      "list of column names encoded as a string")

  /** @group getParam */
  final def getInputCols: Array[String] = $(inputCols)

  /** @group setParam */
  def setInputCols(value: Array[String]): this.type = set(inputCols, value)

  /** List of column names for the pipeline staged columns, encoded as a string.
    * @group param
    */
  val outputCols: StringArrayParam =
    new StringArrayParam(this,
      "outputCols",
      "list of column names encoded as a string")

  /** @group getParam */
  final def getOutputCols: Array[String] = $(outputCols)

  /** @group setParam */
  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

  /** @return List of input/output column name pairs. */
  def getInputOutputPairs: List[(String, String)] = getInputCols.zip(getOutputCols).toList

  def getStages: Array[PipelineStage] = getInputOutputPairs.toArray.map(getInOutPairStage)

  private def getInOutPairStage(inputOutputPair: (String, String)): PipelineStage = {
    val model = getBaseStage.copy(new ParamMap())
    setParamInternal(setParamInternal(model,
      "inputCol", inputOutputPair._1),
      "outputCol", inputOutputPair._2)
    model
  }

  /** Base pipeline stage to apply to every column in the input column list.
    * @group param
    */
  val baseStage: PipelineStageParam =
    new PipelineStageParam(this,
      "baseStage",
      "base pipeline stage to apply to every column")

  /** @group getParam */
  final def getBaseStage: PipelineStage = $(baseStage)

  /** @group setParam */
  def setBaseStage(value: PipelineStage): this.type = {
    if (value.hasParam("inputCol") & value.hasParam("outputCol")){
      setParamInternal(value, "inputCol", this.uid + "_in")
      setParamInternal(value, "outputCol", this.uid + "_out")
    } else if (value.hasParam("inputCols") & value.hasParam("outputCols")){
      setParamInternal(value, "inputCols", Array(this.uid + "_in"))
      setParamInternal(value, "outputCols", Array(this.uid + "_out"))
    } else {
      throw new IllegalArgumentException(
        "Need to pass a pipeline stage with inputCol and outputCol com.microsoft.ml.spark.core.serialize.params")
    }
    set(baseStage, value)
  }

  private def setParamInternal[M <: PipelineStage, V](model: M,
                                                      name: String,
                                                      value: V) = {
    model.set(model.getParam(name), value)
  }

  private def getParamInternal[M <: PipelineStage](model: M, name: String) = {
    model.getOrDefault(model.getParam(name))
  }

  /** Fit the pipeline stage to all the columns in the input column list
    * @param dataset
    * @return PipelineModel fit on the columns bearing the input column names
    */
  override def fit(dataset: Dataset[_]): PipelineModel = {
    transformSchema(dataset.schema)
    new Pipeline(uid).setStages(getStages).fit(dataset)
  }

  def copy(extra: ParamMap): this.type = defaultCopy(extra)

  private def verifyCols(schema: StructType,
                         inputOutputPairs: List[(String, String)]): Unit = {
    inputOutputPairs.foreach {
      case (s1, _) if !schema.fieldNames.contains(s1) =>
        throw new IllegalArgumentException(
          s"DataFrame does not contain specified column: $s1")
      case (_, s2) if schema.fieldNames.contains(s2) =>
        throw new IllegalArgumentException(
          s"DataFrame already contains specified column: $s2")
      case _ =>
    }
  }

  override def transformSchema(schema: StructType): StructType = {
    verifyCols(schema, getInputOutputPairs)
    getInputOutputPairs.foldLeft(schema) { (schema, pair) =>
      getInOutPairStage(pair).transformSchema(schema)
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy