All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.vw.VowpalWabbitDSJsonTransformer.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.vw

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, functions => F, types => T}

import scala.jdk.CollectionConverters.mapAsScalaMapConverter

class VowpalWabbitDSJsonTransformer(override val uid: String)
  extends Transformer
    with SynapseMLLogging
    with Wrappable
    with ComplexParamsWritable {
  import VowpalWabbitDSJsonTransformer._

  logClass(FeatureNames.VowpalWabbit)

  def this() = this(Identifiable.randomUID("VowpalWabbitDSJsonTransformer"))

  val dsJsonColumn = new Param[String](
    this,"dsJsonColumn", "Column containing ds-json. defaults to \"value\".")

  def getDsJsonColumn: String = $(dsJsonColumn)
  def setDsJsonColumn(value: String): this.type = set(dsJsonColumn, value)

  val rewards = new StringStringMapParam(
    this, "rewards", "Extract bandit reward(s) from DS json. Defaults to _label_cost.")

  def getRewards: Map[String, String] = $(rewards)
  def setRewards(v: Map[String, String]): this.type = set(rewards, v)
  def setRewards(value: java.util.HashMap[String, String]): this.type = set(rewards, value.asScala.toMap)

  setDefault(dsJsonColumn -> "value",
    rewards -> Map("reward" -> "_label_cost"))

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  private def eventIdField: T.StructField =
    T.StructField(EventIdColName, T.StringType, false)

  private def rewardFields: Seq[T.StructField] =
    getRewards.map { case (_, v) => T.StructField(v, T.FloatType, false) }.toSeq

  private def jsonSchema: T.StructType =
    T.StructType(Seq(
      eventIdField,
      T.StructField("_label_probability", T.FloatType, false),
      T.StructField("_labelIndex", T.IntegerType, false)) ++
      // extract rewards from JSON
      rewardFields)

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      // TODO: extract all headers as well
      val jsonCol = F.col(JsonColName)

      val rewardCols = getRewards.map({ case (alias, col) => F.col(JsonColName).getField(col).alias(alias) }).toSeq

      val outputFields =
        dataset.schema.names.map(F.col) ++
          Seq(jsonCol,
            jsonCol.getField(EventIdColName).as(EventIdColName),
            F.struct(rewardCols: _*).as(RewardsColName),
            jsonCol.getField(LabelProbability).as(ProbabilityLoggedColName),
            jsonCol.getField(LabelIndex).as(ChosenActionIndexColName)
          )

      dataset.toDF
        .withColumn(JsonColName, F.from_json(F.col(getDsJsonColumn), jsonSchema))
        .select(outputFields: _ *)
    }, dataset.columns.length)
  }

  override def transformSchema(schema: StructType): StructType =
    T.StructType(schema.fields ++ Seq(
      T.StructField(JsonColName, jsonSchema, false),
      eventIdField,
      T.StructField(RewardsColName, T.StructType(rewardFields), false),
      T.StructField(ProbabilityLoggedColName, T.FloatType, false),
      T.StructField(ChosenActionIndexColName, T.IntegerType, false))
    )
}

object VowpalWabbitDSJsonTransformer extends ComplexParamsReadable[VowpalWabbitDSJsonTransformer] {
  val EventIdColName = "EventId"

  val JsonColName = "json"

  val LabelProbability = "_label_probability"

  val LabelIndex = "_labelIndex"

  val ProbabilityLoggedColName = "probLog"

  val ProbabilityPredictedColName = "probPred"

  val ChosenActionIndexColName = "chosenActionIndex"

  val RewardsColName = "rewards"

  val HeaderColNames = Seq(EventIdColName, RewardsColName, ProbabilityLoggedColName, ChosenActionIndexColName)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy