All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.ai.t5.T5EncoderDecoder.scala Maven / Gradle / Ivy

package com.johnsnowlabs.ml.ai.t5

import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.mutable
import scala.math.exp

abstract class T5EncoderDecoder(
    val spp: SentencePieceWrapper,
    val additionalTokens: Map[Int, String] = Map()) {

  protected val paddingTokenId = 0
  protected val eosTokenId = 1
  protected val pieceSize: Int = spp.getSppModel.getPieceSize
  protected val vocabSize = 32128

  def sessionWarmup(): Unit = {
    val dummyInput = Array.fill(1)(0) ++ Array(eosTokenId)

    tag(
      Seq(dummyInput),
      maxNewTokens = 1,
      maxTextLength = 1,
      doSample = false,
      temperature = 0f,
      topK = 0,
      topP = 0f,
      repetitionPenalty = 0f,
      noRepeatNgramSize = 0,
      randomSeed = Option(0L),
      stopAtEos = true,
      ignoreTokenIds = Array(0))
  }

  protected def decode(sentences: Array[Array[Int]]): Seq[String] = {

    sentences.map { s =>
      val filteredPieceIds = s.filter(x => x <= pieceSize || additionalTokens.contains(x))
      val additionalTokenPositions =
        filteredPieceIds.zipWithIndex.filter(x => additionalTokens.contains(x._1)).map(_._2)
      val decodedStrings = if (additionalTokenPositions.nonEmpty) {
        var offset = 0
        additionalTokenPositions.map(i => {
          val slice = spp.getSppModel.decodeIds(
            filteredPieceIds.slice(offset, i).map(_.toInt): _*) + additionalTokens(
            filteredPieceIds(i))
          offset = i + 1
          slice
        }) ++ Array(
          spp.getSppModel.decodeIds(
            filteredPieceIds.slice(offset, filteredPieceIds.length).map(_.toInt): _*))
      } else {
        Array(spp.getSppModel.decodeIds(filteredPieceIds.map(_.toInt): _*))
      }
      decodedStrings.mkString("")
    }

  }

  protected def encode(sentences: Seq[Annotation], task: String): Seq[Array[Int]] = {
    sentences.map(s => {
      val sentWithTask = if (task.nonEmpty) task.concat(" ").concat(s.result) else s.result
      spp.getSppModel.encodeAsIds(sentWithTask) ++ Array(this.eosTokenId)
    })
  }

  protected def encodeS(sentences: Seq[String], task: String): Seq[Array[Int]] = {
    sentences.map(s => {
      val sentWithTask = if (task.nonEmpty) task.concat(" ").concat(s) else s
      spp.getSppModel.encodeAsIds(sentWithTask) ++ Array(this.eosTokenId)
    })
  }

  protected def getGeneratedNgrams(
      prevInputIds: Seq[Array[Int]],
      generatedNgrams: Array[mutable.Map[IndexedSeq[Int], List[Int]]],
      hypoIdx: Int,
      curLen: Int,
      noRepeatNgramSize: Int): Array[Int] = {
    // Before decoding the next token, prevent decoding of ngrams that have already appeared
    val startIdx = curLen + 1 - noRepeatNgramSize
    val ngramIdx = prevInputIds(hypoIdx).slice(startIdx, curLen)
    generatedNgrams(hypoIdx).getOrElse(ngramIdx, List.empty[Int]).toArray
  }

  protected def setTensorByIndicesToValue(
      prevInputIds: Array[Float],
      indices: IndexedSeq[Boolean],
      value: Float): Array[Float] = {
    for ((inputId, index) <- prevInputIds.zip(indices)) yield if (index) value else inputId
  }

  protected def calcBannedNgramTokens(
      prevInputIds: Seq[Array[Int]],
      numHypos: Int,
      noRepeatNgramSize: Int,
      curLen: Int): Array[Array[Int]] = {
    // based on fairseq for noRepeatNgram in beam_search
    if (curLen + 1 < noRepeatNgramSize)
      // return no banned tokens if we haven't generated noRepeatNgram_size tokens yet
      return Array.ofDim[Int](numHypos, 0)
    val generatedNgrams =
      Array.tabulate(numHypos)(_ => mutable.Map.empty[IndexedSeq[Int], List[Int]])
    for (idx <- 0 until numHypos) {
      val genTokens = prevInputIds(idx)
      val generatedNgram = generatedNgrams(idx)
      val ngramArrays = for (e <- 0 until noRepeatNgramSize) yield genTokens.drop(e)
      for (ngramInd <- ngramArrays.last.indices) {
        val ngram = for (e <- ngramArrays) yield e(ngramInd)
        val prevNgramTuple = ngram.dropRight(1)
        generatedNgram(prevNgramTuple) =
          generatedNgram.getOrElse(prevNgramTuple, List.empty[Int]) :+ ngram.last
      }
    }
    (for (hypoIdx <- 0 until numHypos)
      yield getGeneratedNgrams(
        prevInputIds,
        generatedNgrams,
        hypoIdx,
        curLen,
        noRepeatNgramSize)).toArray
  }

  protected def tag(
      batch: Seq[Array[Int]],
      maxNewTokens: Int,
      maxTextLength: Int,
      doSample: Boolean,
      topK: Int,
      topP: Double,
      temperature: Double,
      noRepeatNgramSize: Int,
      repetitionPenalty: Double,
      randomSeed: Option[Long],
      ignoreTokenIds: Array[Int] = Array(),
      stopAtEos: Boolean): Array[Array[Int]]

  def predict(
      sentences: Seq[Annotation],
      task: String,
      batchSize: Int,
      maxNewTokens: Int,
      maxTextLength: Int,
      doSample: Boolean,
      topK: Int,
      topP: Double,
      temperature: Double,
      randomSeed: Option[Long] = None,
      ignoreTokenIds: Array[Int] = Array(),
      isCaseSensitive: Boolean,
      stopAtEos: Boolean,
      noRepeatNgramSize: Int,
      repetitionPenalty: Double): Seq[Annotation] = {

    val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
      val batchSP = encode(batch, task)
      val spIds = tag(
        batch = batchSP,
        maxNewTokens = maxNewTokens,
        maxTextLength = maxTextLength,
        doSample = doSample,
        topK = topK,
        topP = topP,
        temperature = temperature,
        randomSeed = randomSeed,
        ignoreTokenIds = ignoreTokenIds,
        stopAtEos = stopAtEos,
        noRepeatNgramSize = noRepeatNgramSize,
        repetitionPenalty = repetitionPenalty)
      decode(spIds)

    }

    var sentBegin, nextSentEnd = 0
    batchDecoder.zip(sentences).map { case (content, sent) =>
      nextSentEnd += content.length - 1
      val newAnnotation = new Annotation(
        annotatorType = AnnotatorType.DOCUMENT,
        begin = sentBegin,
        end = nextSentEnd,
        result = content,
        metadata = sent.metadata)
      sentBegin += nextSentEnd + 1
      newAnnotation
    }
  }

  def generate(
      prompts: Seq[Annotation],
      batchSize: Int,
      maxNewTokens: Int,
      maxContextLength: Int,
      doSample: Boolean,
      topK: Int,
      topP: Double,
      temperature: Double,
      randomSeed: Option[Long],
      ignoreTokenIds: Array[Int],
      isCaseSensitive: Boolean,
      stopAtEos: Boolean,
      noRepeatNgramSize: Int,
      repetitionPenalty: Double): Seq[Annotation] = {
    predict(
      sentences = prompts,
      task = "",
      batchSize = batchSize,
      maxNewTokens = maxNewTokens,
      maxTextLength = maxContextLength,
      doSample = doSample,
      topK = topK,
      topP = topP,
      temperature = temperature,
      randomSeed = randomSeed,
      ignoreTokenIds = ignoreTokenIds,
      isCaseSensitive = isCaseSensitive,
      stopAtEos = stopAtEos,
      noRepeatNgramSize = noRepeatNgramSize,
      repetitionPenalty = repetitionPenalty)
  }

  class DecoderProcessor(
      val batchSize: Int,
      val maxTextLength: Int,
      val sequenceLength: Int,
      val doSample: Boolean,
      val topK: Int,
      val topP: Double,
      val temperature: Double,
      val vocabSize: Int,
      val noRepeatNgramSize: Int,
      val repetitionPenalty: Double,
      val randomSeed: Option[Long],
      val stopTokens: Array[Int],
      val ignoreTokenIds: Array[Int],
      val maxNewTokens: Int) {
    var unfinishedSentences: List[Int] = List.fill(batchSize)(1)
    var sentenceLengths: List[Int] = List.fill(batchSize)(maxTextLength)
    var currentLength = sequenceLength
    var nPredictedTokens = 0

    def stopDecoding(decoderInputIds: Array[Array[Int]]): Boolean = {
      // stop when there is a eos in each sentence, or if we exceed the maximum length
      //      stopDecoder = curLen < maxOutputLength || unfinishedSents.max == 0

      (decoderInputIds.forall(o => o exists (t => stopTokens.contains(t)))
      || (nPredictedTokens >= maxNewTokens)
      || (decoderInputIds.head.length > maxTextLength))
    }

    def stopDecoding(decoderInputIds: Array[Array[Long]]): Boolean = {
      stopDecoding(decoderInputIds.map(x => x.map(_.toInt)))
    }

    def processLogits(
        batchLogits: Array[Array[Float]],
        decoderInputIds: Array[Array[Long]]): Array[Array[Long]] = {
      processLogits(batchLogits, decoderInputIds.map(x => x.map(_.toInt))).map(x =>
        x.map(_.toLong))
    }

    def createNextTokenLogitsPenalties(
        inputIds: Seq[Array[Int]],
        logits: Array[Array[Float]],
        repetitionPenalty: Double): Array[Array[Float]] = {
      // create logit penalties for already seen inputIds
      val nextTokenLogits = Array.ofDim[Array[Float]](logits.length)

      for (i <- logits.indices) {
        var nextTokenLogit = logits(i)
        val prevInputIds = inputIds.head.distinct
        for ((prevInputId, _) <- prevInputIds.zipWithIndex) {
          var logitPenalty = 1.0
          if (logits(i)(prevInputId) < 0) {
            logitPenalty = repetitionPenalty
          } else {
            logitPenalty = 1 / repetitionPenalty
          }
          nextTokenLogit = nextTokenLogit.updated(
            prevInputId,
            (logitPenalty * nextTokenLogit(prevInputId)).toFloat)
        }
        nextTokenLogits(i) = nextTokenLogit
      }
      nextTokenLogits
    }

    private def softmax(values: Array[Float]): Array[Float] = {
      val expElem = values.map(exp(_))
      val total = expElem.sum
      expElem.map(_ / total).map(_.toFloat)
    }

    private def categoricalSample(dist: Array[Float], randomSeed: Option[Long]): Int = {
      val (distFiltered, indices) =
        dist.zipWithIndex.filter { case (elem, index) => !elem.isInfinite }.sorted.unzip

      if (distFiltered.length == 1)
        return indices(0)

      //    val distMinValue = distFiltered.min
      //    val distRange = distFiltered.max - distMinValue
      //    val normalized = distFiltered.map(i => (i - distMinValue)/distRange)
      val normalized = softmax(distFiltered)

      var randomDouble = 0.0
      if (randomSeed.isDefined)
        randomDouble = new scala.util.Random(randomSeed.get).nextDouble()
      else
        randomDouble = scala.util.Random.nextDouble()

      var accum = 0.0
      for ((itemProb, i) <- normalized.zip(indices)) {
        accum += itemProb
        if (accum >= randomDouble) {
          return i
        }
      }
      indices(0)
    }

    private def scanLeft[a, b](xs: Iterable[a])(s: b)(f: (b, a) => b) =
      xs.foldLeft(List(s))((acc, x) => f(acc.head, x) :: acc).reverse

    private def scatterValuesOnBatchIndices(
        values: List[Boolean],
        batchIndices: Array[Int]): List[Boolean] = {
      // scatter values to pair indices
      val (_, initArray) = batchIndices.zip(values).sorted.unzip
      initArray.toList
    }

    private def topKTopPFiltering(
        logits: Array[Array[Float]],
        topK: Int,
        topP: Double,
        filterValue: Float = Float.NegativeInfinity,
        minTokensToKeep: Int = 1): Array[Array[Float]] = {

      /** Filter a distribution of logits using top-k and/or nucleus (top-p) filtering * Args:
        * logits: logits distribution shape (batch size, vocabulary size) if topK > 0: keep only
        * top k tokens with highest probability (top-k filtering). if topP < 1.0: keep the top
        * tokens with cumulative probability >= topP (nucleus filtering). Nucleus filtering is
        * described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Make sure we keep at
        * least minTokensToKeep per batch example in the output From:
        * https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
        */
      var logitsUpd = logits
      val logitsShape = Array(logits.length, logits(0).length)

      if (topK > 0) {
        val topKup = topK.max(minTokensToKeep).min(logitsShape.last) // Safety check

        /** Remove all tokens with a probability less than the last token of the top-k */
        val removeLimit = logits(0).sortWith(_ > _).take(topKup).min
        val indicesToRemove =
          for (logit <- logits)
            yield for (elem <- logit) yield if (elem < removeLimit) true else false

        logitsUpd =
          for ((nextTokenLogit, indexToRemove) <- logits.zip(indicesToRemove))
            yield setTensorByIndicesToValue(nextTokenLogit, indexToRemove, Float.NegativeInfinity)
      }
      if (topP < 1.0) {
        val (sortedLogits, sortedIndices) = logits(0).zipWithIndex.sorted.reverse.unzip

        val cumulativeProbs = scanLeft(softmax(sortedLogits))(0.0)(_ + _).drop(1)

        /** Remove tokens with cumulative probability above the threshold (token with 0 are kept)
          */
        var sortedIndicesToRemove =
          for (prob <- cumulativeProbs)
            yield if (prob > topP) true else false

        if (minTokensToKeep > 1) {

          /** Keep at least minTokensToKeep (set to minTokensToKeep-1 because we add the first one
            * below)
            */
          sortedIndicesToRemove = List.fill(sortedIndicesToRemove.take(minTokensToKeep).length)(
            false) ++ sortedIndicesToRemove.drop(minTokensToKeep)
        }

        /** Shift the indices to the right to keep also the first token above the threshold */
        sortedIndicesToRemove = sortedIndicesToRemove.takeRight(1) ++ sortedIndicesToRemove
          .dropRight(1)
        sortedIndicesToRemove =
          List.fill(sortedIndicesToRemove.take(1).length)(false) ++ sortedIndicesToRemove
            .drop(1)

        /** scatter sorted tensors to original indexing */
        val indicesToRemove = scatterValuesOnBatchIndices(sortedIndicesToRemove, sortedIndices)
        logitsUpd =
          for ((nextTokenLogit, indexToRemove) <- logits.zip(
              IndexedSeq.fill(logits.length)(indicesToRemove)))
            yield setTensorByIndicesToValue(
              nextTokenLogit,
              indexToRemove.toIndexedSeq,
              Float.NegativeInfinity)
      }
      logitsUpd
    }

    def processLogits(
        batchLogits: Array[Array[Float]],
        decoderInputIds: Array[Array[Int]]): Array[Array[Int]] = {

      nPredictedTokens += 1

      var nextTokenLogits = batchLogits.map(logits => {
        logits.indices
          .map(i => {
            if (ignoreTokenIds.contains(i)) Float.NegativeInfinity else logits(i)
          })
          .toArray
      })

      // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
      if (repetitionPenalty != 1.0) {
        nextTokenLogits =
          createNextTokenLogitsPenalties(decoderInputIds, nextTokenLogits, repetitionPenalty)
      }

      if (noRepeatNgramSize > 0) {
        // calculate a list of banned tokens to prevent repetitively generating the same ngrams
        // from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        val bannedTokens =
          calcBannedNgramTokens(decoderInputIds, batchSize, noRepeatNgramSize, currentLength)
        // create bannedTokens boolean mask
        var bannedTokensIndicesMask = Array.empty[IndexedSeq[Boolean]]
        for (bannedTokensSlice <- bannedTokens) {
          bannedTokensIndicesMask = bannedTokensIndicesMask :+
            (for (token <- 0 until vocabSize)
              yield if (bannedTokensSlice.contains(token)) true else false)
        }
        if (!bannedTokensIndicesMask.isEmpty) {
          nextTokenLogits =
            for ((nextTokenLogit, bannedTokensIndexMask) <- nextTokenLogits.zip(
                bannedTokensIndicesMask))
              yield setTensorByIndicesToValue(
                nextTokenLogit,
                bannedTokensIndexMask,
                Float.NegativeInfinity)
        }
      }

      if (randomSeed.isDefined)
        scala.util.Random.setSeed(randomSeed.get)

      val predictions = if (doSample) {

        // Temperature (higher temperature => more likely to sample low probability tokens)
        if (temperature != 1.0)
          nextTokenLogits =
            for (nextTokenLogit <- nextTokenLogits)
              yield nextTokenLogit.map(_ / temperature.toFloat)
        // Top-p/top-k filtering
        nextTokenLogits = topKTopPFiltering(nextTokenLogits, topK, topP)
        // Sample

        nextTokenLogits.map(input => categoricalSample(input, randomSeed))
      } else {
        nextTokenLogits.map(x => x.zipWithIndex.maxBy(_._1)._2)
      }
      //      var tokensToAdd = Array.ofDim[Int](decoderInputIds.length)
      val tokensToAdd =
        predictions.zip(unfinishedSentences).map(x => x._1 * x._2 + paddingTokenId * (1 - x._2))

      currentLength += 1

      val eosInSentences = tokensToAdd.map(x => if (stopTokens.contains(x)) 1 else 0)
      // if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
      val areSentencesUnfinishedAndTokenToAddIsEos =
        unfinishedSentences.zip(eosInSentences).map(x => x._1 * x._2)

      sentenceLengths = sentenceLengths
        .zip(areSentencesUnfinishedAndTokenToAddIsEos)
        .map(x => x._1 * (1 - x._2) + currentLength * x._2)

      // unfinishedSents is set to zero if eos in sentence
      unfinishedSentences =
        unfinishedSentences.zip(areSentencesUnfinishedAndTokenToAddIsEos).map(x => x._1 - x._2)

      decoderInputIds
        .zip(tokensToAdd)
        .map(x => {
          x._1 ++ Array(x._2)
        })

    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy