All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.vw.VowpalWabbitCSETransformer.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.vw

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.policyeval.PolicyEvalUDAFUtil
import org.apache.spark.ml.param.{DoubleParam, FloatParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T}

/**
  * Emits continuous success experimentation metrics for contextual bandit style predictions and logs.
  */
class VowpalWabbitCSETransformer(override val uid: String)
  extends Transformer
    with SynapseMLLogging
    with Wrappable
    with ComplexParamsWritable {

  import VowpalWabbitDSJsonTransformer._
  import VowpalWabbitCSETransformer._

  logClass(FeatureNames.VowpalWabbit)

  def this() = this(Identifiable.randomUID("VowpalWabbitCSETransformer"))

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  val minImportanceWeight = new DoubleParam(
    this, "minImportanceWeight", "Clip importance weight at this lower bound. Defaults to 0.")

  def getMinImportanceWeight: Double = $(minImportanceWeight)
  def setMinImportanceWeight(value: Double): this.type = set(minImportanceWeight, value)

  val maxImportanceWeight = new DoubleParam(
    this, "maxImportanceWeight", "Clip importance weight at this upper bound. Defaults to 100.")

  def getMaxImportanceWeight: Double = $(maxImportanceWeight)
  def setMaxImportanceWeight(value: Double): this.type = set(maxImportanceWeight, value)

  val metricsStratificationCols = new StringArrayParam(
    this, "metricsStratificationCols", "Optional list of column names to stratify rewards by.")

  def getMetricsStratificationCols: Array[String] = $(metricsStratificationCols)
  def setMetricsStratificationCols(value: Array[String]): this.type = set(metricsStratificationCols, value)

  setDefault(minImportanceWeight -> 0, maxImportanceWeight -> 100, metricsStratificationCols -> Array.empty)

  // define reward independent metrics
  val globalMetrics: Seq[Column] = {
    val w = F.col("w")
    Seq(
      F.count("*").alias(ExampleCountName),
      F.sum(F.when(F.col(ProbabilityPredictedColName) > 0, 1).otherwise(0))
        .alias(ProbabilityPredictedNonZeroCount),
      F.min("w").alias(MinimumImportanceWeight),
      F.max("w").alias(MaximumImportanceWeight),
      F.avg(w).alias(AverageImportanceWeight),
      F.avg(w * w).alias(AverageSquaredImportanceWeight),
      (F.max(w) / F.count("*")).alias(PropOfMaximumImportanceWeight),
      F.expr("approx_percentile(w, array(0.25, 0.5, 0.75, 0.95))")
        .alias(QuantilesOfImportanceWeight))
  }

  case class RewardColumn(name: String, col: String, idx: Int) {
    def minRewardCol: String = s"min_reward_$idx"

    def maxRewardCol: String = s"max_reward_$idx"
  }

  def rewardColumns(schema: T.StructType): Seq[RewardColumn] =
    schema(RewardsColName)
      .dataType.asInstanceOf[T.StructType]
      .fields
      .zipWithIndex
      .map({ case (rewardField: T.StructField, idx: Int) =>
        RewardColumn(rewardField.name, s"$RewardsColName.${rewardField.name}", idx)
      })
      .toSeq

  private def rewardColumnToStruct(rewardCol: RewardColumn,
                                   minImportanceWeight: Double,
                                   maxImportanceWeight: Double): Column = {
    val countCol = F.col("count")
    val minImportanceWeightCol = F.lit(minImportanceWeight)
    val maxImportanceWeightCol = F.lit(maxImportanceWeight)

    val minRewardCol = F.col(rewardCol.minRewardCol)
    val maxRewardCol = F.col(rewardCol.maxRewardCol)

    F.struct(
      F.first(F.col(rewardCol.minRewardCol)).alias(MinReward),
      F.first(F.col(rewardCol.maxRewardCol)).alias(MaxReward),
      // multiple estimations
      PolicyEvalUDAFUtil.Snips(F.col(ProbabilityLoggedColName), F.col(rewardCol.col),
        F.col(ProbabilityPredictedColName), countCol)
        .alias(Snips),
      PolicyEvalUDAFUtil.Ips(F.col(ProbabilityLoggedColName), F.col(rewardCol.col),
        F.col(ProbabilityPredictedColName), countCol)
        .alias(Ips),
      PolicyEvalUDAFUtil.CressieRead(F.col(ProbabilityLoggedColName), F.col(rewardCol.col),
        F.col(ProbabilityPredictedColName), countCol, minImportanceWeightCol, maxImportanceWeightCol)
        .alias(CressieRead),
      PolicyEvalUDAFUtil.CressieReadInterval(F.col(ProbabilityLoggedColName), F.col(rewardCol.col),
        F.col(ProbabilityPredictedColName), countCol, minImportanceWeightCol, maxImportanceWeightCol,
        minRewardCol, maxRewardCol)
        .alias(CressieReadInterval),
      PolicyEvalUDAFUtil.CressieReadIntervalEmpirical(F.col(ProbabilityLoggedColName), F.col(rewardCol.col),
        F.col(ProbabilityPredictedColName), countCol, minImportanceWeightCol, maxImportanceWeightCol,
        minRewardCol, maxRewardCol)
        .alias(CressieReadIntervalEmp))
      .alias(rewardCol.name)
  }

  def perRewardMetrics(rewardsCol: Seq[RewardColumn],
                       minImportanceWeight: Double,
                       maxImportanceWeight: Double): Seq[Column] = {
    rewardsCol.map(rewardColumnToStruct(_, minImportanceWeight, maxImportanceWeight))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      val df = dataset.toDF

      // fetch reward columns out of nested structured
      val rewardsCol = rewardColumns(df.schema)

      // calculate min/max for each reward
      val minMaxes = rewardsCol
        .flatMap({ rewardCol: RewardColumn =>
          Seq(F.min(rewardCol.col).alias(rewardCol.minRewardCol),
            F.max(rewardCol.col).alias(rewardCol.maxRewardCol))
        })

      val minMaxRewards = df.agg(minMaxes.head, minMaxes.drop(1): _*)

      // register contextual bandit policy evaluation aggregates
      PolicyEvalUDAFUtil.registerUdafs()

      val metrics = globalMetrics ++ perRewardMetrics(rewardsCol, getMinImportanceWeight, getMaxImportanceWeight)

      df.crossJoin(minMaxRewards.hint("broadcast"))
        // probability predicted = 1 if top action matches observed action, otherwise 0
        // Note: using the probability distribution produces by cb_adf_explore only increases variance
        // and doesn't help model selection. also the offline policy can't influence what data is getting
        // collected offline
        .withColumn(ProbabilityPredictedColName,
          F.when(F.expr(s"element_at(predictions, 1).action == $ChosenActionIndexColName"), 1f)
            .otherwise(0f))
        // example weight, defaults to 1
        .withColumn("count", F.lit(1))
        .withColumn("w", F.col(ProbabilityPredictedColName) / F.col(ProbabilityLoggedColName))
        // optional stratification
        .groupBy(getMetricsStratificationCols.map(F.col): _*)
        .agg(metrics.head, metrics.drop(1): _*)
    }, dataset.columns.length)
  }

  private def perRewardSchema(f: T.StructField): T.StructField = {
    T.StructField(f.name,
      T.StructType(Seq(
        T.StructField(MinReward, T.FloatType, false),
        T.StructField(MaxReward, T.FloatType, false),
        T.StructField(Snips, T.FloatType, false),
        T.StructField(Ips, T.FloatType, false),
        T.StructField(Snips, T.FloatType, false),
        T.StructField(CressieRead, T.DoubleType, false),
        T.StructField(CressieReadInterval,
          T.StructType(Seq(
            T.StructField("lower", T.DoubleType, false),
            T.StructField("upper", T.DoubleType, false))),
          true),
        T.StructField(CressieReadIntervalEmp,
          T.StructType(Seq(
            T.StructField("lower", T.DoubleType, false),
            T.StructField("upper", T.DoubleType, false))),
          true)
      )),
      false)
  }

  override def transformSchema(schema: StructType): StructType =
    T.StructType(
      // groupBy
      getMetricsStratificationCols.map(T.StructField(_, T.StringType, true)) ++
      Seq(
        // global metrics
        T.StructField(ExampleCountName, T.IntegerType, false),
        T.StructField(ProbabilityPredictedNonZeroCount, T.LongType, false),
        T.StructField(MinimumImportanceWeight, T.DoubleType, false),
        T.StructField(MaximumImportanceWeight, T.DoubleType, false),
        T.StructField(AverageImportanceWeight, T.DoubleType, false),
        T.StructField(AverageSquaredImportanceWeight, T.DoubleType, false),
        T.StructField(PropOfMaximumImportanceWeight, T.DoubleType, false),
        T.StructField(QuantilesOfImportanceWeight, T.ArrayType(T.FloatType, false), false)) ++
      // perRewardMetric
      schema(RewardsColName).dataType.asInstanceOf[T.StructType].fields.map(perRewardSchema)
    )
}

object VowpalWabbitCSETransformer extends ComplexParamsReadable[VowpalWabbitCSETransformer] {
  val ExampleCountName = "exampleCount"
  val ProbabilityPredictedNonZeroCount = "probPredNonZeroCount"
  val MinimumImportanceWeight = "minimumImportanceWeight"
  val MaximumImportanceWeight = "maximumImportanceWeight"
  val AverageImportanceWeight = "averageImportanceWeight"
  val AverageSquaredImportanceWeight = "averageSquaredImportanceWeight"
  val PropOfMaximumImportanceWeight = "proportionOfMaximumImportanceWeight"
  val QuantilesOfImportanceWeight = "importance weight quantiles (0.25, 0.5, 0.75, 0.95)"

  val MinReward = "minReward"
  val MaxReward = "maxReward"
  val Snips = "snips"
  val Ips = "ips"
  val CressieRead = "cressieRead"
  val CressieReadInterval = "cressieReadInterval"
  val CressieReadIntervalEmp = "cressieReadIntervalEmpirical"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy