All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.TokenizerModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.StructFeature
import com.johnsnowlabs.nlp.util.regex.RuleFactory
import org.apache.spark.ml.param.{BooleanParam, IntParam, Param, StringArrayParam}
import org.apache.spark.ml.util.Identifiable

import scala.collection.mutable
import scala.util.matching.Regex

/** Tokenizes raw text into word pieces, tokens. Identifies tokens with tokenization open
  * standards. A few rules will help customizing it if defaults do not fit user needs.
  *
  * This class represents an already fitted Tokenizer model.
  *
  * See the main class [[Tokenizer]] for more examples of usage.
  *
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class TokenizerModel(override val uid: String)
    extends AnnotatorModel[TokenizerModel]
    with HasSimpleAnnotate[TokenizerModel] {

  import com.johnsnowlabs.nlp.AnnotatorType._

  /** rules
    *
    * @group param
    */
  val rules: StructFeature[RuleFactory] = new StructFeature[RuleFactory](this, "rules")

  /** Words that won't be affected by tokenization rules
    *
    * @group param
    */
  val exceptions: StringArrayParam =
    new StringArrayParam(this, "exceptions", "Words that won't be affected by tokenization rules")

  /** Whether to care for case sensitiveness in exceptions (Default: `true`)
    *
    * @group param
    */
  val caseSensitiveExceptions: BooleanParam = new BooleanParam(
    this,
    "caseSensitiveExceptions",
    "Whether to care for case sensitiveness in exceptions")

  /** pattern to grab from text as token candidates. Defaults \\S+
    *
    * @group param
    */
  val targetPattern: Param[String] = new Param(
    this,
    "targetPattern",
    "pattern to grab from text as token candidates. Defaults \\S+")

  /** Set the minimum allowed length for each token
    *
    * @group param
    */
  val minLength = new IntParam(this, "minLength", "Set the minimum allowed length for each token")

  /** Set the maximum allowed length for each token
    *
    * @group param
    */
  val maxLength = new IntParam(this, "maxLength", "Set the maximum allowed length for each token")

  /** character list used to separate from the inside of tokens
    *
    * @group param
    */
  val splitChars: StringArrayParam = new StringArrayParam(
    this,
    "splitChars",
    "character list used to separate from the inside of tokens")

  /** pattern to separate from the inside of tokens. takes priority over splitChars.
    *
    * @group param
    */
  val splitPattern: Param[String] = new Param(
    this,
    "splitPattern",
    "pattern to separate from the inside of tokens. takes priority over splitChars.")

  setDefault(targetPattern -> "\\S+", caseSensitiveExceptions -> true)

  /** Output annotator type : TOKEN
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = TOKEN

  /** Input annotator type : DOCUMENT
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array[AnnotatorType](
    DOCUMENT
  ) // A Tokenizer could require only for now a SentenceDetector annotator

  def this() = this(Identifiable.randomUID("REGEX_TOKENIZER"))

  /** pattern to grab from text as token candidates. Defaults \\S+
    *
    * @group setParam
    */
  def setTargetPattern(value: String): this.type = set(targetPattern, value)

  /** pattern to grab from text as token candidates. Defaults \\S+
    *
    * @group getParam
    */
  def getTargetPattern: String = $(targetPattern)

  /** List of 1 character string to split tokens inside, such as hyphens. Ignored if using infix,
    * prefix or suffix patterns.
    *
    * @group setParam
    */
  def setSplitPattern(value: String): this.type = set(splitPattern, value)

  /** List of 1 character string to split tokens inside, such as hyphens. Ignored if using infix,
    * prefix or suffix patterns.
    *
    * @group getParam
    */
  def getSplitPattern: String = $(splitPattern)

  /** Words that won't be affected by tokenization rules
    *
    * @group setParam
    */
  def setExceptions(value: Array[String]): this.type = set(exceptions, value)

  /** Words that won't be affected by tokenization rules
    *
    * @group getParam
    */
  def getExceptions: Array[String] = $(exceptions)

  /** Rules factory for tokenization
    *
    * @group setParam
    */
  def setRules(ruleFactory: RuleFactory): this.type = set(rules, ruleFactory)

  /** Whether to follow case sensitiveness for matching exceptions in text
    *
    * @group setParam
    */
  def setCaseSensitiveExceptions(value: Boolean): this.type = set(caseSensitiveExceptions, value)

  /** Whether to follow case sensitiveness for matching exceptions in text
    *
    * @group getParam
    */
  def getCaseSensitiveExceptions(value: Boolean): Boolean = $(caseSensitiveExceptions)

  /** Set the minimum allowed length for each token
    *
    * @group setParam
    */
  def setMinLength(value: Int): this.type = set(minLength, value)

  /** Set the minimum allowed length for each token
    *
    * @group getParam
    */
  def getMinLength(value: Int): Int = $(minLength)

  /** Set the maximum allowed length for each token
    *
    * @group setParam
    */
  def setMaxLength(value: Int): this.type = set(maxLength, value)

  /** Set the maximum allowed length for each token
    *
    * @group getParam
    */
  def getMaxLength(value: Int): Int = $(maxLength)

  /** List of 1 character string to split tokens inside, such as hyphens. Ignored if using infix,
    * prefix or suffix patterns.
    *
    * @group setParam
    */
  def setSplitChars(v: Array[String]): this.type = {
    require(
      v.forall(x => x.length == 1 || (x.length == 2 && x.substring(0, 1) == "\\")),
      "All elements in context chars must have length == 1")
    set(splitChars, v)
  }

  /** One character string to split tokens inside, such as hyphens. Ignored if using infix, prefix
    * or suffix patterns.
    *
    * @group setParam
    */
  def addSplitChars(v: String): this.type = {
    require(
      v.length == 1 || (v.length == 2 && v.substring(0, 1) == "\\"),
      "Context char must have length == 1")
    set(splitChars, get(splitChars).getOrElse(Array.empty[String]) :+ v)
  }

  /** List of 1 character string to split tokens inside, such as hyphens. Ignored if using infix,
    * prefix or suffix patterns
    *
    * @group getParam
    * .
    */
  def getSplitChars: Array[String] = {
    $(splitChars)
  }

  private val PROTECT_CHAR = "ↈ"
  private val BREAK_CHAR = "ↇ"

  private lazy val BREAK_PATTERN: Regex =
    ("[^(?:" + $(targetPattern) + ")" + PROTECT_CHAR + "]").r
  private lazy val SPLIT_PATTERN: Regex =
    ("[^" + BREAK_CHAR + "]+").r

  private var compiledExceptions: Option[Regex] = None

  private def getOrCompileExceptionPattern(): Regex = compiledExceptions.getOrElse {
    get(exceptions) match {
      case None => "".r
      case Some(patterns) =>
        val compiledPattern: Regex = patterns
          .map { pattern =>
            val casedExceptionPattern =
              if ($(caseSensitiveExceptions)) pattern else "(?i)" + pattern
            casedExceptionPattern
          }
          .mkString("|")
          .r
        compiledExceptions = Some(compiledPattern)
        compiledPattern
    }
  }

  /** This func generates a Seq of TokenizedSentences from a Seq of Sentences.
    *
    * @param sentences
    *   to tag
    * @return
    *   Seq of TokenizedSentence objects
    */
  def tag(sentences: Seq[Sentence]): Seq[TokenizedSentence] = {
    lazy val splitCharsExists = $(splitChars).map(_.last.toString)
    sentences.map { text =>
      /** Step 1, define breaks from non breaks */
      val exceptionsDefined = get(exceptions).isDefined

      var exceptionsWithoutBreak: Option[mutable.HashSet[String]] = None

      val textContent = if (exceptionsDefined) {

        /** If found, replace BREAK_PATTERN with PROTECT_CHAR, otherwise add to an exception list.
          */
        getOrCompileExceptionPattern().replaceAllIn(
          text.content,
          { m: Regex.Match =>
            {
              val breakReplaced = BREAK_PATTERN.replaceAllIn(m.matched, PROTECT_CHAR)
              if (breakReplaced == m.matched)
                exceptionsWithoutBreak.getOrElse({
                  exceptionsWithoutBreak = Some(new mutable.HashSet[String])
                  exceptionsWithoutBreak.get
                }) += m.matched
              breakReplaced
            }
          })
      } else {
        text.content
      }

      val protectedText = BREAK_PATTERN.replaceAllIn(textContent, BREAK_CHAR)

      /** Step 2, Return protected tokens back into text and move on */
      val tokens = SPLIT_PATTERN
        .findAllMatchIn(protectedText)
        .flatMap { candidate =>
          /** If exceptions are defined, check for candidate whether PROTECT_CHAR present or in
            * exception list.
            */
          if (exceptionsDefined &&
            (candidate.matched.contains(PROTECT_CHAR) ||
              (exceptionsWithoutBreak.isDefined && exceptionsWithoutBreak.get.contains(
                candidate.matched)))) {

            /** Put back character and move on */
            Seq(
              IndexedToken(
                text.content.slice(candidate.start, candidate.end),
                text.start + candidate.start,
                text.start + candidate.end - 1))
          } else {

            /** Step 3, If no exception found, find candidates through the possible general rule
              * patterns
              */
            val rr = $$(rules)
              .findMatchFirstOnly(candidate.matched)
              .map { m =>
                var curPos = m.content.start
                (1 to m.content.groupCount)
                  .flatMap(i => {
                    val target = m.content.group(i)
                    val applyPattern = isSet(splitPattern) && (target
                      .split($(splitPattern))
                      .length > 1)
                    val applyChars = isSet(splitChars) && splitCharsExists.exists(target.contains)

                    def defaultCandidate = {
                      val it = IndexedToken(
                        target,
                        text.start + candidate.start + curPos,
                        text.start + candidate.start + curPos + target.length - 1)
                      curPos += target.length
                      Seq(it)
                    }

                    if (target.nonEmpty && (applyPattern || applyChars)) {
                      try {
                        val strs =
                          if (applyPattern) target.split($(splitPattern))
                          else target.split($(splitChars).mkString("|"))
                        strs.map { str =>
                          curPos = m.content.matched.indexOf(str, curPos)
                          val indexedToken = IndexedToken(
                            str,
                            text.start + candidate.start + curPos,
                            text.start + candidate.start + curPos + str.length - 1)
                          curPos += str.length
                          indexedToken
                        }
                      } catch {
                        case e: Throwable =>
                          logWarning(s"Tokenizer: Could not apply custom pattern or char. " +
                            s"Using default split pattern for \'$target\'. Exception: ${e.toString}")
                          defaultCandidate
                      }
                    } else defaultCandidate
                  })

              /** Step 4, If rules didn't match, return whatever candidate we have and leave it as
                * is
                */
              }
              .getOrElse(Seq(IndexedToken(
                candidate.matched,
                text.start + candidate.start,
                text.start + candidate.end - 1)))
            rr
          }
        }
        .filter(t =>
          t.token.nonEmpty && t.token.length >= $(minLength) && get(maxLength).forall(m =>
            t.token.length <= m))
        .toArray
      TokenizedSentence(tokens, text.index)
    }
  }

  /** one to many annotation */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentences = SentenceSplit.unpack(annotations)
    val tokenized = tag(sentences)
    TokenizedWithSentence.pack(tokenized)
  }
}

trait ReadablePretrainedTokenizer
    extends ParamsAndFeaturesReadable[TokenizerModel]
    with HasPretrained[TokenizerModel] {
  override val defaultModelName: Option[String] = Some("token_rules")

  /** Java compliant-overrides */
  override def pretrained(): TokenizerModel = super.pretrained()

  override def pretrained(name: String): TokenizerModel = super.pretrained(name)

  override def pretrained(name: String, lang: String): TokenizerModel =
    super.pretrained(name, lang)

  override def pretrained(name: String, lang: String, remoteLoc: String): TokenizerModel =
    super.pretrained(name, lang, remoteLoc)
}

/** This is the companion object of [[TokenizerModel]]. Please refer to that class for the
  * documentation.
  */
object TokenizerModel extends ReadablePretrainedTokenizer




© 2015 - 2024 Weber Informatics LLC | Privacy Policy