com.johnsnowlabs.ml.ai.XlmRoberta.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.ai
import ai.onnxruntime.{OnnxTensor, TensorInfo}
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._
/** Sentence-level embeddings using XLM-RoBERTa. The XLM-RoBERTa model was proposed in
* '''Unsupervised Cross-lingual Representation Learning at Scale'''
* [[https://arxiv.org/abs/1911.02116]] by Alexis Conneau, Kartikay Khandelwal, Naman Goyal,
* Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
* Zettlemoyer and Veselin Stoyanov. It is based on Facebook's RoBERTa model released in 2019. It
* is a large multi-lingual language model, trained on 2.5TB of filtered CommonCrawl data.
*
* The abstract from the paper is the following:
*
* This paper shows that pretraining multilingual language models at scale leads to significant
* performance gains for a wide range of cross-lingual transfer tasks. We train a
* Transformer-based masked language model on one hundred languages, using more than two
* terabytes of filtered CommonCrawl data. Our model, dubbed XLM-R, significantly outperforms
* multilingual BERT (mBERT) on a variety of cross-lingual benchmarks, including +13.8% average
* accuracy on XNLI, +12.3% average F1 score on MLQA, and +2.1% average F1 score on NER. XLM-R
* performs particularly well on low-resource languages, improving 11.8% in XNLI accuracy for
* Swahili and 9.2% for Urdu over the previous XLM model. We also present a detailed empirical
* evaluation of the key factors that are required to achieve these gains, including the
* trade-offs between (1) positive transfer and capacity dilution and (2) the performance of high
* and low resource languages at scale. Finally, we show, for the first time, the possibility of
* multilingual modeling without sacrificing per-language performance; XLM-Ris very competitive
* with strong monolingual models on the GLUE and XNLI benchmarks. We will make XLM-R code, data,
* and models publicly available.
*
* Tips:
*
* - XLM-RoBERTa is a multilingual model trained on 100 different languages. Unlike some XLM
* multilingual models, it does not require '''lang''' parameter to understand which language
* is used, and should be able to determine the correct language from the input ids.
* - This implementation is the same as RoBERTa. Refer to the
* [[com.johnsnowlabs.nlp.embeddings.RoBertaEmbeddings]] for usage examples as well as the
* information relative to the inputs and outputs.
*
* @param tensorflowWrapper
* XlmRoberta Model wrapper with TensorFlowWrapper
* @param onnxWrapper
* XlmRoberta Model wrapper with ONNX Wrapper
* @param openvinoWrapper
* XlmRoberta Model wrapper with OpenVINO Wrapper
* @param spp
* XlmRoberta SentencePiece model with SentencePieceWrapper
* @param caseSensitive
* Whether or not the tokenizer should be lowercase
* @param configProtoBytes
* Configuration for TensorFlow session
* @param signatures
* Model's inputs and output(s) signatures
*/
private[johnsnowlabs] class XlmRoberta(
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val openvinoWrapper: Option[OpenvinoWrapper],
val spp: SentencePieceWrapper,
caseSensitive: Boolean = true,
configProtoBytes: Option[Array[Byte]] = None,
signatures: Option[Map[String, String]] = None,
modelArch: String = ModelArch.wordEmbeddings)
extends Serializable {
protected val logger: Logger = LoggerFactory.getLogger("XlmRoberta")
val _tfRoBertaSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else if (openvinoWrapper.isDefined) Openvino.name
else TensorFlow.name
private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions
private val SentenceStartTokenId = 0
private val SentenceEndTokenId = 2
private val SentencePadTokenId = 1
private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("▁")
private def sessionWarmup(): Unit = {
val dummyInput =
Array(0, 10842, 25, 7, 24814, 2037, 70, 148735, 100, 70, 5117, 53498, 6620, 5, 2)
if (modelArch == ModelArch.wordEmbeddings) {
tag(Seq(dummyInput))
} else if (modelArch == ModelArch.sentenceEmbeddings) {
tagSequence(Seq(dummyInput))
}
}
sessionWarmup()
def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length
val embeddings = detectedEngine match {
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)
val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava
// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
embeddings
} finally if (results != null) results.close()
} catch {
case e: Exception =>
// Handle exceptions by logging or other means.
e.printStackTrace()
Array.empty[Float] // Return an empty array or appropriate error handling
} finally {
// Close tensors outside the try-catch to avoid repeated null checks.
// These resources are initialized before the try-catch, so they should be closed here.
tokenTensors.close()
maskTensors.close()
}
case Openvino.name =>
val (tokenTensors, maskTensors) = PrepareEmbeddings.prepareOvLongBatchTensors(
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength)
val inferRequest = openvinoWrapper.get.getCompiledModel().create_infer_request()
inferRequest.set_tensor("input_ids", tokenTensors)
inferRequest.set_tensor("attention_mask", maskTensors)
inferRequest.infer()
val result = inferRequest.get_tensor("last_hidden_state")
val embeddings = result.data()
embeddings
case _ =>
val tensors = new TensorResources()
val (tokenTensors, maskTensors) =
PrepareEmbeddings.prepareBatchTensors(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength,
sentencePadTokenId = SentencePadTokenId)
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner
runner
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
tokenTensors)
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.fetch(
_tfRoBertaSignatures
.getOrElse(
ModelSignatureConstants.LastHiddenState.key,
"missing_sequence_output_key"))
val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)
tokenTensors.close()
maskTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
embeddings
}
PrepareEmbeddings.prepareBatchWordEmbeddings(
batch,
embeddings,
maxSentenceLength,
batchLength)
}
def tagSequence(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length
val embeddings = detectedEngine match {
case ONNX.name =>
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
val attentionMask = batch
.map(sentence => sentence.map(x => if (x == SentencePadTokenId) 0L else 1L))
.toArray
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(env, attentionMask)
val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava
val results = runner.run(inputs)
val lastHiddenState = results.get("last_hidden_state").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val flattenEmbeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
val embeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape)
val normalizedEmbeddings = LinAlg.l2Normalize(embeddings)
LinAlg.denseMatrixToArray(normalizedEmbeddings)
} finally if (results != null) results.close()
case TensorFlow.name =>
val tensors = new TensorResources()
val (tokenTensors, maskTensors) =
PrepareEmbeddings.prepareBatchTensors(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength,
sentencePadTokenId = SentencePadTokenId)
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner
runner
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
tokenTensors)
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.fetch(_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key"))
val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)
tokenTensors.close()
maskTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
val dim = embeddings.length / batchLength
embeddings.grouped(dim).toArray
}
embeddings
}
def predict(
tokenizedSentences: Seq[TokenizedSentence],
batchSize: Int,
maxSentenceLength: Int): Seq[WordpieceEmbeddingsSentence] = {
val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength)
wordPieceTokenizedSentences.zipWithIndex
.grouped(batchSize)
.flatMap { batch =>
val encoded = PrepareEmbeddings.prepareBatchInputsWithPadding(
batch,
maxSentenceLength,
SentenceStartTokenId,
SentenceEndTokenId,
SentencePadTokenId)
val vectors = tag(encoded)
/*Combine tokens and calculated embeddings*/
batch.zip(vectors).map { case (sentence, tokenVectors) =>
val tokenLength = sentence._1.tokens.length
/*All wordpiece embeddings*/
val tokenEmbeddings = tokenVectors.slice(1, tokenLength + 1)
val originalIndexedTokens = tokenizedSentences(sentence._2)
val tokensWithEmbeddings =
sentence._1.tokens.zip(tokenEmbeddings).flatMap { case (token, tokenEmbedding) =>
val tokenWithEmbeddings = TokenPieceEmbeddings(token, tokenEmbedding)
val originalTokensWithEmbeddings = originalIndexedTokens.indexedTokens
.find(p =>
p.begin == tokenWithEmbeddings.begin && tokenWithEmbeddings.isWordStart)
.map { token =>
val originalTokenWithEmbedding = TokenPieceEmbeddings(
TokenPiece(
wordpiece = tokenWithEmbeddings.wordpiece,
token = if (caseSensitive) token.token else token.token.toLowerCase(),
pieceId = tokenWithEmbeddings.pieceId,
isWordStart = tokenWithEmbeddings.isWordStart,
begin = token.begin,
end = token.end),
tokenEmbedding)
originalTokenWithEmbedding
}
originalTokensWithEmbeddings
}
WordpieceEmbeddingsSentence(tokensWithEmbeddings, originalIndexedTokens.sentenceIndex)
}
}
.toSeq
}
def predictSequence(
sentences: Seq[Sentence],
batchSize: Int,
maxSentenceLength: Int): Seq[Annotation] = {
val wordPieceTokenizedSentences = sentences
.grouped(batchSize)
.flatMap { batch =>
tokenizeSentence(batch, maxSentenceLength)
}
.toSeq
/*Run embeddings calculation by batches*/
wordPieceTokenizedSentences
.zip(sentences)
.zipWithIndex
.grouped(batchSize)
.flatMap { batch =>
val tokensBatch = batch.map(x => (x._1._1, x._2))
val sentencesBatch = batch.map(x => x._1._2)
val encoded = PrepareEmbeddings.prepareBatchInputsWithPadding(
tokensBatch,
maxSentenceLength,
SentenceStartTokenId,
SentenceEndTokenId,
SentencePadTokenId)
val embeddings = tagSequence(encoded)
sentencesBatch.zip(embeddings).map { case (sentence, vectors) =>
Annotation(
annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS,
begin = sentence.start,
end = sentence.end,
result = sentence.content,
metadata = Map(
"sentence" -> sentence.index.toString,
"token" -> sentence.content,
"pieceId" -> "-1",
"isWordStart" -> "true"),
embeddings = vectors)
}
}
.toSeq
}
def tokenizeWithAlignment(
sentences: Seq[TokenizedSentence],
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
val encoder =
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)
val sentenceTokenPieces = sentences.map { s =>
val trimmedSentence = s.indexedTokens.take(maxSeqLength - 2)
val wordpieceTokens =
trimmedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
WordpieceTokenizedSentence(wordpieceTokens)
}
sentenceTokenPieces
}
def tokenizeSentence(
sentences: Seq[Sentence],
maxSeqLength: Int): Seq[WordpieceTokenizedSentence] = {
val encoder =
new SentencepieceEncoder(spp, caseSensitive, SentencePieceDelimiterId, pieceIdOffset = 1)
val sentenceTokenPieces = sentences.map { s =>
val wordpieceTokens = encoder.encodeSentence(s, maxLength = maxSeqLength).take(maxSeqLength)
WordpieceTokenizedSentence(wordpieceTokens)
}
sentenceTokenPieces
}
}