Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.annotators.audio
import com.johnsnowlabs.ml.ai.Whisper
import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig
import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWrappers
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.{
EncoderDecoderWrappers => OpenvinoEncoderDecoderWrappers
}
import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel}
import com.johnsnowlabs.ml.tensorflow.{
ReadTensorflowModel,
TensorflowWrapper,
WriteTensorflowModel
}
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadJsonStringAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.{Preprocessor, WhisperPreprocessor}
import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature}
import com.johnsnowlabs.util.Version
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.SparkSession
import org.json4s._
import org.json4s.jackson.JsonMethods._
/** Whisper Model with a language modeling head on top for Connectionist Temporal Classification
* (CTC).
*
* Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of
* multilingual and multitask supervised data collected from the web. It transcribe in multiple
* languages, as well as translate from those languages into English.
*
* The audio needs to be provided pre-processed an array of floats.
*
* For multilingual models, the language and the task (transcribe or translate) can be set with
* `setLanguage` and `setTask`.
*
* Note that at the moment, this annotator only supports greedy search and only Spark Versions
* 3.4 and up are supported.
*
* Pretrained models can be loaded with `pretrained` of the companion object:
* {{{
* val speechToText = WhisperForCTC.pretrained()
* .setInputCols("audio_assembler")
* .setOutputCol("text")
* }}}
* The default model is `"asr_whisper_tiny_opt"`, if no name is provided.
*
* For available pretrained models please see the [[https://sparknlp.org/models Models Hub]].
*
* To see which models are compatible and how to import them see
* [[https://github.com/JohnSnowLabs/spark-nlp/discussions/5669]] and to see more extended
* examples, see
* [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTCTest.scala WhisperForCTCTestSpec]].
*
* '''References:'''
*
* [[https://arxiv.org/abs/2212.04356 Robust Speech Recognition via Large-Scale Weak Supervision]]
*
* '''Paper Abstract:'''
*
* ''We study the capabilities of speech processing systems trained simply to predict large
* amounts of transcripts of audio on the internet. When scaled to 680,000 hours of multilingual
* and multitask supervision, the resulting models generalize well to standard benchmarks and are
* often competitive with prior fully supervised results but in a zero- shot transfer setting
* without the need for any fine- tuning. When compared to humans, the models approach their
* accuracy and robustness. We are releasing models and inference code to serve as a foundation
* for further work on robust speech processing.''
*
* ==Example==
*
* {{{
* import spark.implicits._
* import com.johnsnowlabs.nlp.base._
* import com.johnsnowlabs.nlp.annotators._
* import com.johnsnowlabs.nlp.annotators.audio.WhisperForCTC
* import org.apache.spark.ml.Pipeline
*
* val audioAssembler: AudioAssembler = new AudioAssembler()
* .setInputCol("audio_content")
* .setOutputCol("audio_assembler")
*
* val speechToText: WhisperForCTC = WhisperForCTC
* .pretrained()
* .setInputCols("audio_assembler")
* .setOutputCol("text")
*
* val pipeline: Pipeline = new Pipeline().setStages(Array(audioAssembler, speechToText))
*
* val bufferedSource =
* scala.io.Source.fromFile("src/test/resources/audio/txt/librispeech_asr_0.txt")
*
* val rawFloats = bufferedSource
* .getLines()
* .map(_.split(",").head.trim.toFloat)
* .toArray
* bufferedSource.close
*
* val processedAudioFloats = Seq(rawFloats).toDF("audio_content")
*
* val result = pipeline.fit(processedAudioFloats).transform(processedAudioFloats)
* result.select("text.result").show(truncate = false)
* +------------------------------------------------------------------------------------------+
* |result |
* +------------------------------------------------------------------------------------------+
* |[ Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.]|
* +------------------------------------------------------------------------------------------+
* }}}
*
* @param uid
* required uid for storing annotator to disk
* @groupname anno Annotator types
* @groupdesc anno
* Required input and expected output annotator types
* @groupname Ungrouped Members
* @groupname param Parameters
* @groupname setParam Parameter setters
* @groupname getParam Parameter getters
* @groupname Ungrouped Members
* @groupprio param 1
* @groupprio anno 2
* @groupprio Ungrouped 3
* @groupprio setParam 4
* @groupprio getParam 5
* @groupdesc param
* A list of (hyper-)parameter keys this annotator can take. Users can set and get the
* parameter values through setters and getters, respectively.
*/
class WhisperForCTC(override val uid: String)
extends AnnotatorModel[WhisperForCTC]
with HasBatchedAnnotateAudio[WhisperForCTC]
with HasAudioFeatureProperties
with WriteTensorflowModel
with WriteOpenvinoModel
with WriteOnnxModel
with HasEngine
with HasGeneratorProperties
with HasProtectedParams {
override val outputAnnotatorType: AnnotatorType = AnnotatorType.DOCUMENT
override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.AUDIO)
/** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
* type
*/
def this() = this(Identifiable.randomUID("WhisperForCTC"))
/** Optional language to set for the transcription. The imported model needs to support multiple
* languages.
* @group param
*/
val language =
new Param[String](
this,
"language",
"Optional parameter to set the language for the transcription.")
/** Sets the language for the audio, formatted to e.g. `<|en|>`. Check the model description for
* supported languages.
*
* @group setParam
*/
def setLanguage(value: String): this.type = {
require(getIsMultilingual, "Only multilingual models can have the language set.")
require(
value.length == 6 && value.startsWith("<|") && value.endsWith("|>"),
"The language does not have the right format." +
" Should be a two letter code enclosed in angle brackets with a vertical line (e.g. <|en|>).")
require(getModelIfNotSet.tokenInVocabulary(value), "Language was not found in vocabulary.")
set(language, value)
this
}
/** @group getParam */
def getLanguage: Option[String] = get(this.language)
/** Sets the formatted task for the audio. Either `<|translate|>` or `<|transcribe|>`.
*
* Only multilingual models can do translation.
*
* @group setParam
*/
override def setTask(value: String): this.type = {
require(
getIsMultilingual,
"Only multilingual models can have tasks. For single language models, the default task will be transcribe.")
require(
value == "<|translate|>" || value == "<|transcribe|>",
"Task should be either '<|translate|>' or '<|transcribe|>'")
set(task, value)
this
}
/** Whether or not the model is multilingual.
*
* @group param
*/
val isMultilingual: ProtectedParam[Boolean] =
new BooleanParam(this, "isMultilingual", "Whether or not the model is multilingual.")
.setProtected()
/** @group setParam */
def setIsMultilingual(value: Boolean): this.type = {
set(isMultilingual, value)
this
}
/** @group getParam */
def getIsMultilingual: Boolean = getOrDefault(this.isMultilingual)
/** It contains TF model signatures for the loaded saved model
*
* @group param
*/
val signatures: MapFeature[AnnotatorType, AnnotatorType] =
new MapFeature[String, String](model = this, name = "signatures").setProtected()
/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
set(signatures, value)
this
}
/** @group getParam */
def getSignatures: Option[Map[String, String]] = get(this.signatures)
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group param
*/
val configProtoBytes = new IntArrayParam(
this,
"configProtoBytes",
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group setParam
*/
def setConfigProtoBytes(bytes: Array[Int]): this.type =
set(this.configProtoBytes, bytes)
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group getParam
*/
def getConfigProtoBytes: Option[Array[Byte]] =
get(this.configProtoBytes).map(_.map(_.toByte))
/** Vocabulary used to encode the words to ids */
protected[nlp] val vocabulary: MapFeature[String, Int] =
new MapFeature(this, "vocabulary").setProtected()
def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value)
def getVocabulary: Map[String, Int] = $$(vocabulary)
protected[nlp] val addedSpecialTokens: MapFeature[String, Int] =
new MapFeature(this, "addedSpecialTokens").setProtected()
protected[nlp] def setAddedSpecialTokens(value: Map[String, Int]): this.type =
set(addedSpecialTokens, value)
protected[nlp] def getAddedSpecialTokens: Map[String, Int] = $$(addedSpecialTokens)
protected[nlp] val generationConfig: StructFeature[GenerationConfig] =
new StructFeature(this, "generationConfig").setProtected()
protected[nlp] def setGenerationConfig(value: GenerationConfig): this.type =
set(generationConfig, value)
protected[nlp] def getGenerationConfig: GenerationConfig = $$(generationConfig)
protected[nlp] val preprocessor: StructFeature[WhisperPreprocessor] =
new StructFeature(this, "preprocessor").setProtected()
protected[nlp] def setPreprocessor(value: WhisperPreprocessor): this.type =
set(preprocessor, value)
protected[nlp] def getPreprocessor: WhisperPreprocessor = $$(preprocessor)
setDefault(
minOutputLength -> 0,
maxOutputLength -> 448,
doSample -> false,
temperature -> 1.0,
topK -> 1,
topP -> 1.0,
repetitionPenalty -> 1.0,
noRepeatNgramSize -> 0,
batchSize -> 2,
beamSize -> 1,
nReturnSequences -> 1,
isMultilingual -> true)
private var _model: Option[Broadcast[Whisper]] = None
/** @group getParam */
def getModelIfNotSet: Whisper = _model.get.value
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrappers: Option[EncoderDecoderWrappers],
openvinoWrapper: Option[OpenvinoEncoderDecoderWrappers]): this.type = {
if (_model.isEmpty) {
val preprocessor = getPreprocessor
_model = Some(
spark.sparkContext.broadcast(
new Whisper(
tensorflowWrapper,
onnxWrappers,
openvinoWrapper,
configProtoBytes = getConfigProtoBytes,
signatures = getSignatures,
preprocessor = preprocessor,
vocabulary = getVocabulary,
addedSpecialTokens = $$(addedSpecialTokens),
generationConfig = getGenerationConfig)))
}
this
}
override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
WhisperForCTC.suffix,
WhisperForCTC.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
case ONNX.name =>
val wrappers = getModelIfNotSet.onnxWrappers.get
writeOnnxModels(
path,
spark,
Seq(
(wrappers.encoder, "encoder_model"),
(wrappers.decoder, "decoder_model"),
(wrappers.decoderWithPast, "decoder_with_past_model")),
WhisperForCTC.suffix)
case Openvino.name =>
val wrappers = getModelIfNotSet.openvinoWrapper
writeOpenvinoModels(
path,
spark,
Seq((wrappers.get.encoder, "openvino_encoder_model.xml")),
WhisperForCTC.suffix)
writeOpenvinoModels(
path,
spark,
Seq((wrappers.get.decoder, "openvino_decoder_model.xml")),
WhisperForCTC.suffix)
writeOpenvinoModels(
path,
spark,
Seq((wrappers.get.decoderWithPast, "openvino_decoder_with_past_model.xml")),
WhisperForCTC.suffix)
}
}
/** Takes audio annotations and produces transcribed document annotations.
*
* @param batchedAnnotations
* Audio annotations in batches
* @return
* Transcribed audio as DOCUMENT type annotation
*/
override def batchAnnotate(
batchedAnnotations: Seq[Array[AnnotationAudio]]): Seq[Seq[Annotation]] = {
batchedAnnotations.map { audioAnnotations =>
if (audioAnnotations.nonEmpty) {
getModelIfNotSet.generateFromAudio(
batchAudio = audioAnnotations,
batchSize = getBatchSize,
maxOutputLength = getMaxOutputLength,
minOutputLength = getMinOutputLength,
doSample = getDoSample,
beamSize = getBeamSize,
numReturnSequences = getNReturnSequences,
temperature = getTemperature,
topK = getTopK,
topP = getTopP,
repetitionPenalty = getRepetitionPenalty,
noRepeatNgramSize = getNoRepeatNgramSize,
randomSeed = getRandomSeed,
task = getTask,
language = getLanguage)
} else Seq.empty
}
}
}
trait ReadablePretrainedWhisperForCTCModel
extends ParamsAndFeaturesReadable[WhisperForCTC]
with HasPretrained[WhisperForCTC] {
override val defaultModelName: Some[String] = Some("asr_whisper_tiny_opt")
override val defaultLang: String = "xx"
/** Java compliant-overrides */
override def pretrained(): WhisperForCTC = super.pretrained()
override def pretrained(name: String): WhisperForCTC = super.pretrained(name)
override def pretrained(name: String, lang: String): WhisperForCTC =
super.pretrained(name, lang)
override def pretrained(name: String, lang: String, remoteLoc: String): WhisperForCTC =
super.pretrained(name, lang, remoteLoc)
}
trait ReadWhisperForCTCDLModel
extends ReadTensorflowModel
with ReadOnnxModel
with ReadOpenvinoModel {
this: ParamsAndFeaturesReadable[WhisperForCTC] =>
override val tfFile: String = "whisper_ctc_tensorflow"
override val onnxFile: String = "whisper_ctc_onnx"
override val openvinoFile: String = "whisper_ctc_openvino"
val suffix: String = "_whisper_ctc"
private def checkVersion(spark: SparkSession): Unit = {
val version = Version.parse(spark.version).toFloat
require(version >= 3.4, "WhisperForCTC requires Spark versions 3.4 and up.")
}
def readModel(instance: WhisperForCTC, path: String, spark: SparkSession): Unit = {
checkVersion(spark)
instance.getEngine match {
case TensorFlow.name =>
val tfWrapper = readTensorflowModel(
path,
spark,
WhisperForCTC.suffix,
savedSignatures = instance.getSignatures)
instance.setModelIfNotSet(spark, Some(tfWrapper), None, None)
case ONNX.name =>
val wrappers =
readOnnxModels(
path,
spark,
Seq("encoder_model", "decoder_model", "decoder_with_past_model"),
WhisperForCTC.suffix,
dataFilePostfix = ".onnx_data")
val onnxWrappers = EncoderDecoderWrappers(
wrappers("encoder_model"),
decoder = wrappers("decoder_model"),
decoderWithPast = wrappers("decoder_with_past_model"))
instance.setModelIfNotSet(spark, None, Some(onnxWrappers), None)
case Openvino.name =>
val decoderWrappers =
readOpenvinoModels(path, spark, Seq("openvino_decoder_model.xml"), suffix)
val encoderWrappers =
readOpenvinoModels(path, spark, Seq("openvino_encoder_model.xml"), suffix)
val decoderWithPastWrappers =
readOpenvinoModels(path, spark, Seq("openvino_decoder_with_past_model.xml"), suffix)
val ovWrapper = {
OpenvinoEncoderDecoderWrappers(
encoder = encoderWrappers("openvino_encoder_model.xml"),
decoder = decoderWrappers("openvino_decoder_model.xml"),
decoderWithPast = decoderWithPastWrappers("openvino_decoder_with_past_model.xml"))
}
instance.setModelIfNotSet(spark, None, None, Some(ovWrapper))
case _ =>
throw new Exception(notSupportedEngineError)
}
}
addReader(readModel)
def loadSavedModel(modelPath: String, spark: SparkSession): WhisperForCTC = {
checkVersion(spark)
implicit val formats: DefaultFormats.type = DefaultFormats // for json4s
val (localModelPath, detectedEngine) =
modelSanityCheck(modelPath, isEncoderDecoder = true, withPast = true)
val ppJsonString: String = loadJsonStringAsset(localModelPath, "preprocessor_config.json")
val preprocessor: WhisperPreprocessor =
Preprocessor.loadPreprocessorConfig(ppJsonString).asInstanceOf[WhisperPreprocessor]
val addedTokens: Map[String, Int] =
try {
parse(loadJsonStringAsset(localModelPath, "added_tokens.json")).values
.asInstanceOf[Map[String, BigInt]]
.map {
case (key, value) if value.isValidInt => (key, value.toInt)
case _ =>
throw new IllegalArgumentException(
"Could not convert BigInt to Int while parsing added_tokens.json")
}
} catch {
case _: IllegalArgumentException =>
Map.empty
}
val vocabMap: Map[String, Int] = {
val vocabJsonContent = loadJsonStringAsset(localModelPath, "vocab.json")
parse(vocabJsonContent, useBigIntForLong = true).values
.asInstanceOf[Map[String, BigInt]]
.map {
case (key, value) if value.isValidInt => (key, value.toInt)
case _ =>
throw new IllegalArgumentException(
"Could not convert BigInt to Int while parsing vocab.json")
}
}
val modelConfig: JValue =
parse(loadJsonStringAsset(localModelPath, "config.json"))
val beginSuppressTokens: Array[Int] =
(modelConfig \ "begin_suppress_tokens").extract[Array[Int]]
val suppressTokenIds: Array[Int] =
(modelConfig \ "suppress_tokens").extract[Array[Int]]
val forcedDecoderIds: Array[(Int, Int)] =
(modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map {
case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 =>
(idxWithTokenId(0), idxWithTokenId(1))
case _ =>
throw new Exception(
"Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.")
}
val maxOutputLength = (modelConfig \ "max_length").extract[Int]
val bosTokenId = (modelConfig \ "decoder_start_token_id").extract[Int]
val eosTokenId = (modelConfig \ "eos_token_id").extract[Int]
val padTokenId = (modelConfig \ "pad_token_id").extract[Int]
val vocabSize = (modelConfig \ "vocab_size").extract[Int]
// 3 means multilingual (for official models), e.g. [<|en|>, <|transcribe|>, <|notimestamps|>]
// Single language models only force the force token to be <|notimestamps|>
// Some custom models might have no forced tokens at all, assume its multilingual
val isMultilingual = forcedDecoderIds.length != 1
def arrayOrNone[T](array: Array[T]): Option[Array[T]] =
if (array.nonEmpty) Some(array) else None
val annotatorModel = new WhisperForCTC()
.setVocabulary(vocabMap)
.setMaxOutputLength(maxOutputLength)
.setDoNormalize(preprocessor.do_normalize)
.setReturnAttentionMask(preprocessor.return_attention_mask)
.setPaddingSide(preprocessor.padding_side)
.setPaddingValue(preprocessor.padding_value)
.setFeatureSize(preprocessor.feature_size)
.setSamplingRate(preprocessor.sampling_rate)
.setAddedSpecialTokens(addedTokens)
.setGenerationConfig(
GenerationConfig(
bosTokenId,
padTokenId,
eosTokenId,
vocabSize,
arrayOrNone(beginSuppressTokens),
arrayOrNone(suppressTokenIds),
arrayOrNone(forcedDecoderIds)))
.setPreprocessor(preprocessor)
.setIsMultilingual(isMultilingual)
annotatorModel.set(annotatorModel.engine, detectedEngine)
detectedEngine match {
case TensorFlow.name =>
val (tfWrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)
val _signatures = signatures match {
case Some(s) => s
case None => throw new Exception("Cannot load signature definitions from model!")
}
/** the order of setSignatures is important if we use getSignatures inside
* setModelIfNotSet
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, Some(tfWrapper), None, None)
case ONNX.name =>
val onnxWrapperEncoder =
OnnxWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
modelName = "encoder_model",
onnxFileSuffix = None)
val onnxWrapperDecoder =
OnnxWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
modelName = "decoder_model",
onnxFileSuffix = None)
val onnxWrapperDecoderWithPast =
OnnxWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
modelName = "decoder_with_past_model",
onnxFileSuffix = None)
val onnxWrappers = EncoderDecoderWrappers(
onnxWrapperEncoder,
onnxWrapperDecoder,
onnxWrapperDecoderWithPast)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrappers), None)
case Openvino.name =>
val openvinoEncoderWrapper =
OpenvinoWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
detectedEngine = detectedEngine,
modelName = "openvino_encoder_model")
val openvinoDecoderWrapper =
OpenvinoWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
detectedEngine = detectedEngine,
modelName = "openvino_decoder_model")
val openvinoDecoderWithPastWrapper =
OpenvinoWrapper.read(
spark,
localModelPath,
zipped = false,
useBundle = true,
detectedEngine = detectedEngine,
modelName = "openvino_decoder_with_past_model")
val openvinoWrapper =
OpenvinoEncoderDecoderWrappers(
encoder = openvinoEncoderWrapper,
decoder = openvinoDecoderWrapper,
decoderWithPast = openvinoDecoderWithPastWrapper)
annotatorModel.setModelIfNotSet(spark, None, None, Some(openvinoWrapper))
case _ =>
throw new Exception(notSupportedEngineError)
}
annotatorModel
}
}
/** This is the companion object of [[WhisperForCTC]]. Please refer to that class for the
* documentation.
*/
object WhisperForCTC extends ReadablePretrainedWhisperForCTCModel with ReadWhisperForCTCDLModel