All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.seq2seq.TensorflowMarianEncoderDecoder.scala Maven / Gradle / Ivy

The newest version!
package com.johnsnowlabs.ml.ai.seq2seq

import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}

import scala.collection.JavaConverters._

private[johnsnowlabs] class TensorflowMarianEncoderDecoder(
    val tensorflow: TensorflowWrapper,
    override val sppSrc: SentencePieceWrapper,
    override val sppTrg: SentencePieceWrapper,
    configProtoBytes: Option[Array[Byte]] = None,
    signatures: Option[Map[String, String]] = None)
    extends MarianEncoderDecoder(sppSrc, sppTrg) {

  val _tfMarianSignatures: Map[String, String] =
    signatures.getOrElse(ModelSignatureManager.apply())

  sessionWarmup()

  override protected def tag(
      batch: Seq[Array[Int]],
      maxOutputLength: Int,
      paddingTokenId: Int,
      eosTokenId: Int,
      vocabSize: Int,
      doSample: Boolean = false,
      temperature: Double = 1.0d,
      topK: Int = 50,
      topP: Double = 1.0d,
      repetitionPenalty: Double = 1.0d,
      noRepeatNgramSize: Int = 0,
      randomSeed: Option[Long] = None,
      ignoreTokenIds: Array[Int] = Array()): Array[Array[Int]] = {

    /* Actual size of each sentence to skip padding in the TF model */
    val sequencesLength = batch.map(x => x.length).toArray
    val maxSentenceLength = sequencesLength.max

    // Run encoder
    val tensorEncoder = new TensorResources()
    val inputDim = batch.length * maxSentenceLength

    val encoderInputIdsBuffers = tensorEncoder.createIntBuffer(batch.length * maxSentenceLength)
    val encoderAttentionMaskBuffers =
      tensorEncoder.createIntBuffer(batch.length * maxSentenceLength)
    val decoderAttentionMaskBuffers =
      tensorEncoder.createIntBuffer(batch.length * maxSentenceLength)

    val shape = Array(batch.length.toLong, maxSentenceLength)

    batch.zipWithIndex.foreach { case (tokenIds, idx) =>
      // this one marks the beginning of each sentence in the flatten structure
      val offset = idx * maxSentenceLength
      val diff = maxSentenceLength - tokenIds.length

      val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(paddingTokenId)
      encoderInputIdsBuffers.offset(offset).write(s)
      val mask = s.map(x => if (x != paddingTokenId) 1 else 0)
      encoderAttentionMaskBuffers.offset(offset).write(mask)
      decoderAttentionMaskBuffers.offset(offset).write(mask)
    }

    val encoderInputIdsTensors =
      tensorEncoder.createIntBufferTensor(shape, encoderInputIdsBuffers)
    val encoderAttentionMaskKeyTensors =
      tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers)
    val decoderAttentionMaskTensors =
      tensorEncoder.createIntBufferTensor(shape, decoderAttentionMaskBuffers)

    val session = tensorflow.getTFSessionWithSignature(
      configProtoBytes = configProtoBytes,
      initAllTables = false,
      savedSignatures = signatures)
    val runner = session.runner

    runner
      .feed(
        _tfMarianSignatures
          .getOrElse(ModelSignatureConstants.EncoderInputIds.key, "missing_encoder_input_ids"),
        encoderInputIdsTensors)
      .feed(
        _tfMarianSignatures.getOrElse(
          ModelSignatureConstants.EncoderAttentionMask.key,
          "missing_encoder_attention_mask"),
        encoderAttentionMaskKeyTensors)
      .fetch(_tfMarianSignatures
        .getOrElse(ModelSignatureConstants.EncoderOutput.key, "missing_last_hidden_state"))

    val encoderOuts = runner.run().asScala
    val encoderOutsFloats = TensorResources.extractFloats(encoderOuts.head)
    val dim = encoderOutsFloats.length / inputDim
    val encoderOutsBatch =
      encoderOutsFloats.grouped(dim).toArray.grouped(maxSentenceLength).toArray

    encoderOuts.foreach(_.close())
    tensorEncoder.clearSession(encoderOuts)

    // Run decoder
    val decoderEncoderStateBuffers =
      tensorEncoder.createFloatBuffer(batch.length * maxSentenceLength * dim)
    batch.zipWithIndex.foreach { case (_, index) =>
      var offset = index * maxSentenceLength * dim
      encoderOutsBatch(index).foreach(encoderOutput => {
        decoderEncoderStateBuffers.offset(offset).write(encoderOutput)
        offset += dim
      })
    }

    val decoderEncoderStateTensors = tensorEncoder.createFloatBufferTensor(
      Array(batch.length.toLong, maxSentenceLength, dim),
      decoderEncoderStateBuffers)

    var decoderInputIds = batch.map(_ => Array(paddingTokenId)).toArray
    val batchSize = decoderInputIds.length

    val decoderProcessor = new DecoderProcessor(
      batchSize = batchSize,
      maxTextLength = maxOutputLength + maxSentenceLength,
      sequenceLength = decoderInputIds(0).length,
      doSample = doSample,
      topK = topK,
      topP = topP,
      temperature = temperature,
      vocabSize = vocabSize,
      noRepeatNgramSize = noRepeatNgramSize,
      randomSeed = randomSeed,
      stopTokens = Array(eosTokenId),
      paddingTokenId = paddingTokenId,
      ignoreTokenIds = ignoreTokenIds ++ Array(paddingTokenId),
      maxNewTokens = maxOutputLength,
      repetitionPenalty = repetitionPenalty)

    while (!decoderProcessor.stopDecoding(decoderInputIds)) {
      val decoderInputLength = decoderInputIds.head.length
      val tensorDecoder = new TensorResources()

      val decoderInputBuffers = tensorDecoder.createIntBuffer(batch.length * decoderInputLength)

      decoderInputIds.zipWithIndex.foreach { case (pieceIds, idx) =>
        val offset = idx * decoderInputLength
        decoderInputBuffers.offset(offset).write(pieceIds)
      }

      val decoderInputTensors = tensorDecoder.createIntBufferTensor(
        Array(batch.length.toLong, decoderInputLength),
        decoderInputBuffers)

      val runner = session.runner

      runner
        .feed(
          _tfMarianSignatures.getOrElse(
            ModelSignatureConstants.DecoderEncoderInputIds.key,
            "missing_encoder_state"),
          decoderEncoderStateTensors)
        .feed(
          _tfMarianSignatures
            .getOrElse(ModelSignatureConstants.DecoderInputIds.key, "missing_decoder_input_ids"),
          decoderInputTensors)
        .feed(
          _tfMarianSignatures.getOrElse(
            ModelSignatureConstants.DecoderAttentionMask.key,
            "missing_encoder_attention_mask"),
          decoderAttentionMaskTensors)
        .fetch(_tfMarianSignatures
          .getOrElse(ModelSignatureConstants.DecoderOutput.key, "missing_output_0"))

      val decoderOuts = runner.run().asScala

      val logitsRaw = TensorResources.extractFloats(decoderOuts.head)
      decoderOuts.head.close()

      val logits = (0 until batchSize).map(i => {
        logitsRaw
          .slice(
            i * decoderInputLength * vocabSize + (decoderInputLength - 1) * vocabSize,
            i * decoderInputLength * vocabSize + decoderInputLength * vocabSize)
      })

      decoderInputIds = decoderProcessor.processLogits(
        batchLogits = logits.toArray,
        decoderInputIds = decoderInputIds)
      decoderInputTensors.close()
      tensorDecoder.clearSession(decoderOuts)
      tensorDecoder.clearTensors()
    }

    decoderInputIds.map(x => x.filter(y => y != eosTokenId && y != paddingTokenId))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy