All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.Albert.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConverters._

/** This class is used to calculate ALBERT embeddings for For Sequence Batches of
  * WordpieceTokenizedSentence. Input for this model must be tokenzied with a SentencePieceModel,
  *
  * This Tensorflow model is using the weights provided by https://tfhub.dev/google/albert_base/3
  * * sequence_output: representations of every token in the input sequence with shape
  * [batch_size, max_sequence_length, hidden_size].
  *
  * ALBERT: A LITE BERT FOR SELF-SUPERVISED LEARNING OF LANGUAGE REPRESENTATIONS - Google
  * Research, Toyota Technological Institute at Chicago This these embeddings represent the
  * outputs generated by the Albert model. All offical Albert releases by google in TF-HUB are
  * supported with this Albert Wrapper:
  *
  * TF-HUB Models : albert_base = https://tfhub.dev/google/albert_base/3 | 768-embed-dim,
  * 12-layer, 12-heads, 12M parameters albert_large = https://tfhub.dev/google/albert_large/3 |
  * 1024-embed-dim, 24-layer, 16-heads, 18M parameters albert_xlarge =
  * https://tfhub.dev/google/albert_xlarge/3 | 2048-embed-dim, 24-layer, 32-heads, 60M parameters
  * albert_xxlarge = https://tfhub.dev/google/albert_xxlarge/3 | 4096-embed-dim, 12-layer,
  * 64-heads, 235M parameters
  *
  * This model requires input tokenization with SentencePiece model, which is provided by Spark
  * NLP
  *
  * For additional information see : https://arxiv.org/pdf/1909.11942.pdf
  * https://github.com/google-research/ALBERT https://tfhub.dev/s?q=albert
  *
  * Tips:
  *
  * ALBERT uses repeating layers which results in a small memory footprint, however the
  * computational cost remains similar to a BERT-like architecture with the same number of hidden
  * layers as it has to iterate through the same number of (repeating) layers.
  *
  * @param tensorflowWrapper
  *   Albert Model wrapper with TensorFlowWrapper
  * @param spp
  *   Albert SentencePiece model with SentencePieceWrapper
  * @param batchSize
  *   size of batch
  * @param configProtoBytes
  *   Configuration for TensorFlow session
  */
private[johnsnowlabs] class Albert(
    val tensorflowWrapper: Option[TensorflowWrapper],
    val onnxWrapper: Option[OnnxWrapper],
    val spp: SentencePieceWrapper,
    batchSize: Int,
    configProtoBytes: Option[Array[Byte]] = None,
    signatures: Option[Map[String, String]] = None)
    extends Serializable {

  val _tfAlbertSignatures: Map[String, String] =
    signatures.getOrElse(ModelSignatureManager.apply())

  val detectedEngine: String =
    if (tensorflowWrapper.isDefined) TensorFlow.name
    else if (onnxWrapper.isDefined) ONNX.name
    else TensorFlow.name
  private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

  // keys representing the input and output tensors of the ALBERT model
  private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]")
  private val SentenceEndTokenId = spp.getSppModel.pieceToId("[SEP]")
  private val SentencePadTokenId = spp.getSppModel.pieceToId("[pad]")
  private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("▁")

  protected val logger: Logger = LoggerFactory.getLogger("Albert")

  private def sessionWarmup(): Unit = {
    val dummyInput =
      Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102)
    tag(Seq(dummyInput))
  }

  sessionWarmup()

  def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {

    val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
    val batchLength = batch.length

    val embeddings = detectedEngine match {

      case ONNX.name =>
        // [nb of encoded sentences , maxSentenceLength]
        val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

        val tokenTensors =
          OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)

        val maskTensors =
          OnnxTensor.createTensor(
            env,
            batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

        val segmentTensors =
          OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

        val inputs =
          Map(
            "input_ids" -> tokenTensors,
            "attention_mask" -> maskTensors,
            "token_type_ids" -> segmentTensors).asJava

        // TODO:  A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
        try {
          val results = runner.run(inputs)
          try {
            val embeddings = results
              .get("last_hidden_state")
              .get()
              .asInstanceOf[OnnxTensor]
              .getFloatBuffer
              .array()

            embeddings
          } finally if (results != null) results.close()
        } catch {
          case e: Exception =>
            // Handle exceptions by logging or other means.
            e.printStackTrace()
            Array.empty[Float] // Return an empty array or appropriate error handling
        } finally {
          // Close tensors outside the try-catch to avoid repeated null checks.
          // These resources are initialized before the try-catch, so they should be closed here.
          tokenTensors.close()
          maskTensors.close()
          segmentTensors.close()
        }
      case _ =>
        val tensors = new TensorResources()

        val (tokenTensors, maskTensors, segmentTensors) =
          PrepareEmbeddings.prepareBatchTensorsWithSegment(
            tensors = tensors,
            batch = batch,
            maxSentenceLength = maxSentenceLength,
            batchLength = batchLength)

        val runner = tensorflowWrapper.get
          .getTFSessionWithSignature(
            configProtoBytes = configProtoBytes,
            savedSignatures = signatures,
            initAllTables = false)
          .runner

        runner
          .feed(
            _tfAlbertSignatures.getOrElse(
              ModelSignatureConstants.InputIdsV1.key,
              "missing_input_id_key"),
            tokenTensors)
          .feed(
            _tfAlbertSignatures
              .getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
            maskTensors)
          .feed(
            _tfAlbertSignatures
              .getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
            segmentTensors)
          .fetch(
            _tfAlbertSignatures
              .getOrElse(
                ModelSignatureConstants.LastHiddenStateV1.key,
                "missing_sequence_output_key"))

        val outs = runner.run().asScala
        val embeddings = TensorResources.extractFloats(outs.head)

        tokenTensors.close()
        maskTensors.close()
        segmentTensors.close()
        tensors.clearSession(outs)
        tensors.clearTensors()

        embeddings
    }

    PrepareEmbeddings.prepareBatchWordEmbeddings(
      batch,
      embeddings,
      maxSentenceLength,
      batchLength)

  }

  def predict(
      tokenizedSentences: Seq[TokenizedSentence],
      batchSize: Int,
      maxSentenceLength: Int,
      caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = {

    val wordPieceTokenizedSentences =
      tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive)
    wordPieceTokenizedSentences.zipWithIndex
      .grouped(batchSize)
      .flatMap { batch =>
        val encoded = PrepareEmbeddings.prepareBatchInputsWithPadding(
          batch,
          maxSentenceLength,
          SentenceStartTokenId,
          SentenceEndTokenId,
          SentencePadTokenId)
        val vectors = tag(encoded)

        /*Combine tokens and calculated embeddings*/
        batch.zip(vectors).map { case (sentence, tokenVectors) =>
          val tokenLength = sentence._1.tokens.length
          /*All wordpiece embeddings*/
          val tokenEmbeddings = tokenVectors.slice(1, tokenLength + 1)
          val originalIndexedTokens = tokenizedSentences(sentence._2)

          val tokensWithEmbeddings =
            sentence._1.tokens.zip(tokenEmbeddings).flatMap { case (token, tokenEmbedding) =>
              val tokenWithEmbeddings = TokenPieceEmbeddings(token, tokenEmbedding)
              val originalTokensWithEmbeddings = originalIndexedTokens.indexedTokens
                .find(p =>
                  p.begin == tokenWithEmbeddings.begin && tokenWithEmbeddings.isWordStart)
                .map { token =>
                  val originalTokenWithEmbedding = TokenPieceEmbeddings(
                    TokenPiece(
                      wordpiece = tokenWithEmbeddings.wordpiece,
                      token = if (caseSensitive) token.token else token.token.toLowerCase(),
                      pieceId = tokenWithEmbeddings.pieceId,
                      isWordStart = tokenWithEmbeddings.isWordStart,
                      begin = token.begin,
                      end = token.end),
                    tokenEmbedding)
                  originalTokenWithEmbedding
                }
              originalTokensWithEmbeddings
            }

          WordpieceEmbeddingsSentence(tokensWithEmbeddings, originalIndexedTokens.sentenceIndex)
        }
      }
      .toSeq
  }

  def tokenizeWithAlignment(
      sentences: Seq[TokenizedSentence],
      maxSeqLength: Int,
      caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = {
    val encoder =
      new SentencepieceEncoder(spp, caseSensitive, delimiterId = SentencePieceDelimiterId)

    val sentenceTokenPieces = sentences.map { s =>
      val trimmedSentence = s.indexedTokens.take(maxSeqLength - 2)
      val wordpieceTokens =
        trimmedSentence.flatMap(token => encoder.encode(token)).take(maxSeqLength)
      WordpieceTokenizedSentence(wordpieceTokens)
    }
    sentenceTokenPieces
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy