All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.tokenizer.bpe.XlmTokenizer.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.tokenizer.bpe

import com.johnsnowlabs.nlp.annotators.common.{IndexedToken, TokenPiece}
import com.johnsnowlabs.nlp.annotators.tokenizer.moses.MosesTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.normalizer.MosesPunctNormalizer

/** XLM Tokenizer
  *
  * @param merges
  *   Combinations of byte pairs with ranking
  * @param vocab
  *   Mapping from byte pair to an id
  * @param lang
  *   Language of the text (Currently only english supported)
  * @param specialTokens
  *   Special Tokens of the model to not split on
  * @param doLowercaseAndRemoveAccent
  *   True for current supported model (v1.2.0), False for XLM-17 & 100
  */
private[nlp] class XlmTokenizer(
    merges: Map[(String, String), Int],
    vocab: Map[String, Int],
    specialTokens: SpecialTokens,
    padWithSentenceTokens: Boolean = false,
    lang: String = "en",
    doLowercaseAndRemoveAccent: Boolean = true)
    extends BpeTokenizer(
      merges,
      vocab,
      specialTokens,
      padWithSentenceTokens,
      addPrefixSpace = false) {
  require(lang == "en", "Only English is supported currently.")

  /** Lowercase and strips accents from a piece of text based on
    * https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
    */
  def lowercaseAndRemoveAccent(input: String): String = {
    var text = input
    text = text.toLowerCase()
    text = java.text.Normalizer.normalize(text, java.text.Normalizer.Form.NFD)
    text.toCharArray
      .filter(Character.getType(_) != Character.NON_SPACING_MARK) // Unicode Category "Mn"
      .mkString
      .toLowerCase
  }

  val mosesNormalizer = new MosesPunctNormalizer()
  val mosesTokenizer = new MosesTokenizer(lang)

  private def mosesPipeline(text: String): Array[String] = {
    var processed = text
    processed = mosesNormalizer.normalize(processed)
    processed = mosesNormalizer.removeNonPrintingChar(processed)
    mosesTokenizer.tokenize(processed)
  }

  override def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken] = {
    var indexedTokens: Array[IndexedToken] = Array()
    val mosesTokenized = mosesPipeline(text)
    val processedText =
      if (doLowercaseAndRemoveAccent)
        lowercaseAndRemoveAccent(mosesTokenized.mkString(" "))
      else mosesTokenized.mkString(" ")

    val textForIndexing = if (doLowercaseAndRemoveAccent) lowercaseAndRemoveAccent(text) else text
    indexedTokens = processedText
      .split(" ")
      .map((token: String) => {
        val tokenTextIndex = textForIndexing.indexOf(token)
        IndexedToken(
          token,
          indexOffset + tokenTextIndex,
          indexOffset + tokenTextIndex + token.length - 1
        ) // TODO: What if special characters were removed?
      })
    indexedTokens
  }

  override val appendForPieceId: Option[String] = Some("")

  override def bpe(indToken: IndexedToken): Array[TokenPiece] = {
    val processedToken = preProcessTokenForBpe(indToken.token)

    var word: Array[String] = Array[String]()
    // split the word into characters, to be combined into subwords
    word = processedToken.map(_.toString).toArray
    val pairs: Array[(String, String)] = getBytePairs(word)

    // XLM Specific: append word end indicator
    if (pairs.isEmpty) {
      word = Array(processedToken)
    } else {
      pairs(pairs.length - 1) = pairs(pairs.length - 1) match {
        case (s, s1) => (s, s1 + "")
      }
      word(word.length - 1) += ""
      word = performMerges(word, pairs)

      // remove end token again for correct indexes
      word = word.map(_.replace("", ""))
    }

    getTokenPieces(indToken, word)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy