All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.policyeval.Ips.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.policyeval

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

/**
  * Simplest off-policy evaluation metric: IPS (Inverse Propensity Score)
  *
  * See https://courses.cs.washington.edu/courses/cse599m/19sp/notes/off_policy.pdf
  */
class Ips
  extends Aggregator[IpsInput, IpsBuffer, Float]
    with Serializable
    with SynapseMLLogging {
  override val uid: String = Identifiable.randomUID("BanditEstimatorIps")

  logClass(FeatureNames.VowpalWabbit)

  def zero: IpsBuffer = IpsBuffer(0, 0)

  def reduce(acc: IpsBuffer, x: IpsInput): IpsBuffer = {
    val w = x.probabilityPredicted / x.probabilityLogged

    IpsBuffer(
      acc.exampleCount + x.count,
      acc.weightedReward + x.reward * w * x.count)
  }

  def merge(acc1: IpsBuffer, acc2: IpsBuffer): IpsBuffer = {
    IpsBuffer(
      acc1.exampleCount + acc2.exampleCount,
      acc1.weightedReward + acc2.weightedReward)
  }

  def finish(acc: IpsBuffer): Float =
    logVerb("aggregate", {
      if (acc.exampleCount == 0)
        -1 // TODO: how to return null?
      else
        acc.weightedReward / acc.exampleCount
    })

  def bufferEncoder: Encoder[IpsBuffer] = Encoders.product[IpsBuffer]
  def outputEncoder: Encoder[Float] = Encoders.scalaFloat
}

final case class IpsInput(probabilityLogged: Float, reward: Float, probabilityPredicted: Float, count: Float)

final case class IpsBuffer(exampleCount: Float, weightedReward: Float)




© 2015 - 2025 Weber Informatics LLC | Privacy Policy