All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.ws.WordSegmenterModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.ws

import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, TOKEN}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.pos.perceptron.{
  AveragedPerceptron,
  PerceptronPredictionUtils
}
import com.johnsnowlabs.nlp.annotators.ws.TagsType.{LEFT_BOUNDARY, MIDDLE, RIGHT_BOUNDARY}
import com.johnsnowlabs.nlp.serialization.StructFeature
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.RegexTokenizer
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.Identifiable

/** WordSegmenter which tokenizes non-english or non-whitespace separated texts.
  *
  * Many languages are not whitespace separated and their sentences are a concatenation of many
  * symbols, like Korean, Japanese or Chinese. Without understanding the language, splitting the
  * words into their corresponding tokens is impossible. The WordSegmenter is trained to
  * understand these languages and plit them into semantically correct parts.
  *
  * This is the instantiated model of the [[WordSegmenterApproach]]. For training your own model,
  * please see the documentation of that class.
  *
  * Pretrained models can be loaded with `pretrained` of the companion object:
  * {{{
  * val wordSegmenter = WordSegmenterModel.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("words_segmented")
  * }}}
  * The default model is `"wordseg_pku"`, default language is `"zh"`, if no values are provided.
  * For available pretrained models please see the
  * [[https://nlp.johnsnowlabs.com/models?task=Word+Segmentation Models Hub]].
  *
  * For extended examples of usage, see the
  * [[https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/annotation/chinese/word_segmentation/words_segmenter_demo.ipynb Spark NLP Workshop]]
  * and the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/WordSegmenterTest.scala WordSegmenterTest]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotator.WordSegmenterModel
  * import org.apache.spark.ml.Pipeline
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val wordSegmenter = WordSegmenterModel.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("token")
  *
  * val pipeline = new Pipeline().setStages(Array(
  *   documentAssembler,
  *   wordSegmenter
  * ))
  *
  * val data = Seq("然而,這樣的處理也衍生了一些問題。").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.select("token.result").show(false)
  * +--------------------------------------------------------+
  * |result                                                  |
  * +--------------------------------------------------------+
  * |[然而, ,, 這樣, 的, 處理, 也, 衍生, 了, 一些, 問題, 。    ]|
  * +--------------------------------------------------------+
  * }}}
  *
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class WordSegmenterModel(override val uid: String)
    extends AnnotatorModel[WordSegmenterModel]
    with HasSimpleAnnotate[WordSegmenterModel]
    with PerceptronPredictionUtils {

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  def this() = this(Identifiable.randomUID("WORD_SEGMENTER"))

  /** POS model
    *
    * @group param
    */
  val model: StructFeature[AveragedPerceptron] =
    new StructFeature[AveragedPerceptron](this, "POS Model")

  val enableRegexTokenizer: BooleanParam = new BooleanParam(
    this,
    "enableRegexTokenizer",
    "Whether to use RegexTokenizer before segmentation. Useful for multilingual text")

  /** Indicates whether to convert all characters to lowercase before tokenizing (Default:
    * `false`).
    *
    * @group param
    */
  val toLowercase: BooleanParam = new BooleanParam(
    this,
    "toLowercase",
    "Indicates whether to convert all characters to lowercase before tokenizing.\n")

  /** Regex pattern used to match delimiters (Default: `"\\s+"`)
    *
    * @group param
    */
  val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")

  /** @group getParam */
  def getModel: AveragedPerceptron = $$(model)

  /** @group setParam */
  def setModel(targetModel: AveragedPerceptron): this.type = set(model, targetModel)

  /** @group setParam */
  def setEnableRegexTokenizer(value: Boolean): this.type = set(enableRegexTokenizer, value)

  /** @group setParam */
  def setToLowercase(value: Boolean): this.type = set(toLowercase, value)

  /** @group setParam */
  def setPattern(value: String): this.type = set(pattern, value)

  setDefault(enableRegexTokenizer -> false, toLowercase -> false, pattern -> "\\s+")

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {

    if ($(enableRegexTokenizer)) {
      return segmentWithRegexAnnotator(annotations)
    }

    val sentences = SentenceSplit.unpack(annotations)
    val tokens = getTokenAnnotations(sentences)
    val tokenizedSentences = TokenizedWithSentence.unpack(annotations ++ tokens)
    val tagged = tag($$(model), tokenizedSentences.toArray)
    buildWordSegments(tagged)
  }

  private def segmentWithRegexAnnotator(annotatedSentences: Seq[Annotation]): Seq[Annotation] = {

    val outputCol = Identifiable.randomUID("regex_token")

    val regexTokenizer = new RegexTokenizer()
      .setInputCols(getInputCols)
      .setOutputCol(outputCol)
      .setToLowercase($(toLowercase))
      .setPattern($(pattern))

    val annotatedTokens = regexTokenizer.annotate(annotatedSentences)

    val segmentedResult = annotatedTokens.flatMap { annotatedToken =>
      val codePoint = annotatedToken.result.codePointAt(0)
      val unicodeScript = Character.UnicodeScript.of(codePoint)
      if (unicodeScript == Character.UnicodeScript.LATIN) {
        Seq(annotatedToken)
      } else {
        val sentenceIndex = annotatedToken.metadata("sentence")

        val annotatedSentence = Annotation(
          DOCUMENT,
          annotatedToken.begin,
          annotatedToken.end,
          annotatedToken.result,
          Map("sentence" -> sentenceIndex))
        val sentence = Sentence(
          annotatedToken.result,
          annotatedToken.begin,
          annotatedToken.end,
          sentenceIndex.toInt)
        val annotatedTokens = getTokenAnnotations(Seq(sentence))

        val tokenizedSentences =
          TokenizedWithSentence.unpack(annotatedTokens ++ Seq(annotatedSentence))
        val tagged = tag($$(model), tokenizedSentences.toArray)
        buildWordSegments(tagged)
      }
    }

    segmentedResult
  }

  private def getTokenAnnotations(annotation: Seq[Sentence]): Seq[Annotation] = {
    val tokens = annotation.flatMap { sentence =>
      val chars = sentence.content.split("")
      chars.zipWithIndex
        .map { case (char, index) =>
          val tokenIndex = index + sentence.start
          Annotation(
            TOKEN,
            tokenIndex,
            tokenIndex,
            char,
            Map("sentence" -> sentence.index.toString))
        }
        .filter(annotation => annotation.result != " ")
    }
    tokens
  }

  def buildWordSegments(taggedSentences: Array[TaggedSentence]): Seq[Annotation] = {
    taggedSentences.zipWithIndex.flatMap { case (taggedSentence, index) =>
      val tagsSentence = taggedSentence.tags.mkString("")
      val wordIndexesByMatchedGroups = getWordIndexesByMatchedGroups(tagsSentence)
      if (wordIndexesByMatchedGroups.isEmpty) {
        taggedSentence.indexedTaggedWords.map(indexedTaggedWord =>
          Annotation(
            TOKEN,
            indexedTaggedWord.begin,
            indexedTaggedWord.end,
            indexedTaggedWord.word,
            Map("sentence" -> index.toString)))
      } else {
        annotateSegmentWords(wordIndexesByMatchedGroups, taggedSentence, index)
      }
    }
  }

  private def getWordIndexesByMatchedGroups(tagsSentence: String): List[List[RegexTagsInfo]] = {
    val regexPattern = s"($LEFT_BOUNDARY)($MIDDLE*)*($RIGHT_BOUNDARY)".r
    regexPattern
      .findAllMatchIn(tagsSentence)
      .map(matchedResult => {
        val groups = (1 to matchedResult.groupCount).toList
        groups
          .map(g =>
            RegexTagsInfo(
              matchedResult.group(g),
              matchedResult.start(g),
              matchedResult.end(g),
              (matchedResult.end(g) / 2) - 1))
          .filter(regexTagsInfo => regexTagsInfo.estimatedIndex != -1)
      })
      .toList
  }

  private def annotateSegmentWords(
      wordIndexesByMatchedGroups: List[List[RegexTagsInfo]],
      taggedSentence: TaggedSentence,
      sentenceIndex: Int): Seq[Annotation] = {

    val singleTaggedWords =
      getSingleIndexedTaggedWords(wordIndexesByMatchedGroups, taggedSentence)
    val multipleTaggedWords = getMultipleTaggedWords(wordIndexesByMatchedGroups, taggedSentence)
    val segmentedTaggedWords = (singleTaggedWords ++ multipleTaggedWords)
      .sortWith(
        _.metadata.getOrElse("index", "-1").toInt < _.metadata.getOrElse("index", "-1").toInt)
    segmentedTaggedWords.map(segmentedTaggedWord =>
      Annotation(
        TOKEN,
        segmentedTaggedWord.begin,
        segmentedTaggedWord.end,
        segmentedTaggedWord.word,
        Map("sentence" -> sentenceIndex.toString)))
  }

  private def getSingleIndexedTaggedWords(
      wordIndexesByMatchedGroups: List[List[RegexTagsInfo]],
      taggedSentence: TaggedSentence): List[IndexedTaggedWord] = {
    val flattenWordIndexes = wordIndexesByMatchedGroups.flatMap(wordIndexGroup =>
      wordIndexGroup.map(wi => wi.estimatedIndex))
    val unmatchedTaggedWordsCandidates = taggedSentence.indexedTaggedWords.zipWithIndex
      .filter { case (_, index) =>
        !flattenWordIndexes.contains(index)
      }
      .map(_._1)
    val unmatchedTaggedWords =
      unmatchedTaggedWordsCandidates.filter(unmatchedTaggedWordCandidate =>
        !isMatchedWord(unmatchedTaggedWordCandidate, wordIndexesByMatchedGroups))
    unmatchedTaggedWords.toList
  }

  private def isMatchedWord(
      indexedTaggedWord: IndexedTaggedWord,
      regexTagsInfoList: List[List[RegexTagsInfo]]): Boolean = {
    val index = indexedTaggedWord.metadata.getOrElse("index", "-1").toInt

    val result = regexTagsInfoList.flatMap(regexTagsInfo => {
      val leftBoundaryIndex = regexTagsInfo.head.estimatedIndex
      val rightBoundaryIndex = regexTagsInfo.last.estimatedIndex
      val isInRange = if (index > leftBoundaryIndex && index < rightBoundaryIndex) true else false
      val verifyMatches = regexTagsInfo.map(rti => {
        if (indexedTaggedWord.tag != MIDDLE || !isInRange) "unmatched"
        else {
          if (rti.tagsMatch.contains(MIDDLE) && rti.tagsMatch.length > 2) "matched"
          else "unmatched"
        }
      })
      verifyMatches
    })
    result.contains("matched")
  }

  private def getMultipleTaggedWords(
      wordIndexesByMatchedGroups: List[List[RegexTagsInfo]],
      taggedSentence: TaggedSentence): List[IndexedTaggedWord] = {
    wordIndexesByMatchedGroups.flatMap { wordIndexesGroup =>
      val wordIndexes = wordIndexesGroup.map(wi => wi.estimatedIndex)
      val taggedWords = taggedSentence.indexedTaggedWords.zipWithIndex
        .filter { case (indexedTaggedWord, index) =>
          wordIndexes.contains(index) || isMatchedWord(indexedTaggedWord, List(wordIndexesGroup))
        }
        .map(_._1)
      if (taggedWords.nonEmpty) Some(taggedWords.reduceLeft(processTags)) else None
    }
  }

  private val processTags = (current: IndexedTaggedWord, next: IndexedTaggedWord) => {
    val wordSegment = current.word + next.word
    val tagSegment = current.tag + next.tag
    val begin = if (current.begin <= next.begin) current.begin else next.begin
    val end = begin + wordSegment.length - 1
    val currentIndexValue = current.metadata.getOrElse("index", "-1")
    val nextIndexValue = current.metadata.getOrElse("index", "-1")
    val index =
      if (currentIndexValue.toInt <= nextIndexValue.toInt) currentIndexValue else nextIndexValue
    IndexedTaggedWord(wordSegment, tagSegment, begin, end, None, Map("index" -> index))
  }

  /** Output Annotator Types: TOKEN
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = TOKEN

  /** Input Annotator Types: DOCUMENT
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT)
}

private case class RegexTagsInfo(tagsMatch: String, start: Int, end: Int, estimatedIndex: Int)

trait ReadablePretrainedWordSegmenter
    extends ParamsAndFeaturesReadable[WordSegmenterModel]
    with HasPretrained[WordSegmenterModel] {
  override val defaultModelName: Some[String] = Some("wordseg_pku")
  override val defaultLang: String = "zh"

  /** Java compliant-overrides */
  override def pretrained(): WordSegmenterModel = super.pretrained()

  override def pretrained(name: String): WordSegmenterModel = super.pretrained(name)

  override def pretrained(name: String, lang: String): WordSegmenterModel =
    super.pretrained(name, lang)

  override def pretrained(name: String, lang: String, remoteLoc: String): WordSegmenterModel =
    super.pretrained(name, lang, remoteLoc)
}

/** This is the companion object of [[WordSegmenterModel]]. Please refer to that class for the
  * documentation.
  */
object WordSegmenterModel extends ReadablePretrainedWordSegmenter




© 2015 - 2024 Weber Informatics LLC | Privacy Policy