All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.services.anomaly.MultivariateAnomalyDetection.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.services.anomaly

import com.microsoft.azure.synapse.ml.build.BuildInfo
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.services.anomaly.MADJsonProtocol._
import com.microsoft.azure.synapse.ml.services.vision.HasAsyncReply
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCols, HasOutputCol}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.{convertAndClose, sendWithRetries}
import com.microsoft.azure.synapse.ml.io.http.RESTHelpers.{Client, retry}
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.stages._
import com.microsoft.azure.synapse.ml.param.CognitiveServiceStructParam
import org.apache.commons.io.IOUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.http.client.methods._
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.injections.UDFUtils
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import spray.json._

import java.net.URI
import java.time.format.DateTimeFormatter
import java.util.concurrent.TimeoutException
import scala.collection.parallel.mutable
import scala.collection.parallel.mutable.ParHashSet
import scala.concurrent.blocking
import scala.language.existentials

private[ml] case class RemoteIteratorWrapper[T](underlying: org.apache.hadoop.fs.RemoteIterator[T])
  extends scala.collection.AbstractIterator[T] with scala.collection.Iterator[T] {
  def hasNext: Boolean = underlying.hasNext

  def next(): T = underlying.next()
}

private[ml] object Conversions {
  implicit def remoteIterator2ScalaIterator[T](underlying: org.apache.hadoop.fs.RemoteIterator[T]):
  scala.collection.Iterator[T] = RemoteIteratorWrapper[T](underlying)
}

object MADUtils extends Logging {

  private[ml] val CreatedModels: mutable.ParHashSet[String] = new ParHashSet[String]()

  //noinspection ScalaStyle
  private[ml] def madSend(request: HttpRequestBase,
                          path: String,
                          key: String,
                          params: Map[String, String] = Map()): String = {

    val paramString = if (params.isEmpty) {
      ""
    } else {
      "?" + URLEncodingUtils.format(params)
    }
    request.setURI(new URI(path + paramString))

    retry(List(100, 500, 1000), { () => //scalastyle:ignore magic.number
      request.addHeader("Ocp-Apim-Subscription-Key", key)
      request.addHeader("Content-Type", "application/json")
      using(Client.execute(request)) { response =>
        if (!response.getStatusLine.getStatusCode.toString.startsWith("2")) {
          val bodyOpt = request match {
            case er: HttpEntityEnclosingRequestBase => IOUtils.toString(er.getEntity.getContent, "UTF-8")
            case _ => ""
          }
          if (response.getStatusLine.getStatusCode.toString.equals("429")) {
            val retryTime = response.getHeaders("Retry-After").head.getValue.toInt * 1000
            Thread.sleep(retryTime.toLong)
          }
          throw new RuntimeException(s"Failed: response: $response " + s"requestUrl: ${request.getURI} " +
            s"requestBody: $bodyOpt")
        }
        if (response.getStatusLine.getReasonPhrase == "No Content") {
          ""
        }
        else if (response.getStatusLine.getReasonPhrase == "Created") {
          response.getHeaders("Location").head.getValue
        }
        else {
          IOUtils.toString(response.getEntity.getContent, "UTF-8")
        }
      }.get
    })
  }

  private[ml] def madGetModel(url: String, modelId: String,
                              key: String, params: Map[String, String] = Map()): String = {
    madSend(new HttpGet(), url + modelId, key, params)
  }

  private[ml] def madUrl(location: String): String = {
    s"https://$location.api.cognitive.microsoft.com/anomalydetector/v1.1/multivariate/"
  }

  private[ml] def madDelete(modelId: String,
                            key: String,
                            location: String,
                            params: Map[String, String] = Map()): String = {
    madSend(new HttpDelete(), madUrl(location) + "models/" + modelId, key, params)
  }

  private[ml] def madGetBatchDetectionResults(url: String,
                                              resultId: String,
                                              key: String,
                                              params: Map[String, String] = Map(),
                                              maxTries: Int,
                                              pollingDelay: Int): String = {

    val it = (0 to maxTries).toIterator.flatMap { _ =>
      val resp = madSend(new HttpGet(), url + resultId, key, params)
      val fields = resp.parseJson.asJsObject.fields
      fields("summary").convertTo[DMASummary].status.toLowerCase() match {
        case "ready" | "failed" => Some(resp)
        case "created" | "running" => {
          blocking {
            Thread.sleep(pollingDelay.toLong)
          }
          None
        }
        case s => throw new RuntimeException(s"Received unknown status code: $s")
      }
    }
    if (it.hasNext) {
      it.next()
    } else {
      throw new TimeoutException(
        s"Querying for results with resultId $resultId did not complete within $maxTries tries")
    }
  }

  private[ml] def madListModels(key: String,
                                location: String,
                                params: Map[String, String] = Map()): String = {
    madSend(new HttpGet(), madUrl(location) + "models?$top=500", key, params)
  }

  private[ml] def cleanUpAllModels(key: String, location: String): Unit = {
    for (modelId <- CreatedModels) {
      println(s"Deleting mvad model $modelId")
      madDelete(modelId, key, location)
    }
    CreatedModels.clear()
  }

  private[ml] def checkModelStatus(url: String, modelId: String, subscriptionKey: String): Unit = try {
    val response = madGetModel(url, modelId, subscriptionKey)
      .parseJson.asJsObject.fields

    val modelInfo = response("modelInfo").asJsObject.fields
    val modelStatus = modelInfo("status").asInstanceOf[JsString].value.toLowerCase
    modelStatus match {
      case "failed" =>
        val errors = modelInfo("errors").toJson.compactPrint
        throw new RuntimeException(s"Caught errors during fitting: $errors")
      case "created" | "running" =>
        throw new RuntimeException(s"model $modelId is not ready yet")
      case "ready" =>
        logInfo("model is ready for inference")
    }
  } catch {
    case e: RuntimeException =>
      throw new RuntimeException(s"Encounter error while fetching model $modelId, " +
        s"please double check the modelId is correct: ${e.getMessage}")
  }

}

trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply {
  protected def prepareUrl: String

  protected def prepareMethod(): HttpRequestBase = new HttpPost()

  protected def prepareEntity(dataSource: String): Option[AbstractHttpEntity]

  protected val subscriptionKeyHeaderName = "Ocp-Apim-Subscription-Key"

  protected def contentType: String = "application/json"

  protected def prepareRequest(entity: AbstractHttpEntity): Option[HttpRequestBase] = {
    val req = prepareMethod()
    req.setURI(new URI(prepareUrl))
    req.setHeader(subscriptionKeyHeaderName, getSubscriptionKey)
    req.setHeader("Content-Type", contentType)

    req match {
      case er: HttpEntityEnclosingRequestBase =>
        er.setEntity(entity)
      case _ =>
    }
    Some(req)
  }

  protected def queryForResult(key: Option[String],
                               client: CloseableHttpClient,
                               location: URI): Option[HTTPResponseData] = {
    val get = new HttpGet()
    get.setURI(location)
    key.foreach(get.setHeader("Ocp-Apim-Subscription-Key", _))
    get.setHeader("User-Agent", s"synapseml/${BuildInfo.version}${HeaderValues.PlatformInfo}")
    val resp = convertAndClose(sendWithRetries(client, get, getBackoffs))
    get.releaseConnection()
    Some(resp)
  }

  //noinspection ScalaStyle
  protected def timeoutResult(key: Option[String], client: CloseableHttpClient,
                              queryUrl: URI, maxTries: Int): HTTPResponseData = {
    throw new TimeoutException(
      s"Querying for results did not complete within $maxTries tries")
  }

  //scalastyle:off cyclomatic.complexity
  protected def handlingFunc(client: CloseableHttpClient,
                             request: HTTPRequestData): HTTPResponseData = {
    val response = HandlingUtils.advanced(getBackoffs: _*)(client, request)
    if (response.statusLine.statusCode == 201) {
      val location = new URI(response.headers.filter(_.name == "Location").head.value)
      val maxTries = getMaxPollingRetries
      val key = request.headers.find(_.name == "Ocp-Apim-Subscription-Key").map(_.value)
      val it = (0 to maxTries).toIterator.flatMap { _ =>
        val resp = queryForResult(key, client, location)
        val fields = IOUtils.toString(resp.get.entity.get.content, "UTF-8").parseJson.asJsObject.fields
        val status = fields match {
          case f if f.contains("modelInfo") => f("modelInfo").convertTo[MAEModelInfo].status
          case f if f.contains("summary") => f("summary").convertTo[DMASummary].status
          case _ => "None"
        }
        status.toLowerCase() match {
          case "ready" | "failed" => resp
          case "created" | "running" =>
            blocking {
              Thread.sleep(getPollingDelay.toLong)
            }
            None
          case s => throw new RuntimeException(s"Received unknown status code: $s")
        }
      }
      if (it.hasNext) {
        it.next()
      } else {
        timeoutResult(key, client, location, maxTries)
      }
    } else {
      val error = IOUtils.toString(response.entity.get.content, "UTF-8")
      throw new RuntimeException(s"Caught error: $error")
    }
  }
  //scalastyle:on cyclomatic.complexity
}

private case class StorageInfo(account: String, container: String, key: String, blob: String)

trait TimeConverter {
  protected def convertTimeFormat(name: String, v: String): String = {
    try {
      DateTimeFormatter.ISO_INSTANT.format(DateTimeFormatter.ISO_INSTANT.parse(v))
    }
    catch {
      case e: java.time.format.DateTimeParseException =>
        throw new IllegalArgumentException(
          s"${name.capitalize} should be ISO8601 format. e.g. 2021-01-01T00:00:00Z, received: ${e.toString}")
    }
  }
}

trait HasTimestampCol extends Params {
  val timestampCol = new Param[String](this, "timestampCol", "Timestamp column name")

  def setTimestampCol(v: String): this.type = set(timestampCol, v)

  def getTimestampCol: String = $(timestampCol)

  setDefault(timestampCol -> "timestamp")
}

trait MADBase extends HasOutputCol with TimeConverter
  with MADHttpRequest with HasSetLocation with HasInputCols
  with ComplexParamsWritable with Wrappable with HasTimestampCol
  with HasErrorCol with SynapseMLLogging {

  val startTime = new Param[String](this, "startTime", "A required field, start time" +
    " of data to be used for detection/generating multivariate anomaly detection model, should be date-time.")

  def setStartTime(v: String): this.type = set(startTime, convertTimeFormat(startTime.name, v))

  def getStartTime: String = $(startTime)

  val endTime = new Param[String](this, "endTime", "A required field, end time of data" +
    " to be used for detection/generating multivariate anomaly detection model, should be date-time.")

  def setEndTime(v: String): this.type = set(endTime, convertTimeFormat(endTime.name, v))

  def getEndTime: String = $(endTime)

  private def validateIntermediateSaveDir(dir: String): Boolean = {
    if (!dir.startsWith("wasbs://") && !dir.startsWith("abfss://")) {
      throw new IllegalArgumentException("improper HDFS loacation. Please use a wasb path such as: \n" +
        "wasbs://[CONTAINER]@[ACCOUNT].blob.core.windows.net/[DIRECTORY]" +
        "For more information on connecting storage accounts to spark visit " +
        "https://docs.microsoft.com/en-us/azure/databricks/data/data-sources" +
        "/azure/azure-storage#--access-azure-data-lake-storage-gen2-or-blob-storage-using-the-account-key"
      )
    }
    true
  }

  val intermediateSaveDir = new Param[String](
    this,
    "intermediateSaveDir",
    "Blob storage location in HDFS where intermediate data is saved while training.",
    isValid = validateIntermediateSaveDir _
  )

  def setIntermediateSaveDir(v: String): this.type = set(intermediateSaveDir, v)

  def getIntermediateSaveDir: String = $(intermediateSaveDir)

  setDefault(
    outputCol -> (this.uid + "_output"),
    errorCol -> (this.uid + "_error"))

  private def getStorageInfo: StorageInfo = {
    val uri = new URI(getIntermediateSaveDir)
    val account = uri.getHost.split(".".toCharArray).head
    val blobConfig = s"fs.azure.account.key.$account.blob.core.windows.net"
    val adlsConfig = s"fs.azure.account.key.$account.dfs.core.windows.net"
    val hc = SparkSession.builder().getOrCreate()
      .sparkContext.hadoopConfiguration
    val key = Option(hc.get(adlsConfig)).orElse(Option(hc.get(blobConfig)))

    if (key.isEmpty) {
      throw new IllegalAccessError("Could not find the storage account credentials." +
        s" Make sure your hadoopConfiguration has the" +
        s" ''$blobConfig'' or ''$adlsConfig'' configuration set.")
    }

    StorageInfo(account, uri.getUserInfo, key.get, uri.getPath.stripPrefix("/"))
  }

  protected def blobPath: Path = new Path(new URI(getIntermediateSaveDir.stripSuffix("/") + s"/$uid.csv"))

  protected def upload(df: DataFrame): String = {
    val convertTimeFormatUdf = UDFUtils.oldUdf(
      { value: String => convertTimeFormat("Timestamp column", value) },
      StringType
    )
    val formatDf = df.withColumn(getTimestampCol, convertTimeFormatUdf(col(getTimestampCol)))
      .sort(col(getTimestampCol).asc)

    val storageInfo = getStorageInfo

    formatDf.coalesce(1)
      .write.mode("overwrite").format("csv")
      .option("header", "true")
      .save(blobPath.toString)

    // MVAD doesn't support SAS url anymore, you need to add authentication of storage account
    // with anomaly detector's managed identity
    val hconf = SparkSession.builder().getOrCreate().sparkContext.hadoopConfiguration
    val fs = FileSystem.get(blobPath.toUri, hconf)
    import Conversions._
    val filePath = fs.listFiles(blobPath, true)
      .filter(file => file.getPath.toString.contains("part-00000"))
      .toSeq.head.getPath.toString
    s"https://${storageInfo.account}.blob.core.windows.net/${storageInfo.container}/" +
      s"${filePath.split("/").drop(3).mkString("/")}"
  }

  def cleanUpIntermediateData(): Unit = {
    val hconf = SparkSession.builder().getOrCreate().sparkContext.hadoopConfiguration
    val fs = FileSystem.get(blobPath.toUri, hconf)
    fs.delete(blobPath, true)
  }

  override def pyAdditionalMethods: String = super.pyAdditionalMethods + {
    """
      |def cleanUpIntermediateData(self):
      |    self._java_obj.cleanUpIntermediateData()
      |    return
      |""".stripMargin
  }

  protected def submitDatasetAndJob(dataset: Dataset[_]): Map[String, JsValue] = {
    val df = dataset.toDF().select((Array(getTimestampCol) ++ getInputCols).map(col): _*)
    val url = upload(df)

    val httpRequestBase = prepareRequest(prepareEntity(url).get)
    val request = new HTTPRequestData(httpRequestBase.get)
    val response = handlingFunc(Client, request)

    val responseJson = IOUtils.toString(response.entity.get.content, "UTF-8")
      .parseJson.asJsObject.fields

    responseJson
  }

}

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
object SimpleFitMultivariateAnomaly extends ComplexParamsReadable[SimpleFitMultivariateAnomaly] with Serializable

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
class SimpleFitMultivariateAnomaly(override val uid: String) extends Estimator[SimpleDetectMultivariateAnomaly]
  with MADBase {
  logClass(FeatureNames.AiServices.Anomaly)

  def this() = this(Identifiable.randomUID("SimpleFitMultivariateAnomaly"))

  def urlPath: String = "anomalydetector/v1.1/multivariate/models"

  val dataSchema = "OneTable"

  val slidingWindow = new IntParam(this, "slidingWindow", "An optional field, indicates" +
    " how many history points will be used to determine the anomaly score of one subsequent point.")

  def setSlidingWindow(v: Int): this.type = {
    if ((v >= 28) && (v <= 2880)) {
      set(slidingWindow, v)
    } else {
      throw new IllegalArgumentException("slidingWindow must be between 28 and 2880 (both inclusive).")
    }
  }

  def getSlidingWindow: Int = $(slidingWindow)

  val alignMode = new Param[String](this, "alignMode", "An optional field, indicates how " +
    "we align different variables into the same time-range which is required by the model.{Inner, Outer}")

  def setAlignMode(v: String): this.type = {
    if (Set("inner", "outer").contains(v.toLowerCase)) {
      set(alignMode, v.toLowerCase.capitalize)
    } else {
      throw new IllegalArgumentException("alignMode must be either `inner` or `outer`.")
    }
  }

  def getAlignMode: String = $(alignMode)

  val fillNAMethod = new Param[String](this, "fillNAMethod", "An optional field, indicates how missed " +
    "values will be filled with. Can not be set to NotFill, when alignMode is Outer.{Previous, Subsequent," +
    " Linear, Zero, Fixed}")

  def setFillNAMethod(v: String): this.type = {
    if (Set("previous", "subsequent", "linear", "zero", "fixed").contains(v.toLowerCase)) {
      set(fillNAMethod, v.toLowerCase.capitalize)
    } else {
      throw new IllegalArgumentException("fillNAMethod must be one of {Previous, Subsequent, Linear, Zero, Fixed}.")
    }
  }

  def getFillNAMethod: String = $(fillNAMethod)

  val paddingValue = new IntParam(this, "paddingValue", "optional field, is only useful" +
    " if FillNAMethod is set to Fixed.")

  def setPaddingValue(v: Int): this.type = set(paddingValue, v)

  def getPaddingValue: Int = $(paddingValue)

  val displayName = new Param[String](this, "displayName", "optional field," +
    " name of the model")

  def setDisplayName(v: String): this.type = set(displayName, v)

  def getDisplayName: String = $(displayName)

  setDefault(slidingWindow -> 300, alignMode -> "Outer", fillNAMethod -> "Linear")

  protected def prepareEntity(dataSource: String): Option[AbstractHttpEntity] = {
    Some(new StringEntity(
      MAERequest(
        dataSource,
        dataSchema,
        getStartTime,
        getEndTime,
        get(slidingWindow).orElse(getDefault(slidingWindow)),
        Option(AlignPolicy(
          get(alignMode).orElse(getDefault(alignMode)),
          get(fillNAMethod).orElse(getDefault(fillNAMethod)),
          get(paddingValue))),
        get(displayName)
      ).toJson.compactPrint, ContentType.APPLICATION_JSON))
  }

  protected def prepareUrl: String = getUrl

  //noinspection ScalaStyle
  override protected def timeoutResult(key: Option[String], client: CloseableHttpClient,
                                       queryUrl: URI, maxTries: Int): HTTPResponseData = {
    // if no response after max retries, return the response containing modelId directly
    queryForResult(key, client, queryUrl).get
  }

  override def fit(dataset: Dataset[_]): SimpleDetectMultivariateAnomaly = {
    logFit({
      val response = submitDatasetAndJob(dataset)

      val modelInfo = response("modelInfo").asJsObject.fields
      val modelId = response("modelId").convertTo[String]

      if (modelInfo("status").asInstanceOf[JsString].value.toLowerCase() == "failed") {
        val errors = modelInfo("errors").toJson.compactPrint
        throw new RuntimeException(s"Caught errors during fitting: $errors")
      }

      MADUtils.CreatedModels += modelId

      new SimpleDetectMultivariateAnomaly()
        .setSubscriptionKey(getSubscriptionKey)
        .setLocation(getUrl.split("/".toCharArray)(2).split(".".toCharArray).head)
        .setModelId(modelId)
        .setIntermediateSaveDir(getIntermediateSaveDir)
        .setDiagnosticsInfo(modelInfo("diagnosticsInfo").convertTo[DiagnosticsInfo])
    }, dataset.columns.length)
  }

  override def copy(extra: ParamMap): SimpleFitMultivariateAnomaly = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    schema.add(getErrorCol, DMAError.schema)
      .add(getOutputCol, DMAResponse.schema)
      .add("isAnomaly", BooleanType)
  }

}

trait DetectMAParams extends Params {
  val modelId = new Param[String](this, "modelId", "Format - uuid. Model identifier.")

  def setModelId(v: String): this.type = set(modelId, v)

  def getModelId: String = $(modelId)

  val diagnosticsInfo = new CognitiveServiceStructParam[DiagnosticsInfo](this, "diagnosticsInfo",
    "diagnosticsInfo for training a multivariate anomaly detection model")

  def setDiagnosticsInfo(v: DiagnosticsInfo): this.type = set(diagnosticsInfo, v)

  def getDiagnosticsInfo: DiagnosticsInfo = $(diagnosticsInfo)

  val topContributorCount = new IntParam(this, "topContributorCount", "This is a number" +
    " that you could specify N from 1 to 30, which will give you the details of top N contributed variables " +
    "in the anomaly results. For example, if you have 100 variables in the model, but you only care the top " +
    "five contributed variables in detection results, then you should fill this field with 5. The default" +
    " number is 10.", isValid = ParamValidators.inRange(1.0, 30.0))

  def setTopContributorCount(v: Int): this.type = set(topContributorCount, v)

  def getTopContributorCount: Int = $(topContributorCount)

  setDefault(topContributorCount -> 10)
}

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
object SimpleDetectMultivariateAnomaly extends ComplexParamsReadable[SimpleDetectMultivariateAnomaly] with Serializable

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[SimpleDetectMultivariateAnomaly]
  with MADBase with HasHandler with DetectMAParams {
  logClass(FeatureNames.AiServices.Anomaly)

  def this() = this(Identifiable.randomUID("SimpleDetectMultivariateAnomaly"))

  def urlPath: String = "anomalydetector/v1.1/multivariate/models/"

  protected def prepareEntity(dataSource: String): Option[AbstractHttpEntity] = {
    Some(new StringEntity(
      DMARequest(dataSource, getStartTime, getEndTime, Some(getTopContributorCount))
        .toJson.compactPrint))
  }

  protected def prepareUrl: String = getUrl + s"$getModelId:detect-batch"

  override def handlingFunc(client: CloseableHttpClient,
                            request: HTTPRequestData): HTTPResponseData = getHandler(client, request)

  //scalastyle:off method.length
  override def transform(dataset: Dataset[_]): DataFrame =
    logTransform[DataFrame] ({

      // check model status first
      MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey)

      val spark = dataset.sparkSession
      val responseJson = submitDatasetAndJob(dataset)

      // need to fetch batch inference result using resultId
      val response = MADUtils.madGetBatchDetectionResults(
        getUrl.split("/".toCharArray).dropRight(1).mkString("/") + "/detect-batch/",
        responseJson("resultId").convertTo[String],
        getSubscriptionKey,
        maxTries = getMaxPollingRetries,
        pollingDelay = getPollingDelay)
      val fields = response.parseJson.asJsObject.fields
      val summary = fields("summary").convertTo[DMASummary]
      if (summary.status.toLowerCase() == "failed") {
        val errors = summary.errors.get.toJson.compactPrint
        throw new RuntimeException(s"Failure during inference: $errors")
      }

      val resultDF = spark.createDataFrame(fields("results").convertTo[Seq[DMAResult]])

      val sortedDF = resultDF
        .sort(col("timestamp").asc)
        .withColumnRenamed("timestamp", "resultTimestamp")

      val simplifiedDF = if (sortedDF.columns.contains("value")) {
        sortedDF.withColumn("isAnomaly", col("value.isAnomaly"))
          .withColumnRenamed("value", getOutputCol)
      } else {
        sortedDF.withColumn(getOutputCol, lit(None))
          .withColumn("isAnomaly", lit(None))
      }

      val finalDF = if (simplifiedDF.columns.contains("errors")) {
        simplifiedDF.withColumnRenamed("errors", getErrorCol)
      } else {
        simplifiedDF.withColumn(getErrorCol, lit(None))
      }

      val df = dataset.toDF()
      df.join(finalDF, df(getTimestampCol) === finalDF("resultTimestamp"), "left")
        .drop("resultTimestamp")
        .sort(col(getTimestampCol).asc)
    }, dataset.columns.length)
  //scalastyle:on method.length

  override def copy(extra: ParamMap): SimpleDetectMultivariateAnomaly = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    schema.add(getErrorCol, DMAError.schema)
      .add(getOutputCol, DMAResponse.schema)
      .add("isAnomaly", BooleanType)
  }

}

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
object DetectLastMultivariateAnomaly extends ComplexParamsReadable[DetectLastMultivariateAnomaly] with Serializable

@deprecated("The Anomaly Detection Service will be shutting down in 2026," +
  " please use IsolationForest for anomaly detection", "v1.0.0")
class DetectLastMultivariateAnomaly(override val uid: String) extends CognitiveServicesBase(uid)
  with HasInternalJsonOutputParser with TimeConverter with HasTimestampCol
  with HasSetLocation with HasCognitiveServiceInput with HasBatchSize
  with ComplexParamsWritable with Wrappable
  with HasErrorCol with SynapseMLLogging with DetectMAParams {
  logClass(FeatureNames.AiServices.Anomaly)

  def this() = this(Identifiable.randomUID("DetectLastMultivariateAnomaly"))

  def urlPath: String = "anomalydetector/v1.1/multivariate/models/"

  val inputVariablesCols = new StringArrayParam(this, "inputVariablesCols",
    "The names of the input variables columns")

  def setInputVariablesCols(value: Array[String]): this.type = set(inputVariablesCols, value)

  def getInputVariablesCols: Array[String] = $(inputVariablesCols)

  override def setBatchSize(value: Int): this.type = {
    logWarning("batchSize should be equal to 1 sliding window.")
    set(batchSize, value)
  }

  setDefault(batchSize -> 300)

  override protected def prepareUrl: Row => String = {
    row: Row => getUrl + s"$getModelId:detect-last"
  }

  protected def prepareEntity: Row => Option[AbstractHttpEntity] = { row =>
    val timestamps = row.getAs[Seq[String]](s"${getTimestampCol}_list")
    val variables = getInputVariablesCols.map(
      variable => Variable(timestamps, row.getAs[Seq[Double]](s"${variable}_list"), variable))
    Some(new StringEntity(
      DLMARequest(variables, getTopContributorCount).toJson.compactPrint
    ))
  }

  // scalastyle:off null
  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      // check model status first
      MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey)

      val convertTimeFormatUdf = UDFUtils.oldUdf(
        { value: String => convertTimeFormat("Timestamp column", value) },
        StringType
      )
      val formattedDF = dataset.withColumn(getTimestampCol, convertTimeFormatUdf(col(getTimestampCol)))
        .sort(col(getTimestampCol).asc)
        .withColumn("group", lit(1))

      val window = Window.partitionBy("group").rowsBetween(-getBatchSize, 0)
      var collectedDF = formattedDF
      var columnNames = Array(getTimestampCol) ++ getInputVariablesCols
      for (columnName <- columnNames) {
        collectedDF = collectedDF.withColumn(s"${columnName}_list", collect_list(columnName).over(window))
      }
      collectedDF = collectedDF.drop("group")
      columnNames = columnNames.map(name => s"${name}_list")

      val testDF = getInternalTransformer(collectedDF.schema).transform(collectedDF)

      testDF
        .withColumn("isAnomaly", when(col(getOutputCol).isNotNull,
          col(s"$getOutputCol.results.value.isAnomaly")(0)).otherwise(null))
        .withColumn("DetectDataTimestamp", when(col(getOutputCol).isNotNull,
          col(s"$getOutputCol.results.timestamp")(0)).otherwise(null))
        .drop(columnNames: _*)

    }, dataset.columns.length)
  }
  // scalastyle:on null

  override protected def getInternalTransformer(schema: StructType): PipelineModel = {
    val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", schema)
    val lambda = Lambda(_.withColumn(dynamicParamColName, struct(
      s"${getTimestampCol}_list", getInputVariablesCols.map(name => s"${name}_list"): _*)))

    val stages = Array(
      lambda,
      new SimpleHTTPTransformer()
        .setInputCol(dynamicParamColName)
        .setOutputCol(getOutputCol)
        .setInputParser(getInternalInputParser(schema))
        .setOutputParser(getInternalOutputParser(schema))
        .setHandler(handlingFunc _)
        .setConcurrency(getConcurrency)
        .setConcurrentTimeout(get(concurrentTimeout))
        .setErrorCol(getErrorCol),
      new DropColumns().setCol(dynamicParamColName)
    )

    NamespaceInjections.pipelineModel(stages)

  }

  override def transformSchema(schema: StructType): StructType = {
    schema.add(getErrorCol, DMAError.schema)
      .add(getOutputCol, DLMAResponse.schema)
      .add("isAnomaly", BooleanType)
  }

  override def responseDataType: DataType = DLMAResponse.schema

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy