All Downloads are FREE. Search and download functionalities are using the official Maven repository. Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.


import java.util.concurrent.TimeoutException

import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpGet, HttpRequestBase}
import org.apache.http.entity.{AbstractHttpEntity, ByteArrayEntity, StringEntity}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import spray.json._
import org.apache.http.entity.ContentType
import scala.concurrent.blocking

import scala.language.existentials

trait HasImageUrl extends HasServiceParams {
  val imageUrl = new ServiceParam[String](
    this, "imageUrl", "the url of the image to use", isRequired = false)

  def getImageUrl: String = getScalarParam(imageUrl)

  def setImageUrl(v: String): this.type = setScalarParam(imageUrl, v)

  def getImageUrlCol: String = getVectorParam(imageUrl)

  def setImageUrlCol(v: String): this.type = setVectorParam(imageUrl, v)


trait HasImageBytes extends HasServiceParams {

  val imageBytes = new ServiceParam[Array[Byte]](
    this, "imageBytes", "bytestream of the image to use", isRequired = false)

  def getImageBytes: Array[Byte] = getScalarParam(imageBytes)

  def setImageBytes(v: Array[Byte]): this.type = setScalarParam(imageBytes, v)

  def getImageBytesCol: String = getVectorParam(imageBytes)

  def setImageBytesCol(v: String): this.type = setVectorParam(imageBytes, v)


trait HasImageInput extends HasImageUrl
  with HasImageBytes with HasCognitiveServiceInput {

  override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
    r =>
      getValueOpt(r, imageUrl)
        .map(url => new StringEntity(Map("url" -> url).toJson.compactPrint, ContentType.APPLICATION_JSON))
        .orElse(getValueOpt(r, imageBytes)
          .map(bytes => new ByteArrayEntity(bytes, ContentType.APPLICATION_OCTET_STREAM))
        ).orElse(throw new IllegalArgumentException(
        "Payload needs to contain image bytes or url. This code should not run"))

  override protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = {
    val rowToUrl = prepareUrl
    val rowToEntity = prepareEntity;
    { row: Row =>
      if (shouldSkip(row)) {
      } else {
        val req = prepareMethod()
        req.setURI(new URI(rowToUrl(row)))
        getValueOpt(row, subscriptionKey).foreach(
          req.setHeader(subscriptionKeyHeaderName, _))
        req match {
          case er: HttpEntityEnclosingRequestBase =>
          case _ =>


  override protected def shouldSkip(row: Row): Boolean = {
    val hasUrlInput = emptyParamData(row, imageUrl)
    val hasBytesInput = emptyParamData(row, imageBytes)

    if (hasUrlInput ^ hasBytesInput) {
    } else {

trait HasDetectOrientation extends HasServiceParams {
  val detectOrientation = new ServiceParam[Boolean](
    this, "detectOrientation", "whether to detect image orientation prior to processing", isURLParam = true)

  def getDetectOrientation: Boolean = getScalarParam(detectOrientation)

  def setDetectOrientation(v: Boolean): this.type = setScalarParam(detectOrientation, v)

  def getDetectOrientationCol: String = getVectorParam(detectOrientation)

  def setDetectOrientationCol(v: String): this.type = setVectorParam(detectOrientation, v)


trait HasWidth extends HasServiceParams {
  val width = new ServiceParam[Int](
    this, "width", "the desired width of the image", isURLParam = true)

  def getWidth: Int = getScalarParam(width)

  def setWidth(v: Int): this.type = setScalarParam(width, v)

  def getWidthCol: String = getVectorParam(width)

  def setWidthCol(v: String): this.type = setVectorParam(width, v)


trait HasHeight extends HasServiceParams {
  val height = new ServiceParam[Int](
    this, "height", "the desired height of the image", isURLParam = true)

  def getHeight: Int = getScalarParam(height)

  def setHeight(v: Int): this.type = setScalarParam(height, v)

  def getHeightCol: String = getVectorParam(height)

  def setHeightCol(v: String): this.type = setVectorParam(height, v)


trait HasSmartCropping extends HasServiceParams {
  val smartCropping = new ServiceParam[Boolean](
    this, "smartCropping", "whether to intelligently crop the image", isURLParam = true)

  def getSmartCropping: Boolean = getScalarParam(smartCropping)

  def setSmartCropping(v: Boolean): this.type = setScalarParam(smartCropping, v)

  def getSmartCroppingCol: String = getVectorParam(smartCropping)

  def setSmartCroppingCol(v: String): this.type = setVectorParam(smartCropping, v)


object OCR extends ComplexParamsReadable[OCR] with Serializable {

  def flatten(inputCol: String, outputCol: String): UDFTransformer = {
    val fromRow = OCRResponse.makeFromRowConverter
    new UDFTransformer()
        { r: Row =>
          Option(r).map(fromRow).map { resp =>
      " ")
              ).mkString(" ")
            ).mkString(" ")

class OCR(override val uid: String) extends CognitiveServicesBase(uid)
  with HasLanguage with HasImageInput with HasDetectOrientation
  with HasCognitiveServiceInput with HasInternalJsonOutputParser {

  def this() = this(Identifiable.randomUID("OCR"))

  def setLocation(v: String): this.type =

  def setDefaultLanguage(v: String): this.type = setDefaultValue(language, v)

  override def responseDataType: DataType = OCRResponse.schema

object RecognizeText extends ComplexParamsReadable[RecognizeText] {
  def flatten(inputCol: String, outputCol: String): UDFTransformer = {
    val fromRow = RTResponse.makeFromRowConverter
    new UDFTransformer()
        { r: Row =>
  " "))

class RecognizeText(override val uid: String)
  extends CognitiveServicesBaseWithoutHandler(uid)
    with HasImageInput with HasCognitiveServiceInput
    with HasInternalJsonOutputParser {

  def this() = this(Identifiable.randomUID("RecognizeText"))

  val backoffs: IntArrayParam = new IntArrayParam(
    this, "backoffs", "array of backoffs to use in the handler")

  /** @group getParam */
  def getBackoffs: Array[Int] = $(backoffs)

  /** @group setParam */
  def setBackoffs(value: Array[Int]): this.type = set(backoffs, value)

  val maxPollingRetries: IntParam = new IntParam(
    this, "maxPollingRetries", "number of times to poll")

  /** @group getParam */
  def getMaxPollingRetries: Int = $(maxPollingRetries)

  /** @group setParam */
  def setMaxPollingRetries(value: Int): this.type = set(maxPollingRetries, value)

  val pollingDelay: IntParam = new IntParam(
    this, "pollingDelay", "number of milliseconds to wait between polling")

  /** @group getParam */
  def getPollingDelay: Int = $(pollingDelay)

  /** @group setParam */
  def setPollingDelay(value: Int): this.type = set(pollingDelay, value)

  //scalastyle:off magic.number
  setDefault(backoffs -> Array(100, 500, 1000), maxPollingRetries -> 1000, pollingDelay -> 300)
  //scalastyle:on magic.number

  val mode = new ServiceParam[String](this, "mode",
    "If this parameter is set to 'Printed', " +
      "printed text recognition is performed. If 'Handwritten' is specified," +
      " handwriting recognition is performed",
    { spd: ServiceParamData[String] => match {
        case Left(_) => true
        case Right(s) => Set("Printed", "Handwritten")(s)
    }, isURLParam = true)

  def getMode: String = getScalarParam(mode)

  def setMode(v: String): this.type = setScalarParam(mode, v)

  def getModeCol: String = getVectorParam(mode)

  def setModeCol(v: String): this.type = setVectorParam(mode, v)

  def setLocation(v: String): this.type =

  private def queryForResult(key: Option[String],
                             client: CloseableHttpClient,
                             location: URI): Option[HTTPResponseData] = {
    val get = new HttpGet()
    key.foreach(get.setHeader("Ocp-Apim-Subscription-Key", _))
    val resp = convertAndClose(sendWithRetries(client, get, getBackoffs))
    val status = IOUtils.toString(resp.entity.get.content, "UTF-8")
    status.flatMap {
      case "Succeeded" | "Failed" => Some(resp)
      case "NotStarted" | "Running" => None
      case s => throw new RuntimeException(s"Received unknown status code: $s")

  override protected def handlingFunc(client: CloseableHttpClient,
                                      request: HTTPRequestData): HTTPResponseData = {
    val response = HandlingUtils.advanced(getBackoffs: _*)(client, request)
    if (response.statusLine.statusCode == 202) {
      val location = new URI(response.headers.filter( == "Operation-Location").head.value)
      val maxTries = getMaxPollingRetries
      val key = request.headers.find( == "Ocp-Apim-Subscription-Key").map(_.value)
      val it = (0 to maxTries).toIterator.flatMap { _ =>
        queryForResult(key, client, location).orElse({
          blocking {Thread.sleep(getPollingDelay.toLong)}
      if (it.hasNext) {
      } else {
        throw new TimeoutException(
          s"Querying for results did not complete within $maxTries tries")
    } else {

  override protected def responseDataType: DataType = RTResponse.schema

object GenerateThumbnails extends ComplexParamsReadable[GenerateThumbnails] with Serializable

class GenerateThumbnails(override val uid: String)
  extends CognitiveServicesBase(uid) with HasImageInput
    with HasWidth with HasHeight with HasSmartCropping
    with HasInternalJsonOutputParser with HasCognitiveServiceInput {

  def this() = this(Identifiable.randomUID("GenerateThumbnails"))

  override protected def getInternalOutputParser(schema: StructType): HTTPOutputParser = {
    new CustomOutputParser().setUDF({ r: HTTPResponseData => })

  override def responseDataType: DataType = BinaryType

  def setLocation(v: String): this.type =


object AnalyzeImage extends ComplexParamsReadable[AnalyzeImage]

class AnalyzeImage(override val uid: String)
  extends CognitiveServicesBase(uid) with HasImageInput
    with HasInternalJsonOutputParser with HasCognitiveServiceInput {

  val visualFeatures = new ServiceParam[Seq[String]](
    this, "visualFeatures", "what visual feature types to return",
    { spd: ServiceParamData[Seq[String]] => match {
        case Left(seq) => seq.forall(Set(
          "Categories", "Tags", "Description", "Faces", "ImageType", "Color", "Adult"
        case _ => true
    isURLParam = true,
    toValueString = { seq => seq.mkString(",") }

  def getVisualFeatures: Seq[String] = getScalarParam(visualFeatures)

  def getVisualFeaturesCol: String = getVectorParam(visualFeatures)

  def setVisualFeatures(v: Seq[String]): this.type = setScalarParam(visualFeatures, v)

  def setVisualFeaturesCol(v: String): this.type = setVectorParam(visualFeatures, v)

  val details = new ServiceParam[Seq[String]](
    this, "details", "what visual feature types to return",
    { spd: ServiceParamData[Seq[String]] => match {
        case Left(seq) => seq.forall(Set("Celebrities", "Landmarks"))
        case _ => true
    isURLParam = true,
    toValueString = { seq => seq.mkString(",") }

  def getDetails: Seq[String] = getScalarParam(details)

  def getDetailsCol: String = getVectorParam(details)

  def setDetails(v: Seq[String]): this.type = setScalarParam(details, v)

  def setDetailsCol(v: String): this.type = setVectorParam(details, v)

  val language = new ServiceParam[String](
    this, "language", "the language of the response (en if none given)", isURLParam = true

  def getLanguage: String = getScalarParam(language)

  def getLanguageCol: String = getVectorParam(language)

  def setLanguage(v: String): this.type = setScalarParam(language, v)

  def setLanguageCol(v: String): this.type = setVectorParam(language, v)

  def setDefaultLanguage(v: String): this.type = setDefaultValue(language, v)

  setDefault(language, ServiceParamData(None, Some("en")))

  def this() = this(Identifiable.randomUID("AnalyzeImage"))

  override def responseDataType: DataType = AIResponse.schema

  def setLocation(v: String): this.type =


object RecognizeDomainSpecificContent
  extends ComplexParamsReadable[RecognizeDomainSpecificContent] with Serializable {

  def getMostProbableCeleb(inputCol: String, outputCol: String): UDFTransformer = {
    val fromRow = DSIRResponse.makeFromRowConverter
    new UDFTransformer()
        { r: Row =>
          Option(r).map { r =>
            fromRow(r).result.celebrities.flatMap {
              case Seq() => None
              case celebs => Some(celebs.maxBy(_.confidence).name)

class RecognizeDomainSpecificContent(override val uid: String)
  extends CognitiveServicesBase(uid) with HasImageInput
    with HasServiceParams with HasCognitiveServiceInput
    with HasInternalJsonOutputParser {

  def this() = this(Identifiable.randomUID("RecognizeDomainSpecificContent"))

  val model = new ServiceParam[String](this, "model",
    "the domain specific model: celebrities, landmarks")

  def setModel(v: String): this.type = setScalarParam(model, v)

  def setModelCol(v: String): this.type = setVectorParam(model, v)

  override def responseDataType: DataType = DSIRResponse.schema

  def setLocation(v: String): this.type =

  override protected def prepareUrl: Row => String = { r => getUrl + s"/models/${getValue(r, model)}/analyze" }


object TagImage extends ComplexParamsReadable[TagImage]

class TagImage(override val uid: String)
  extends CognitiveServicesBase(uid) with HasImageInput
    with HasCognitiveServiceInput with HasInternalJsonOutputParser {

  def this() = this(Identifiable.randomUID("TagImage"))

  def setLocation(v: String): this.type =

  override def responseDataType: DataType = TagImagesResponse.schema

  private def validateLanguage(spd: ServiceParamData[String]): Boolean = { {
      case Left(lang) => Set("en", "es", "ja", "pt", "zh")(lang)
      case _ => true

  val language = new ServiceParam[String](this, "language",
    "The desired language for output generation.",
    isRequired = false, isURLParam = true, isValid = validateLanguage)

  def setLanguage(v: String): this.type = setScalarParam(language, v)

  def setLanguageCol(v: String): this.type = setVectorParam(language, v)

  setDefault(language -> ServiceParamData(None, Some("en")))

object DescribeImage extends ComplexParamsReadable[DescribeImage]

class DescribeImage(override val uid: String)
  extends CognitiveServicesBase(uid) with HasCognitiveServiceInput
    with HasImageInput with HasInternalJsonOutputParser {

  def this() = this(Identifiable.randomUID("DescribeImage"))

  override def responseDataType: DataType = DescribeImageResponse.schema

  def setLocation(v: String): this.type =

  val maxCandidates = new ServiceParam[Int](this, "maxCandidates", "Maximum candidate descriptions to return",
    isURLParam = true

  def setMaxCandidates(v: Int): this.type = setScalarParam(maxCandidates, v)

  def setMaxCandidatesCol(v: String): this.type = setVectorParam(maxCandidates, v)

  setDefault(maxCandidates, ServiceParamData(None, Some(1)))

  val language = new ServiceParam[String](this, "language", "Language of image description",
    isValid = { spd: ServiceParamData[String] => {
        case Left(lang) => Set("en", "ja", "pt", "zh")(lang)
        case _ => true
    isURLParam = true

  def setLanguage(v: String): this.type = setScalarParam(language, v)

  def setLanguageCol(v: String): this.type = setVectorParam(language, v)

  setDefault(language, ServiceParamData(None, Some("en")))


© 2015 - 2024 Weber Informatics LLC | Privacy Policy