All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.tokenizer.bpe

import com.johnsnowlabs.nlp.annotators.common.{IndexedToken, Sentence, TokenPiece}
import org.apache.commons.lang.StringUtils

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** A BPE Tokenizer based on GPT2's tokenization scheme. The tokenization can then be used for
  * models based on this scheme (e.g. GPT2, roBERTa, DeBERTa) TODO: truncation assumed?
  */
private[nlp] abstract class BpeTokenizer(
    val merges: Map[(String, String), Int],
    val vocab: Map[String, Int],
    val specialTokens: SpecialTokens,
    var padWithSentenceTokens: Boolean) {

  protected val bpeRanks: Map[(String, String), Int] = {
    merges
  }

  /** Rankings for the byte pairs. Derived from merges.txt */
  protected def getBpeRanking: ((String, String)) => Int =
    (bytePair: (String, String)) => bpeRanks.getOrElse(bytePair, Integer.MAX_VALUE)

  /** cache for already encoded tokens */
  protected val cache: mutable.Map[String, Array[String]] = mutable.Map()

  /** Create a sequence of byte-pairs of the word */
  protected def getBytePairs(word: Array[String]): Array[(String, String)] = {
    val createPairs = (i: Int) => (word(i), word(i + 1))
    (0 until (word.length - 1)).map(createPairs).toArray
  }

  // Can be overridden in inherited class
  protected val prependForPieceId: Option[String] = None
  protected val appendForPieceId: Option[String] = None

  protected def performMerges(
      wordChars: Array[String],
      charPairs: Array[(String, String)]): Array[String] = {
    var word = wordChars
    var pairs = charPairs
    // get highest priority byte-pair first
    var bytePair: (String, String) =
      pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
    var done = false
    // while we still have byte-pairs from our vocabulary
    while (bpeRanks.contains(bytePair) && !done) {
      val (first, second) = bytePair
      val newWord: ListBuffer[String] = ListBuffer()
      var i = 0
      var j = 0
      // keep combining characters with the current byte-pair
      while ((i < word.length) && (j != -1)) {
        j = word.indexOf(first, i)
        if (j == -1) newWord ++= word.drop(i)
        else {
          newWord ++= word.slice(i, j)
          i = j
          val bpIsAtIndex =
            (word(i) == first) && (i < word.length - 1) && word(i + 1) == second
          if (bpIsAtIndex) {
            newWord += (first + second)
            i += 2
          } else {
            newWord += word(i)
            i += 1
          }
        }
      }
      word = newWord.toArray
      // if we were able to create a whole word that was in the vocabulary, we're done
      if (word.length == 1) {
        done = true
      } else {
        // do it again with the next byte-pair
        pairs = getBytePairs(word)
        bytePair = pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
      }
    }
    word
  }

  protected def getTokenPieces(
      indToken: IndexedToken,
      word: Array[String],
      processedToken: String): Array[TokenPiece] = {
    var currentIndex = indToken.begin
    val wordIndexes = word.map((subWord: String) => {
      val startIndex = currentIndex
      currentIndex = startIndex + subWord.length
      (startIndex, startIndex + subWord.length - 1)
    })
    val result = word
      .zip(wordIndexes)
      .map { case (subWord: String, indexes: (Int, Int)) =>
        val isWordStart = indToken.begin == indexes._1
        var processedSubWord = subWord
        processedSubWord = prependForPieceId match {
          case None => processedSubWord
          case Some(prepend) =>
            if (isWordStart && subWord.indexOf(prepend) < 0) prepend + processedSubWord
            else processedSubWord
        }
        processedSubWord = appendForPieceId match {
          case None => processedSubWord
          case Some(append) =>
            val isWordEnd = indToken.end == indexes._2
            if (isWordEnd && subWord.indexOf(append) < 0) processedSubWord + append
            else processedSubWord
        }
        // Set unknown id if not found
        val subWordId: Int = vocab.getOrElse(processedSubWord, specialTokens.unk.id)

        TokenPiece(subWord, processedToken, subWordId, isWordStart, indexes._1, indexes._2)

      }
    result
  }

  /** Do the BPE algorithm. Goal is to find the token as the largest words in the known
    * vocabulary. If not possible, the word is split into smaller subwords, until they are known.
    *
    * @return
    *   Array of TokenPieces, corresponding to encoded token
    */
  protected def bpe(indToken: IndexedToken): Array[TokenPiece] = {
    var processedToken = ""
    try {
      processedToken = preProcessTokenForBpe(indToken.token)
      // TODO: Caching
      var word: Array[String] = Array[String]()
      // split the word into characters, to be combined into subwords
      word = processedToken.map(_.toString).toArray
      val pairs: Array[(String, String)] = getBytePairs(word)

      if (pairs.isEmpty)
        word = Array(processedToken)
      else
        word = performMerges(word, pairs)

      getTokenPieces(indToken, word, processedToken)
    } catch {
      case _: java.util.NoSuchElementException =>
        Array(
          TokenPiece(
            indToken.token,
            indToken.token,
            specialTokens.unk.id,
            isWordStart = true,
            indToken.begin,
            indToken.end))
    }
  }

  /** Split the the individual sub texts on special tokens, e.g. masking etc. */
  protected def splitOnSpecialToken(
      specialToken: SpecialToken,
      text: String): ListBuffer[String] = {
    val isControl = (c: Char) => {
      if (c == '\t' || c == '\n' || c == '\r') false // count as whitespace
      else c.isControl
    }
    val isPunctuation =
      (c: Char) => raw"""[^[:alnum:]]""".r.findFirstIn(c.toString).isDefined
    val isWordBorder =
      (c: Char) => isControl(c) || isPunctuation(c) || c.isWhitespace

    val isEndOfWord = (text: String) => isWordBorder(text.last)
    val isStartOfWord = (text: String) => isWordBorder(text.head)

    val result: ListBuffer[String] = ListBuffer()
    val tok = specialToken.content

    val splitText = StringUtils.splitByWholeSeparator(text, tok)
    var fullWord = ""

    for ((subText, i) <- splitText.zipWithIndex) {
      var done = false
      // Try to avoid splitting on token
      if (specialToken.singleWord) {
        if ((i < (splitText.length - 1)) && !isEndOfWord(subText) && !isStartOfWord(
            splitText(i + 1))) fullWord += subText + tok
        else if (fullWord.nonEmpty) {
          fullWord += subText
          result += fullWord
          fullWord = ""
          done = true
        }
      }
      if (!done) {
        // A bit counter-intuitive but we strip the left of the string
        // since rstrip means the special token is eating all white spaces on its right
        var subTextProcessed: String = subText
        if (specialToken.rstrip && i > 0)
          subTextProcessed = subText.stripPrefix(" ")
        if (specialToken.lstrip && i < (splitText.length - 1))
          subTextProcessed = subText.stripSuffix(" ")
        if (i == 0 && subTextProcessed.isEmpty)
          result += tok
        else if (i == (splitText.length - 1)) {
          if (subTextProcessed.nonEmpty) result += subTextProcessed
        } else {
          if (subTextProcessed.nonEmpty) result += subTextProcessed
          result += tok
        }
      }
    }
    result
  }

  /** Needs to be implemented */
  def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]

  /** Special tokens of the model for processing */
  val sentencePadding: (String, String) =
    (specialTokens.sentenceStart.content, specialTokens.sentenceEnd.content)

  /** Tokenize considering special tokens and split algorithm */
  def tokenize(sentence: Sentence): Array[IndexedToken] = {
    var text = sentence.content
    if (text.trim.isEmpty) Array[IndexedToken]()
    else {
      val splitTexts: ListBuffer[String] = ListBuffer()
      var textList: ListBuffer[String] = ListBuffer(text)

      for (transformations <- specialTokens.allTokens) {
        splitTexts.clear()
        for (subText <- textList) {
          if (!specialTokens.contains(subText))
            splitTexts ++= splitOnSpecialToken(transformations, subText)
          else
            splitTexts += subText
        }
        textList = splitTexts.clone()
      }

      if (padWithSentenceTokens) {
        text = sentencePadding._1 + text + sentencePadding._2
        splitTexts.prepend(sentencePadding._1)
        splitTexts.append(sentencePadding._2)
      }

      var currentIndex = 0
      val result = mutable.ArrayBuffer[IndexedToken]()
      for (subText <- splitTexts) {
        val subTextIndex = sentence.start + text.indexOf(subText, currentIndex)
        if (!specialTokens.contains(subText)) {
          val splitSubText: Array[IndexedToken] = tokenizeSubText(subText, subTextIndex)
          result.append(splitSubText: _*)
        } else // subtext is just the special token
          result.append(
            IndexedToken(subText, begin = subTextIndex, end = subTextIndex + subText.length - 1))
        currentIndex = subTextIndex + subText.length
      }
      result.toArray
    }
  }

  protected def preProcessTokenForBpe(token: String): String = token

  def encode(indToken: IndexedToken): Array[TokenPiece] = {
    if (!specialTokens.contains(indToken.token))
      bpe(indToken)
    else
      Array(
        TokenPiece(
          indToken.token,
          indToken.token,
          vocab(indToken.token),
          isWordStart = true,
          indToken.begin,
          indToken.end))
  }

  def encode(indTokens: Array[IndexedToken]): Array[TokenPiece] = indTokens.flatMap(encode(_))
}

object BpeTokenizer {
  def forModel(
      modelType: String,
      merges: Map[(String, String), Int],
      vocab: Map[String, Int],
      padWithSentenceTokens: Boolean = false,
      specialTokens: Option[SpecialTokens] = None): BpeTokenizer = {
    val availableModels = Array("roberta", "xlm", "gpt2")
    require(
      availableModels.contains(modelType),
      "Model type \"" + modelType + "\" not supported yet.")

    val modelSpecialTokens = specialTokens match {
      case Some(specialTok) => specialTok
      case None => SpecialTokens.getSpecialTokensForModel(modelType, vocab)
    }

    modelType match {
      case "roberta" =>
        new RobertaTokenizer(merges, vocab, modelSpecialTokens, padWithSentenceTokens)
      case "xlm" => new XlmTokenizer(merges, vocab, modelSpecialTokens, padWithSentenceTokens)
      case "gpt2" => new Gpt2Tokenizer(merges, vocab, modelSpecialTokens, padWithSentenceTokens)

    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy