
com.microsoft.azure.synapse.ml.vw.VectorZipper.scala Maven / Gradle / Ivy
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.vw
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCols, HasOutputCol}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.types.{ArrayType, StructType}
object VectorZipper extends ComplexParamsReadable[VectorZipper]
/**
* Combine one or more input columns into a sequence in the output column.
*/
class VectorZipper(override val uid: String) extends Transformer
with HasInputCols with HasOutputCol with Wrappable with ComplexParamsWritable with SynapseMLLogging {
logClass(FeatureNames.VowpalWabbit)
def this() = this(Identifiable.randomUID("VectorZipper"))
override def copy(extra: ParamMap): VectorZipper = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = {
val firstDt = schema(getInputCols(0)).dataType
getInputCols.tail.foreach(col => assert(schema(col).dataType == firstDt))
schema.add(getOutputCol, ArrayType(firstDt))
}
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val inputCols = getInputCols
dataset.withColumn(getOutputCol, array(inputCols.head, inputCols.tail: _*))
}, dataset.columns.length)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy