All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.vw.VowpalWabbitPrediction.scala Maven / Gradle / Ivy

There is a newer version: 1.0.9
Show newest version
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.vw

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.StructType
import org.vowpalwabbit.spark.VowpalWabbitNative
import org.vowpalwabbit.spark.prediction.ScalarPrediction
import vowpalWabbit.responses.{ActionProbs, ActionScores, DecisionScores, Multilabels, PDF, PDFValue}

import scala.reflect.runtime.universe.TypeTag

/**
  * Provide schemas and accessor functions for almost all VW prediction types.
  */
object VowpalWabbitPrediction {
  private def toStructType[T: TypeTag] = ExpressionEncoder[T]().resolveAndBind().schema

  private val SchemaMap = {
    import VowpalWabbitSchema.Predictions

    Map(
      "prediction_type_t::scalars" -> toStructType[Predictions.Scalars],
      "prediction_type_t::multiclass" -> toStructType[Predictions.Multiclass],
      "prediction_type_t::prob" -> toStructType[Predictions.Prob],
      "prediction_type_t::multilabels" -> toStructType[Predictions.Multilabels],
      "prediction_type_t::scalar" -> toStructType[Predictions.Scalar],
      "prediction_type_t::action_scores" -> toStructType[Predictions.ActionScores],
      "prediction_type_t::action_probs" -> toStructType[Predictions.ActionProbs],
      "prediction_type_t::decision_probs" -> toStructType[Predictions.DecisionProbs],
      "prediction_type_t::action_pdf_value" -> toStructType[Predictions.ActionPDFValue],
      "prediction_type_t::pdf" -> toStructType[Predictions.PDF]
    )
  }

  def getSchema(vw: VowpalWabbitNative): StructType = SchemaMap(vw.getOutputPredictionType)

  type VowpalWabbitPredictionToSeqFunc = Object => Seq[Any]

  // TODO: the java wrapper would have to support them too
  // CASE(prediction_type_t::pdf)
  // CASE(prediction_type_t::multiclassprobs)
  // CASE(prediction_type_t::action_pdf_value)
  // CASE(prediction_type_t::active_multiclass)
  // CASE(prediction_type_t::nopred)

  val PredictionFunctions = Map(
    "prediction_type_t::scalar" -> predictScalar,
    "prediction_type_t::scalars" -> predictScalars,
    "prediction_type_t::action_scores"-> predictActionScores,
    "prediction_type_t::action_probs" -> predictActionProbs,
    "prediction_type_t::decision_probs" -> predictDecisionProbs,
    "prediction_type_t::multilabels" -> predictMultilabels,
    "prediction_type_t::multiclass" -> predictMulticlass,
    "prediction_type_t::prob" -> predictProb,
    "prediction_type_t::action_pdf_value" -> predictActionPdfValue,
    "prediction_type_t::pdf" -> predictPdf)

  def getPredictionFunc(vw: VowpalWabbitNative): VowpalWabbitPredictionToSeqFunc =
    PredictionFunctions(vw.getOutputPredictionType)

  private def predictScalar: VowpalWabbitPredictionToSeqFunc = obj => {
    val pred = obj.asInstanceOf[ScalarPrediction]
    Seq(pred.getValue, pred.getConfidence)
  }

  private def predictScalars: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[Array[Float]])

  private def predictActionScores: VowpalWabbitPredictionToSeqFunc = obj => {
    val pred = obj.asInstanceOf[ActionScores].getActionScores
    Seq(pred.map({ a_s => Row.fromTuple(a_s.getAction, a_s.getScore) }))
  }

  private def predictActionProbs: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[ActionProbs]
    .getActionProbs
    .map { a_s => Row.fromTuple((a_s.getAction, a_s.getProbability)) })

  private def predictDecisionProbs: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[DecisionScores]
    .getDecisionScores.map({ ds => {
    ds.getActionScores.map { a_s => Row.fromTuple((a_s.getAction, a_s.getScore)) }
  }}))

  private def predictMultilabels: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[Multilabels].getLabels)

  private def predictMulticlass: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[java.lang.Integer])

  private def predictProb: VowpalWabbitPredictionToSeqFunc = obj => Seq(obj.asInstanceOf[java.lang.Float])

  private def predictActionPdfValue: VowpalWabbitPredictionToSeqFunc = obj => {
    val pred = obj.asInstanceOf[PDFValue]
    Seq(pred.getAction, pred.getPDFValue)
  }

  private def predictPdf: VowpalWabbitPredictionToSeqFunc = obj => {
    Seq(obj.asInstanceOf[PDF]
      .getPDFSegments
      .map { s => Row.fromTuple((s.getLeft, s.getRight, s.getPDFValue)) })
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy