All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.policyeval.Snips.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.policyeval

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

/**
  * SNIPS off policy estimator: the Self-Normalized Estimator for Counterfactual Learning
  *
  * See https://papers.nips.cc/paper/2015/file/39027dfad5138c9ca0c474d71db915c3-Paper.pdf
  */
class Snips
  extends Aggregator[SnipsInput, SnipsBuffer, Float]
    with Serializable
    with SynapseMLLogging {
  override val uid: String = Identifiable.randomUID("BanditEstimatorSnips")

  logClass(FeatureNames.VowpalWabbit)

  def zero: SnipsBuffer = SnipsBuffer(0, 0)

  def reduce(acc: SnipsBuffer, x: SnipsInput): SnipsBuffer = {
    val w = x.probabilityPredicted / x.probabilityLogged

    SnipsBuffer(
      acc.weightedExampleCount + w * x.count,
      acc.weightedReward + x.reward * w * x.count)
  }

  def merge(acc1: SnipsBuffer, acc2: SnipsBuffer): SnipsBuffer =
    SnipsBuffer(
      acc1.weightedExampleCount + acc2.weightedExampleCount,
      acc1.weightedReward + acc2.weightedReward)

  def finish(acc: SnipsBuffer): Float = {
    logVerb("aggregate", {
      if (acc.weightedExampleCount == 0)
        -1f // TODO: how to return null?
      else
        (acc.weightedReward / acc.weightedExampleCount).toFloat
    })
  }

  def bufferEncoder: Encoder[SnipsBuffer] = Encoders.product[SnipsBuffer]
  def outputEncoder: Encoder[Float] = Encoders.scalaFloat
}

final case class SnipsBuffer(weightedExampleCount: Double, weightedReward: Double)

final case class SnipsInput(probabilityLogged: Float,
                            reward: Float,
                            probabilityPredicted: Float,
                            count: Float)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy