com.johnsnowlabs.ml.tensorflow.TensorflowXlnetClassification.scala Maven / Gradle / Ivy
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.ml.tensorflow
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import com.johnsnowlabs.nlp.annotators.common._
import org.tensorflow.ndarray.buffer.IntDataBuffer
import scala.collection.JavaConverters._
/** @param tensorflowWrapper
* XLNet Model wrapper with TensorFlow Wrapper
* @param spp
* XLNet SentencePiece model with SentencePieceWrapper
* @param configProtoBytes
* Configuration for TensorFlow session
* @param tags
* labels which model was trained with in order
* @param signatures
* TF v2 signatures in Spark NLP
*/
class TensorflowXlnetClassification(
val tensorflowWrapper: TensorflowWrapper,
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
signatures: Option[Map[String, String]] = None)
extends Serializable
with TensorflowForClassification {
val _tfXlnetSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
// keys representing the input and output tensors of the XLNet model
override protected val sentenceStartTokenId: Int = spp.getSppModel.pieceToId("")
override protected val sentenceEndTokenId: Int = spp.getSppModel.pieceToId("")
override protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("")
private val sentencePieceDelimiterId = spp.getSppModel.pieceToId("▁")
def tokenizeWithAlignment(
sentences: Seq[TokenizedSentence],
maxSeqLength: Int,
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {
val encoder =
new SentencepieceEncoder(spp, caseSensitive, delimiterId = sentencePieceDelimiterId)
val sentenceTokenPieces = sentences.map { s =>
val trimmedSentence = s.indexedTokens.take(maxSeqLength - 2)
val wordpieceTokens =
trimmedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
WordpieceTokenizedSentence(wordpieceTokens)
}
sentenceTokenPieces
}
def tokenizeDocument(
docs: Seq[Annotation],
maxSeqLength: Int,
caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {
Seq.empty[WordpieceTokenizedSentence]
}
def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)
batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}
val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner
val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)
runner
.feed(
_tfXlnetSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))
val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)
outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
.grouped(dim)
.map(scores => calculateSoftmax(scores))
.toArray
.grouped(maxSentenceLength)
.toArray
batchScores
}
def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
val tensors = new TensorResources()
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)
batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}
val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner
val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)
runner
.feed(
_tfXlnetSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfXlnetSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))
val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)
outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
.grouped(dim)
.map(scores =>
activation match {
case ActivationFunction.softmax => calculateSoftmax(scores)
case ActivationFunction.sigmoid => calculateSigmoid(scores)
case _ => calculateSoftmax(scores)
})
.toArray
batchScores
}
def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
(Array.empty[Array[Float]], Array.empty[Array[Float]])
}
def findIndexedToken(
tokenizedSentences: Seq[TokenizedSentence],
sentence: (WordpieceTokenizedSentence, Int),
tokenPiece: TokenPiece): Option[IndexedToken] = {
tokenizedSentences(sentence._2).indexedTokens.find(p =>
p.begin == tokenPiece.begin && tokenPiece.isWordStart)
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy