All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.NLLB.scala Maven / Gradle / Ivy

/*
 * Copyright 2017 - 2023  John Snow Labs
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
import com.johnsnowlabs.ml.onnx.OnnxSession
import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper.{
  EncoderDecoderWithoutPastWrappers => OpenvinoEncoderDecoderWithoutPastWrappers
}
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.intel.openvino.InferRequest
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

private[johnsnowlabs] class NLLB(
    val onnxWrappers: Option[EncoderDecoderWithoutPastWrappers],
    val openvinoWrapper: Option[OpenvinoEncoderDecoderWithoutPastWrappers],
    val spp: SentencePieceWrapper,
    generationConfig: GenerationConfig,
    vocab: Map[String, Int])
    extends Serializable
    with Generate {

  private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions
  private var nextPositionId: Option[Array[Long]] = None
  private var decoderEncoderStateTensorsOV: Option[org.intel.openvino.Tensor] = None
  private var encoderAttentionMaskTensorsOV: Option[org.intel.openvino.Tensor] = None

  val detectedEngine: String =
    if (onnxWrappers.isDefined) ONNX.name
    else if (openvinoWrapper.isDefined) Openvino.name
    else ONNX.name

  private val GenerationConfig(
    bosTokenId: Int,
    paddingTokenId: Int,
    eosTokenId: Int,
    vocabSize: Int,
    beginSuppressTokens,
    suppressTokenIds,
    forcedDecoderIds) =
    generationConfig

  private val pieceSize = spp.getSppModel.getPieceSize
  private val reverseVocab = vocab.map(_.swap)

  /** Decode a sequence of sentences
    * @param sentences
    *   Sequence of sentences
    * @return
    *   Sequence of decoded sentences
    */
  def decode(sentences: Array[Array[Int]]): Seq[String] = {
    sentences.map { s =>
      val filteredPieceIds = s.filter(x => x < (vocabSize - 205))
      val filteredPieces = filteredPieceIds.map(x => reverseVocab.getOrElse(x, ""))
      val sentence = spp.getSppModel.decodePieces(filteredPieces.toList.asJava)
      sentence
    }
  }

  /** Encode a sequence of sentences
    * @param sentences
    *   Sequence of sentences
    * @return
    *   Sequence of encoded sentences
    */
  def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
    val encodedPieces = sentences.map(s => {
      val sentWithTask = s.result
      spp.getSppModel.encodeAsPieces(sentWithTask).toArray.map(x => x.toString)
    })
    val encodedIds = encodedPieces.map(p => {
      p.map(x => vocab.getOrElse(x, 0))
    })
    encodedIds
  }

  /** Translates a batch of sentences from a source language to a target language
    * @param batch
    *   a batch of sentences to translate
    * @param minOutputLength
    *   minimum length of the output
    * @param maxOutputLength
    *   maximum length of the output
    * @param doSample
    *   whether to sample or not
    * @param temperature
    *   temperature for sampling
    * @param topK
    *   topK for sampling
    * @param topP
    *   topP for sampling
    * @param repetitionPenalty
    *   repetition penalty for sampling
    * @param noRepeatNgramSize
    *   no repeat ngram size for sampling
    * @param randomSeed
    *   random seed for sampling
    * @param ignoreTokenIds
    *   token ids to ignore
    * @param beamSize
    *   beam size for beam search
    * @param maxInputLength
    *   maximum length of the input
    * @param srcLangToken
    *   source language token
    * @param tgtLangToken
    *   target language token
    * @return
    */
  def tag(
      batch: Seq[Array[Int]],
      minOutputLength: Int,
      maxOutputLength: Int,
      doSample: Boolean,
      temperature: Double,
      topK: Int,
      topP: Double,
      repetitionPenalty: Double,
      noRepeatNgramSize: Int,
      randomSeed: Option[Long],
      ignoreTokenIds: Array[Int] = Array(),
      beamSize: Int,
      maxInputLength: Int,
      srcLangToken: Int,
      tgtLangToken: Int): Array[Array[Int]] = {

    val ignoreTokenIdsInt = ignoreTokenIds
    val expandedEncoderInputsVals =
      batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))).toArray
    val sequencesLength = expandedEncoderInputsVals.map(x => x.length)
    val maxSentenceLength = sequencesLength.max // - curLen

    expandedEncoderInputsVals.zipWithIndex.foreach { case (input, i) =>
      expandedEncoderInputsVals(i) =
        Array(vocabSize + srcLangToken - 205) ++ input ++ Array(eosTokenId)
    }

    val decoderInputIds: Array[Array[Int]] =
      batch.map(_ => Array(eosTokenId, vocabSize + tgtLangToken - 205)).toArray

    val numReturn_sequences = 1
    // from config

    var effectiveBatch_size = 1
    var effectiveBatch_mult = 1

    if (doSample) {
      effectiveBatch_size = expandedEncoderInputsVals.length * numReturn_sequences
      effectiveBatch_mult = numReturn_sequences
    } else {
      effectiveBatch_size = expandedEncoderInputsVals.length
      effectiveBatch_mult = 1
    }

    var decoderEncoderStateTensors: Either[Tensor, OnnxTensor] = null
    var encoderAttentionMaskTensors: Either[Tensor, OnnxTensor] = null

    var (encoderSession, encoderEnv): (OrtSession, OrtEnvironment) = (null, null)
    var (decoderSession, decoderEnv): (OrtSession, OrtEnvironment) = (null, null)
    val ovInferRequest: Option[InferRequest] = detectedEngine match {
      case ONNX.name => None
      case Openvino.name =>
        Some(openvinoWrapper.get.decoder.getCompiledModel().create_infer_request())
    }

    if (detectedEngine == TensorFlow.name) {
      // not implemented yet
      return Array()
    } else if (detectedEngine == ONNX.name) {
      val (_encoderSession, _encoderEnv) = onnxWrappers.get.encoder.getSession(onnxSessionOptions)
      val (_decoderSession, _decoderEnv) = onnxWrappers.get.decoder.getSession(onnxSessionOptions)

      encoderSession = _encoderSession
      encoderEnv = _encoderEnv
      decoderSession = _decoderSession
      decoderEnv = _decoderEnv

      // run encoder
      decoderEncoderStateTensors =
        getEncoderOutput(expandedEncoderInputsVals, Right((encoderEnv, encoderSession)))

      encoderAttentionMaskTensors = Right(
        OnnxTensor
          .createTensor(decoderEnv, expandedEncoderInputsVals.toArray.map(_.map(_ => 1L))))

    } else if (detectedEngine == Openvino.name) {
      val encoderInferRequest =
        openvinoWrapper.get.encoder.getCompiledModel().create_infer_request()
      decoderEncoderStateTensorsOV = Some(
        getEncoderOutputOv(expandedEncoderInputsVals, encoderInferRequest))

      encoderAttentionMaskTensorsOV = Some(
        new org.intel.openvino.Tensor(
          Array(expandedEncoderInputsVals.length, expandedEncoderInputsVals.head.length),
          expandedEncoderInputsVals.flatMap { tokenIds => tokenIds.map(_ => 1L) }))

    }
    // output with beam search
    val modelOutputs = generate(
      batch,
      decoderEncoderStateTensors,
      encoderAttentionMaskTensors,
      decoderInputIds,
      maxOutputLength + maxSentenceLength,
      minOutputLength,
      doSample,
      beamSize,
      1,
      temperature,
      topK,
      topP,
      repetitionPenalty,
      noRepeatNgramSize,
      this.vocabSize,
      this.eosTokenId,
      this.paddingTokenId,
      randomSeed,
      ignoreTokenIdsInt,
      Right((decoderEnv, decoderSession)),
      applySoftmax = false,
      ovInferRequest = ovInferRequest)

    // Run the prompt through the decoder and get the past
    //    val decoderOutputs =
    //      generateGreedyOnnx(
    //        decoderInputIds,
    //        decoderEncoderStateTensors,
    //        encoderAttentionMaskTensors,
    //        onnxSession = (decoderSession, decoderEnv))

    if (detectedEngine == ONNX.name) {
      // close sessions
      decoderEncoderStateTensors.fold(
        tfTensor => {
          // not implemented yet
        },
        onnxTensor => onnxTensor.close())

      encoderAttentionMaskTensors.fold(
        tfTensor => {
          // not implemented yet
        },
        onnxTensor => onnxTensor.close())

      encoderEnv.close()
      decoderEnv.close()
    }

    //    decoderOutputs
    modelOutputs
  }

  /** Translates a batch of sentences from a source language to a target language
    * @param sentences
    *   a batch of sentences to translate
    * @param batchSize
    *   batch size
    * @param minOutputLength
    *   minimum length of the output
    * @param maxOutputLength
    *   maximum length of the output
    * @param doSample
    *   whether to sample or not
    * @param temperature
    *   temperature for sampling
    * @param topK
    *   topK for sampling
    * @param topP
    *   topP for sampling
    * @param repetitionPenalty
    *   repetition penalty for sampling
    * @param noRepeatNgramSize
    *   no repeat ngram size for sampling
    * @param randomSeed
    *   random seed for sampling
    * @param ignoreTokenIds
    *   token ids to ignore
    * @param beamSize
    *   beam size for beam search
    * @param maxInputLength
    *   maximum length of the input
    * @param srcLangToken
    *   source language token
    * @param tgtLangToken
    * @return
    */
  def predict(
      sentences: Seq[Annotation],
      batchSize: Int,
      minOutputLength: Int,
      maxOutputLength: Int,
      doSample: Boolean,
      temperature: Double,
      topK: Int,
      topP: Double,
      repetitionPenalty: Double,
      noRepeatNgramSize: Int,
      randomSeed: Option[Long] = None,
      ignoreTokenIds: Array[Int] = Array(),
      beamSize: Int,
      maxInputLength: Int,
      srcLangToken: Int,
      tgtLangToken: Int): Seq[Annotation] = {

    val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
      val batchSP = encode(batch)
      val spIds = tag(
        batchSP,
        minOutputLength,
        maxOutputLength,
        doSample,
        temperature,
        topK,
        topP,
        repetitionPenalty,
        noRepeatNgramSize,
        randomSeed,
        ignoreTokenIds,
        beamSize,
        maxInputLength,
        srcLangToken,
        tgtLangToken)
      decode(spIds)

    }

    var sentBegin, nextSentEnd = 0
    val annotations = batchDecoder.zip(sentences).map { case (content, sent) =>
      nextSentEnd += content.length - 1
      val annots = new Annotation(
        annotatorType = DOCUMENT,
        begin = sentBegin,
        end = nextSentEnd,
        result = content,
        metadata = sent.metadata)
      sentBegin += nextSentEnd + 1
      annots
    }
    annotations
  }

  /** Generates a sequence of tokens using beam search
    * @param encoderInputIds
    *   Input IDs for the Encoder
    * @param session
    *   Tensorflow/ONNX Session
    * @return
    *   Last hidden state of the encoder
    */
  private def getEncoderOutput(
      encoderInputIds: Seq[Array[Int]],
      session: Either[Session, (OrtEnvironment, OrtSession)]): Either[Tensor, OnnxTensor] = {
    session.fold(
      tfSession => {
        // not implemented yet
        null
      },
      onnxSession => {

        val (env, encoderSession) = onnxSession

        val encoderAttentionMask: OnnxTensor =
          OnnxTensor.createTensor(env, encoderInputIds.toArray.map(_.map(_ => 1L)))

        val encoderInputTensors: OnnxTensor =
          OnnxTensor.createTensor(env, encoderInputIds.toArray.map(_.map(_.toLong)))

        val encoderInputs: java.util.Map[String, OnnxTensor] = Map(
          OnnxSignatures.encoderInputIDs -> encoderInputTensors,
          OnnxSignatures.encoderAttentionMask -> encoderAttentionMask).asJava

        val encoderResults = encoderSession.run(encoderInputs)

        val encoderStateBuffer =
          try {
            val encoderStateTensor = encoderResults
              .get(OnnxSignatures.encoderOutput)
              .get()
              .asInstanceOf[OnnxTensor]

            val shape = encoderStateTensor.getInfo.getShape
            encoderStateTensor.getFloatBuffer
              .array()
              .grouped(shape(2).toInt)
              .toArray
              .grouped(shape(1).toInt)
              .toArray
          } finally {
            if (encoderResults != null) encoderResults.close()
          }

        encoderInputTensors.close()
        encoderAttentionMask.close()

        val encoderStateTensors = OnnxTensor.createTensor(env, encoderStateBuffer)

        Right(encoderStateTensors)
      })
  }

  private def getEncoderOutputOv(
      encoderInputIds: Seq[Array[Int]],
      inferRequest: InferRequest): org.intel.openvino.Tensor = {

    val encoderAttentionMask: Array[Long] =
      encoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }(collection.breakOut)
    val encoderAttentionMaskTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(
        Array(encoderInputIds.length, encoderInputIds.head.length),
        encoderAttentionMask)

    val encoderInputIdsLong: Array[Long] =
      encoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }(collection.breakOut)

    val encoderInputIdsLongTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(
        Array(encoderInputIds.length, encoderInputIds.head.length),
        encoderInputIdsLong)

    inferRequest.set_tensor(OpenVinoSignatures.encoderInputIDs, encoderInputIdsLongTensor)
    inferRequest.set_tensor(OpenVinoSignatures.encoderAttentionMask, encoderAttentionMaskTensor)

    inferRequest.infer()

    val result = inferRequest.get_tensor(OpenVinoSignatures.encoderOutput)
    result
  }

  /** Gets the model output
    * @param encoderInputIds
    *   Input IDs for the Encoder
    * @param decoderInputIds
    *   Input IDs for the Decoder
    * @param decoderEncoderStateTensors
    *   Tensor of encoded input for the decoder
    * @param encoderAttentionMaskTensors
    *   Tensor for encoder attention mask
    * @param maxLength
    *   Max length of the input
    * @param session
    *   Tensorflow/ONNX Session
    * @return
    *   Logits for the input
    */
  override def getModelOutput(
      encoderInputIds: Seq[Array[Int]],
      decoderInputIds: Seq[Array[Int]],
      decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
      encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
      maxLength: Int,
      session: Either[Session, (OrtEnvironment, OrtSession)],
      ovInferRequest: Option[InferRequest]): Array[Array[Float]] = {

    //    session.fold(
    //      tfSession => {
    //        // not implemented yet
    //        Array()
    //      },
    //      onnxSession => {
    //        val (env, decoderSession) = onnxSession
    //        val decoderOutputs =
    //          getDecoderOutputs(
    //            decoderInputIds.toArray,
    //            decoderEncoderStateTensors,
    //            encoderAttentionMaskTensors,
    //            onnxSession = (decoderSession, env))
    //        decoderOutputs
    //      })
    detectedEngine match {
      case TensorFlow.name =>
        // not implemented yet
        Array()
      case ONNX.name =>
        val (env, decoderSession) = session.right.get
        val decoderOutputs =
          getDecoderOutputs(
            decoderInputIds.toArray,
            decoderEncoderStateTensors,
            encoderAttentionMaskTensors,
            onnxSession = (decoderSession, env))
        decoderOutputs
      case Openvino.name =>
        val decoderOutputs =
          getDecoderOutputsOv(
            decoderInputIds.toArray,
            decoderEncoderStateTensorsOV.get,
            encoderAttentionMaskTensorsOV.get,
            ovInferRequest.get)
        decoderOutputs
    }

  }

  /** Gets the decoder outputs
    * @param inputIds
    *   input ids
    * @param decoderEncoderStateTensors
    *   decoder encoder state tensors
    * @param encoderAttentionMaskTensors
    *   encoder attention mask tensors
    * @param onnxSession
    *   onnx session
    * @return
    *   decoder outputs
    */
  private def getDecoderOutputs(
      inputIds: Array[Array[Int]],
      decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
      encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
      onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = {
    val (session, env) = onnxSession

    val inputIdsLong: Array[Array[Long]] =
      inputIds.map { tokenIds => tokenIds.map(_.toLong) }

    val inputIdsLongTensor: OnnxTensor =
      OnnxTensor.createTensor(env, inputIdsLong)

    val encoderAttentionMaskTensor = encoderAttentionMaskTensors.fold(
      tfTensor => {
        // not implemented yet
        null
      },
      onnxTensor => onnxTensor)

    val decoderEncoderStateTensor = decoderEncoderStateTensors.fold(
      tfTensor => {
        // not implemented yet
        null
      },
      onnxTensor => onnxTensor)

    val decoderInputs: java.util.Map[String, OnnxTensor] = Map(
      OnnxSignatures.decoderInputIDs -> inputIdsLongTensor,
      OnnxSignatures.decoderEncoderAttentionMask -> encoderAttentionMaskTensor,
      OnnxSignatures.decoderEncoderState -> decoderEncoderStateTensor).asJava
    val sessionOutput = session.run(decoderInputs)

    val sequenceLength = inputIds.head.length
    val batchSize = inputIds.length

    val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
    val decoderOutputs = (0 until batchSize).map(i => {
      logitsRaw
        .slice(
          i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
          i * sequenceLength * vocabSize + sequenceLength * vocabSize)
    })
    decoderOutputs.toArray
  }

  private def getDecoderOutputsOv(
      inputIds: Array[Array[Int]],
      decoderEncoderStateTensors: org.intel.openvino.Tensor,
      encoderAttentionMaskTensors: org.intel.openvino.Tensor,
      inferRequest: InferRequest): (Array[Array[Float]]) = {
    val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
      if (nextPositionId.isDefined) {
        val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
        (inpIdsLong, nextPositionId.get)
      } else {
        val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
        val posIdsLong = inputIds.flatMap { tokenIds =>
          tokenIds.zipWithIndex.map { case (_, i) =>
            i.toLong
          }
        }
        (inpIdsLong, posIdsLong)
      }

    val batchSize: Int = inputIds.length
    val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

    val inputIdsLongTensor: org.intel.openvino.Tensor =
      new org.intel.openvino.Tensor(shape, inputIdsLong)

    inferRequest.set_tensor("input_ids", inputIdsLongTensor)
    inferRequest.set_tensor("encoder_hidden_states", decoderEncoderStateTensors)
    inferRequest.set_tensor("encoder_attention_mask", encoderAttentionMaskTensors)

    inferRequest.infer()

    val result = inferRequest.get_tensor("logits")
    val logitsRaw = result.data()
    nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong))

    val sequenceLength = inputIdsLong.length / batchSize
    val decoderOutputs = (0 until batchSize).map(i => {
      logitsRaw
        .slice(
          i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
          i * sequenceLength * vocabSize + sequenceLength * vocabSize)
    })
    decoderOutputs.toArray
  }

  /** Gets the index with the highest score
    *
    * @param scores
    *   Array of Scores to max
    * @return
    *   Index of the highest score
    */
  private def argmax(scores: Array[Float]): Int =
    scores.zipWithIndex.maxBy { case (score, _) =>
      score
    }._2
  private def greedyGenerationFinished(
      decoderIds: Seq[Array[Int]],
      eosTokenId: Int,
      maxOutputLength: Int): Boolean =
    decoderIds.map(_.last).forall(_ == eosTokenId) || decoderIds.head.length == maxOutputLength

  private def generateGreedyOnnx(
      decoderInputIds: Array[Array[Int]],
      decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
      encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
      session: Either[Session, (OrtEnvironment, OrtSession)]): (Array[Array[Int]]) = {

    val sequencesLength = decoderInputIds.map(x => x.length).toArray
    val maxSentenceLength = sequencesLength.max // - curLen
    var generatedIds: Array[Array[Int]] = Array()

    while (!greedyGenerationFinished(generatedIds, eosTokenId, maxSentenceLength)) {

      session.fold(
        tfSession => {
          // not implemented yet
          Array()
        },
        onnxSession => {
          val (env, decoderSession) = onnxSession
          val decoderOutputs =
            getDecoderOutputs(
              decoderInputIds.toArray,
              decoderEncoderStateTensors,
              encoderAttentionMaskTensors,
              onnxSession = (decoderSession, env))

          val nextTokenIds: Array[Int] = decoderOutputs.map(argmax)
          generatedIds =
            generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) =>
              currentIds ++ Array(nextId)
            }
        })
    }
    generatedIds
  }

  private object OnnxSignatures {
    val encoderInputIDs: String = "input_ids"
    val encoderAttentionMask: String = "attention_mask"

    val encoderOutput: String = "last_hidden_state"

    val decoderInputIDs: String = "input_ids"
    val decoderEncoderAttentionMask: String = "encoder_attention_mask"
    val decoderEncoderState: String = "encoder_hidden_states"

    val decoderOutput: String = "logits"
  }

  private object OpenVinoSignatures {
    val encoderInputIDs: String = "input_ids"
    val encoderAttentionMask: String = "attention_mask"

    val encoderOutput: String = "last_hidden_state"

    val decoderInputIDs: String = "input_ids"
    val decoderEncoderAttentionMask: String = "encoder_attention_mask"
    val decoderAttentionMask: String = "attention_mask"
    val decoderEncoderState: String = "encoder_hidden_states"

    val decoderOutput: String = "logits"
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy