All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.azure.synapse.ml.services.speech.SpeechToTextSDK.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.services.speech

import com.microsoft.azure.synapse.ml.build.BuildInfo
import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.services.speech.SpeechFormat._
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.{DatasetExtensions, SparkBindings}
import com.microsoft.azure.synapse.ml.core.utils.OsUtils
import com.microsoft.azure.synapse.ml.io.http.HasURL
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.cognitiveservices.speech._
import com.microsoft.cognitiveservices.speech.audio._
import com.microsoft.cognitiveservices.speech.transcription.{Conversation, ConversationTranscriber,
  ConversationTranscriptionEventArgs, Participant}
import com.microsoft.cognitiveservices.speech.util.EventHandler
import org.apache.commons.io.FilenameUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.injections.SConf
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import spray.json._

import java.io.{BufferedInputStream, ByteArrayInputStream, Closeable, InputStream}
import java.lang.ProcessBuilder.Redirect
import java.net.{URI, URL}
import java.util.UUID
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import scala.concurrent.{ExecutionContext, Future, blocking}
import scala.language.existentials

object SpeechToTextSDK extends ComplexParamsReadable[SpeechToTextSDK]

//scalastyle:off no.finalize
private[ml] class BlockingQueueIterator[T](lbq: LinkedBlockingQueue[Option[T]],
                                           onClose: => Unit) extends Iterator[T] with Closeable {
  var nextVar: Option[T] = None
  var isDone = false
  var takeAnother = true

  override def hasNext: Boolean = {
    if (takeAnother) {
      nextVar = lbq.take()
      takeAnother = false
      isDone = nextVar.isEmpty
    }
    if (isDone) {
      close()
    }
    !isDone
  }

  override def close(): Unit = {
    onClose
  }

  // This is for occurance that someone starts pulling rows but doesen't finish, like in a df.show
  override def finalize(): Unit = {
    onClose
    super.finalize()
  }

  override def next(): T = {
    takeAnother = true
    nextVar.get
  }
}
//scalastyle:on no.finalize

abstract class SpeechSDKBase extends Transformer
  with HasSetLocation with HasServiceParams
  with HasOutputCol with HasURL with HasSubscriptionKey with ComplexParamsWritable with SynapseMLLogging
  with HasSetLinkedServiceUsingLocation {

  type ResponseType <: SharedSpeechFields

  val responseTypeBinding: SparkBindings[ResponseType]

  val audioDataCol = new Param[String](this, "audioDataCol",
    "Column holding audio data, must be either ByteArrays or Strings representing file URIs"
  )

  def setAudioDataCol(v: String): this.type = set(audioDataCol, v)

  def getAudioDataCol: String = $(audioDataCol)

  val recordAudioData = new BooleanParam(this, "recordAudioData",
    "Whether to record audio data to a file location, for use only with m3u8 streams"
  )

  val endpointId = new Param[String](this, "endpointId",
    "endpoint for custom speech models"
  )

  def setEndpointId(v: String): this.type = set(endpointId, v)

  def getEndpointId: String = $(endpointId)

  def setRecordAudioData(v: Boolean): this.type = set(recordAudioData, v)

  def getRecordAudioData: Boolean = $(recordAudioData)

  setDefault(recordAudioData -> false)

  val recordedFileNameCol = new Param[String](this, "recordedFileNameCol",
    "Column holding file names to write audio data to if ``recordAudioData'' is set to true"
  )

  def setRecordedFileNameCol(v: String): this.type = set(recordedFileNameCol, v)

  def getRecordedFileNameCol: String = $(recordedFileNameCol)

  val fileType = new ServiceParam[String](
    this, "fileType", "The file type of the sound files, supported types: wav, ogg, mp3")

  def setFileType(v: String): this.type = setScalarParam(fileType, v)

  def setFileTypeCol(v: String): this.type = setVectorParam(fileType, v)

  val wordLevelTimestamps = new ServiceParam[Boolean](
    this, "wordLevelTimestamps", "Whether to request timestamps foe each indivdual word")

  def setWordLevelTimestamps(v: Boolean): this.type = setScalarParam(wordLevelTimestamps, v)

  def setWordLevelTimestampsCol(v: String): this.type = setVectorParam(wordLevelTimestamps, v)


  val language = new ServiceParam[String](this, "language",
    """
      |Identifies the spoken language that is being recognized.
    """.stripMargin.replace("\n", " ").replace("\r", " "),
    { _ => true },
    isRequired = true,
    isURLParam = true
  )

  val participantsJson = new ServiceParam[String](
    this, "participantsJson",
    "a json representation of a list of conversation participants (email, language, user)")

  def setParticipantsJson(v: String): this.type = setScalarParam(participantsJson, v)

  def setParticipants(v: Seq[(String, String, String)]): this.type =
    setParticipantsJson(v.map(t => TranscriptionParticipant(t._1, t._2, t._3)).toJson.compactPrint)

  def setParticipantsJsonCol(v: String): this.type = setVectorParam(participantsJson, v)

  def setLanguage(v: String): this.type = setScalarParam(language, v)

  def setLanguageCol(v: String): this.type = setVectorParam(language, v)

  val extraFfmpegArgs = new StringArrayParam(this, "extraFfmpegArgs",
    "extra arguments to for ffmpeg output decoding")

  def setExtraFfmpegArgs(v: Array[String]): this.type = set(extraFfmpegArgs, v)

  def getExtraFfmpegArgs: Array[String] = $(extraFfmpegArgs)

  setDefault(extraFfmpegArgs -> Array())

  val streamIntermediateResults = new BooleanParam(this, "streamIntermediateResults",
    "Whether or not to immediately return itermediate results, or group in a sequence"
  )

  def setStreamIntermediateResults(v: Boolean): this.type = set(streamIntermediateResults, v)

  def getStreamIntermediateResults: Boolean = $(streamIntermediateResults)

  setDefault(streamIntermediateResults -> true)

  val format = new ServiceParam[String](this, "format",
    """
      |Specifies the result format. Accepted values are simple and detailed. Default is simple.
    """.stripMargin.replace("\n", " ").replace("\r", " "),
    { _ => true },
    isRequired = false,
    isURLParam = true
  )

  def setFormat(v: String): this.type = setScalarParam(format, v)

  def setFormatCol(v: String): this.type = setVectorParam(format, v)

  val profanity = new ServiceParam[String](this, "profanity",
    """
      |Specifies how to handle profanity in recognition results.
      |Accepted values are masked, which replaces profanity with asterisks,
      |removed, which remove all profanity from the result, or raw,
      |which includes the profanity in the result. The default setting is masked.
    """.stripMargin.replace("\n", " ").replace("\r", " "),
    { _ => true },
    isRequired = false,
    isURLParam = true
  )

  def setProfanity(v: String): this.type = setScalarParam(profanity, v)

  def setProfanityCol(v: String): this.type = setVectorParam(profanity, v)

  def urlPath: String = "/sts/v1.0/issuetoken"

  setDefault(language -> Left("en-us"))
  setDefault(profanity -> Left("Masked"))
  setDefault(format -> Left("Simple"))
  setDefault(wordLevelTimestamps -> Left(false))

  protected def makeEventHandler[T](f: (Any, T) => Unit): EventHandler[T] = {
    new EventHandler[T] {
      def onEvent(var1: Any, var2: T): Unit = f(var1, var2)
    }
  }

  protected def getAudioFormat(fileType: String, default: Option[String]): AudioStreamFormat = {
    fileType.toLowerCase match {
      case "wav" =>
        AudioStreamFormat.getDefaultInputFormat
      case "mp3" =>
        AudioStreamFormat.getCompressedFormat(AudioStreamContainerFormat.MP3)
      case "ogg" =>
        AudioStreamFormat.getCompressedFormat(AudioStreamContainerFormat.OGG_OPUS)
      case _ if default.isDefined =>
        log.warn(s"Could not identify codec $fileType using default ${default.get} instead")
        getAudioFormat(default.get, None)
      case _ =>
        throw new IllegalArgumentException(s"Could not identify codec $fileType")
    }
  }

  //scalastyle:off cyclomatic.complexity
  private def getStream(bconf: Broadcast[SConf], //scalastyle:ignore method.length
                        isUriAudio: Boolean,
                        row: Row,
                        dynamicParamRow: Row): (InputStream, String) = {
    if (isUriAudio) { //scalastyle:ignore cyclomatic.complexity
      val uri = row.getAs[String](getAudioDataCol)
      val ffmpegCommand: Seq[String] = {
        val body = Seq("ffmpeg", "-y",
          "-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_delay_max", "2000",
          "-i", uri) ++ getExtraFfmpegArgs ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", "pipe:1")

        if (getRecordAudioData && OsUtils.IsWindows) {
          val fn = row.getAs[String](getRecordedFileNameCol)
          body ++ Seq("-acodec", "mp3", "-ab", "257k", "-f", "mp3", fn)
        } else if (getRecordAudioData && !OsUtils.IsWindows) {
          val fn = row.getAs[String](getRecordedFileNameCol)
          Seq("/bin/sh", "-c", (body ++ Seq("|", "tee", fn)).mkString(" "))
        } else {
          body
        }
      }

      val extension = FilenameUtils.getExtension(new URI(uri).getPath).toLowerCase()

      if (Set("m3u8", "m4a")(extension) && uri.startsWith("http")) {
        val proc = new ProcessBuilder()
          .redirectError(Redirect.INHERIT)
          .redirectInput(Redirect.INHERIT)
          .command(ffmpegCommand: _*)
          .start()
        val stream = proc.getInputStream

        if (getExtraFfmpegArgs.contains("-t")) {
          val timeLimit = getExtraFfmpegArgs(getExtraFfmpegArgs.indexOf("-t") + 1).toInt
          Future {
            blocking {
              proc.waitFor(timeLimit + 10, TimeUnit.SECONDS)
            }
            if (proc.isAlive) {
              proc.destroy()
              proc.waitFor(10, TimeUnit.SECONDS) //scalastyle:ignore magic.number
              proc.destroyForcibly()
              proc.waitFor()
              log.warn("Had to forcibly stop ffmpeg")
            }
          }(ExecutionContext.global)
        }

        (stream, "mp3")
      } else if (uri.startsWith("http")) {
        val conn = new URL(uri).openConnection
        conn.setConnectTimeout(5000) //scalastyle:ignore magic.number
        conn.setReadTimeout(5000) //scalastyle:ignore magic.number
        conn.connect()
        (new BufferedInputStream(conn.getInputStream), extension)
      } else {
        val path = new Path(uri)
        val fs = path.getFileSystem(bconf.value.value2)
        (fs.open(path), extension)
      }
    } else {
      val bytes = row.getAs[Array[Byte]](getAudioDataCol)
      (new ByteArrayInputStream(bytes), getValueOpt(dynamicParamRow, fileType).getOrElse("wav"))
    }
  }
  //scalastyle:on cyclomatic.complexity

  def inputStreamToText(stream: InputStream,
                        audioFormat: String,
                        uri: URI,
                        speechKey: String,
                        profanity: String,
                        wordLevelTimestamps: Boolean,
                        language: String,
                        format: String,
                        defaultAudioFormat: Option[String],
                        participants: Seq[TranscriptionParticipant]
                       ): Iterator[ResponseType]

  protected def transformAudioRows(dynamicParamColName: String,
                                   toRow: ResponseType => Row,
                                   bconf: Broadcast[SConf],
                                   isUriAudio: Boolean)(rows: Iterator[Row]): Iterator[Row] = {
    rows.flatMap { row =>
      if (shouldSkip(row)) {
        Seq(Row.fromSeq(row.toSeq :+ null)) //scalastyle:ignore null
      } else {
        val dynamicParamRow = row.getAs[Row](dynamicParamColName)
        val (stream, audioFileFormat) = getStream(bconf, isUriAudio, row, dynamicParamRow)
        val results = inputStreamToText(
          stream,
          audioFileFormat,
          new URI(getUrl),
          getValue(dynamicParamRow, subscriptionKey),
          getValue(dynamicParamRow, profanity),
          getValue(dynamicParamRow, wordLevelTimestamps),
          getValue(dynamicParamRow, language),
          getValue(dynamicParamRow, format),
          getValueOpt(dynamicParamRow, fileType),
          getValueOpt(dynamicParamRow, participantsJson)
            .getOrElse("[]")
            .parseJson.convertTo[Seq[TranscriptionParticipant]]
        )
        if (getStreamIntermediateResults) {
          results.map(speechResponse => Row.fromSeq(row.toSeq :+ toRow(speechResponse)))
        } else {
          Seq(Row.fromSeq(row.toSeq :+ results.map(speechResponse => toRow(speechResponse)).toSeq))
        }
      }
    }

  }

  def getPullStream(stream: InputStream,
                    audioFormat: String,
                    defaultAudioFormat: Option[String]): PullAudioInputStream = {
    val af = getAudioFormat(audioFormat, defaultAudioFormat)
    val pullStream = if (audioFormat == "wav") {
      AudioInputStream.createPullStream(new WavStream(stream), af)
    } else {
      AudioInputStream.createPullStream(new CompressedStream(stream), af)
    }
    pullStream
  }

  def getSpeechConfig(uri: URI,
                      speechKey: String,
                      language: String,
                      profanity: String,
                      wordLevelTimestamps: Boolean,
                      format: String): SpeechConfig = {
    val speechConfig: SpeechConfig = SpeechConfig.fromEndpoint(uri, speechKey)
    assert(speechConfig != null)
    get(endpointId).foreach(id => speechConfig.setEndpointId(id))
    speechConfig.setProperty(PropertyId.SpeechServiceResponse_ProfanityOption, profanity)
    speechConfig.setSpeechRecognitionLanguage(language)
    if (wordLevelTimestamps) {
      speechConfig.requestWordLevelTimestamps()
    }
    speechConfig.setProperty(PropertyId.SpeechServiceResponse_OutputFormatOption, format) //scalastyle:ignore token
    speechConfig
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      val df = dataset.toDF
      val schema = dataset.schema

      val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", dataset)
      val badColumns = getVectorParamMap.values.toSet.diff(schema.fieldNames.toSet)
      assert(badColumns.isEmpty,
        s"Could not find dynamic columns: $badColumns in columns: ${schema.fieldNames.toSet}")

      val dynamicParamCols = getVectorParamMap.values.toList.map(col) match {
        case Nil => Seq(lit(false).alias("placeholder"))
        case l => l
      }
      val enrichedDf = df.withColumn(dynamicParamColName, struct(dynamicParamCols: _*))
      val addedSchema = if (getStreamIntermediateResults) {
        responseTypeBinding.schema
      } else {
        ArrayType(responseTypeBinding.schema)
      }

      val enc = RowEncoder(enrichedDf.schema.add(getOutputCol, addedSchema))
      val sc = df.sparkSession.sparkContext
      val bConf = sc.broadcast(new SConf(sc.hadoopConfiguration))
      val isUriAudio = df.schema(getAudioDataCol).dataType match {
        case StringType => true
        case BinaryType => false
        case t => throw new IllegalArgumentException(s"AudioDataCol must be String or Binary Type, got: $t")
      }
      val toRow = responseTypeBinding.makeToRowConverter
      enrichedDf.mapPartitions(transformAudioRows(
        dynamicParamColName,
        toRow,
        bConf,
        isUriAudio
      ))(enc)
        .drop(dynamicParamColName)
    }, dataset.columns.length)
  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    schema(getAudioDataCol).dataType match {
      case StringType => ()
      case BinaryType => ()
      case t => throw new IllegalArgumentException(s"AudioDataCol must be String or Binary Type, got: $t")
    }
    if (getStreamIntermediateResults) {
      schema.add(getOutputCol, responseTypeBinding.schema)
    } else {
      schema.add(getOutputCol, ArrayType(responseTypeBinding.schema))
    }
  }
}

class SpeechToTextSDK(override val uid: String) extends SpeechSDKBase with SynapseMLLogging {
  logClass(FeatureNames.AiServices.Speech)

  override type ResponseType = SpeechResponse

  override val responseTypeBinding: SparkBindings[SpeechResponse] = SpeechResponse

  def this() = this(Identifiable.randomUID("SpeechToTextSDK"))

  /** @return text transcription of the audio */
  def inputStreamToText(stream: InputStream,
                        audioFormat: String,
                        uri: URI,
                        speechKey: String,
                        profanity: String,
                        wordLevelTimestamps: Boolean,
                        language: String,
                        format: String,
                        defaultAudioFormat: Option[String],
                        participants: Seq[TranscriptionParticipant]
                       ): Iterator[SpeechResponse] = {

    val speechConfig = getSpeechConfig(uri, speechKey, language, profanity, wordLevelTimestamps, format)
    val pullStream = getPullStream(stream, audioFormat, defaultAudioFormat)
    val audioConfig = AudioConfig.fromStreamInput(pullStream)
    val recognizer = new SpeechRecognizer(speechConfig, audioConfig)
    val connection = Connection.fromRecognizer(recognizer)
    connection.setMessageProperty("speech.config", "application",
      s"""{"name":"synapseml", "version": "${BuildInfo.version}"}""")
    val queue = new LinkedBlockingQueue[Option[String]]()

    def recognizedHandler(s: Any, e: SpeechRecognitionEventArgs): Unit = {
      if (e.getResult.getReason eq ResultReason.RecognizedSpeech) {
        queue.put(Some(e.getResult.getProperties.getProperty(PropertyId.SpeechServiceResponse_JsonResult)))
      }
    }

    def cleanUp(): Unit = {
      recognizer.stopContinuousRecognitionAsync().get()
      Option(pullStream).foreach(_.close())
      Option(speechConfig).foreach(_.close())
      Option(audioConfig).foreach(_.close())
    }

    def sessionStoppedHandler(s: Any, e: SessionEventArgs): Unit = {
      queue.put(None)
      cleanUp()
    }

    recognizer.recognized.addEventListener(makeEventHandler(recognizedHandler))
    recognizer.sessionStopped.addEventListener(makeEventHandler(sessionStoppedHandler))
    recognizer.startContinuousRecognitionAsync.get

    if (getExtraFfmpegArgs.contains("-t")) {
      val timeLimit = getExtraFfmpegArgs(getExtraFfmpegArgs.indexOf("-t") + 1).toInt
      Future {
        blocking {
          Thread.sleep((timeLimit + 20) * 1000)
        }
        queue.put(None)
        cleanUp()
      }(ExecutionContext.global)
    }

    new BlockingQueueIterator[String](queue, cleanUp()).map { jsonString =>
      //println(jsonString)
      jsonString.parseJson.convertTo[SpeechResponse]
    }
  }
}

object ConversationTranscription extends ComplexParamsReadable[ConversationTranscription]

class ConversationTranscription(override val uid: String) extends SpeechSDKBase with SynapseMLLogging {
  logClass(FeatureNames.AiServices.Speech)

  override type ResponseType = TranscriptionResponse

  override val responseTypeBinding: SparkBindings[TranscriptionResponse] = TranscriptionResponse

  def this() = this(Identifiable.randomUID("ConversationTranscription"))

  /** @return text transcription of the audio */
  def inputStreamToText(stream: InputStream, //scalastyle:ignore method.length
                        audioFormat: String,
                        uri: URI,
                        speechKey: String,
                        profanity: String,
                        wordLevelTimestamps: Boolean,
                        language: String,
                        format: String,
                        defaultAudioFormat: Option[String],
                        participants: Seq[TranscriptionParticipant]
                       ): Iterator[TranscriptionResponse] = {
    val speechConfig = getSpeechConfig(uri, speechKey, language, profanity, wordLevelTimestamps, format)
    speechConfig.setProperty("ConversationTranscriptionInRoomAndOnline", "true")

    speechConfig.setServiceProperty("transcriptionMode",
      "RealTimeAndAsync", ServicePropertyChannel.UriQueryParameter)

    val guid = UUID.randomUUID().toString
    val conversation = Conversation.createConversationAsync(speechConfig, guid).get()
    participants.foreach(p =>
      conversation.addParticipantAsync(
        Participant.from(p.name, p.language, p.signature)
      ).get()
    )

    val pullStream = getPullStream(stream, audioFormat, defaultAudioFormat)
    val audioConfig = AudioConfig.fromStreamInput(pullStream)
    audioConfig.setProperty("f0f5debc-f8c9-4892-ac4b-90a7ab359fd2", "true")

    val transcriber = new ConversationTranscriber(audioConfig)

    conversation.getProperties.setProperty("DifferentiateGuestSpeakers", "true")

    transcriber.joinConversationAsync(conversation).get()
    val connection = Connection.fromRecognizer(transcriber)

    connection.setMessageProperty("speech.config", "application",
      s"""{"name":"synapseml", "version": "${BuildInfo.version}"}""")
    val queue = new LinkedBlockingQueue[Option[String]]()

    def cleanUp(): Unit = {
      transcriber.stopTranscribingAsync().get()
      Option(conversation).foreach(_.close())
      Option(pullStream).foreach(_.close())
      Option(speechConfig).foreach(_.close())
      Option(audioConfig).foreach(_.close())
    }

    def recognizedHandler(s: Any, e: ConversationTranscriptionEventArgs): Unit = {
      if (e.getResult.getReason eq ResultReason.RecognizedSpeech) {
        queue.put(Some(e.getResult.getProperties.getProperty(PropertyId.SpeechServiceResponse_JsonResult)))
      }
    }

    def sessionStoppedHandler(s: Any, e: SessionEventArgs): Unit = {
      queue.put(None)
      cleanUp()
    }

    transcriber.transcribed.addEventListener(makeEventHandler(recognizedHandler))
    transcriber.sessionStopped.addEventListener(makeEventHandler(sessionStoppedHandler))
    transcriber.startTranscribingAsync().get

    if (getExtraFfmpegArgs.contains("-t")) {
      val timeLimit = getExtraFfmpegArgs(getExtraFfmpegArgs.indexOf("-t") + 1).toInt
      Future {
        blocking {
          Thread.sleep((timeLimit + 20) * 1000)
        }
        queue.put(None)
        cleanUp()
      }(ExecutionContext.global)
    }

    new BlockingQueueIterator[String](queue, cleanUp()).map { jsonString =>
      //println(jsonString)
      jsonString.parseJson.convertTo[TranscriptionResponse]
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy