All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.vw.VectorZipper.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.vw

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCols, HasOutputCol}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.types.{ArrayType, StructType}

object VectorZipper extends ComplexParamsReadable[VectorZipper]

/**
  * Combine one or more input columns into a sequence in the output column.
  */
class VectorZipper(override val uid: String) extends Transformer
  with HasInputCols with HasOutputCol with Wrappable with ComplexParamsWritable with SynapseMLLogging {
  logClass(FeatureNames.VowpalWabbit)

  def this() = this(Identifiable.randomUID("VectorZipper"))

  override def copy(extra: ParamMap): VectorZipper = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val firstDt = schema(getInputCols(0)).dataType
    getInputCols.tail.foreach(col => assert(schema(col).dataType == firstDt))
    schema.add(getOutputCol, ArrayType(firstDt))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      val inputCols = getInputCols
      dataset.withColumn(getOutputCol, array(inputCols.head, inputCols.tail: _*))
    }, dataset.columns.length)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy