All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.TextMatcherModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.collections.SearchTrie
import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.serialization.StructFeature
import org.apache.spark.ml.param.{BooleanParam, Param}
import org.apache.spark.ml.util.Identifiable

import scala.annotation.{tailrec => tco}
import scala.collection.mutable.ArrayBuffer

/** Instantiated model of the [[TextMatcher]]. For usage and examples see the documentation of the
  * main class.
  *
  * @param uid
  *   internally renquired UID to make it writable
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class TextMatcherModel(override val uid: String)
    extends AnnotatorModel[TextMatcherModel]
    with HasSimpleAnnotate[TextMatcherModel] {

  /** Output annotator type : CHUNK
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = CHUNK

  /** input annotator type : DOCUMENT, TOKEN
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT, TOKEN)

  /** searchTrie for Searching words
    *
    * @group param
    */
  val searchTrie = new StructFeature[SearchTrie](this, "searchTrie")

  /** whether to merge overlapping matched chunks. Defaults false
    *
    * @group param
    */
  val mergeOverlapping = new BooleanParam(
    this,
    "mergeOverlapping",
    "whether to merge overlapping matched chunks. Defaults false")

  /** Value for the entity metadata field
    *
    * @group param
    */
  val entityValue = new Param[String](this, "entityValue", "Value for the entity metadata field")

  /** Whether the TextMatcher should take the CHUNK from TOKEN or not
    *
    * @group param
    */
  val buildFromTokens = new BooleanParam(
    this,
    "buildFromTokens",
    "Whether the TextMatcher should take the CHUNK from TOKEN or not")

  /** SearchTrie of Tokens
    *
    * @group setParam
    */
  def setSearchTrie(value: SearchTrie): this.type = set(searchTrie, value)

  /** Whether to merge overlapping matched chunks. Defaults false
    *
    * @group setParam
    */
  def setMergeOverlapping(v: Boolean): this.type = set(mergeOverlapping, v)

  /** Whether to merge overlapping matched chunks. Defaults false
    *
    * @group getParam
    */
  def getMergeOverlapping: Boolean = $(mergeOverlapping)

  /** Setter for Value for the entity metadata field
    *
    * @group setParam
    */
  def setEntityValue(v: String): this.type = set(entityValue, v)

  /** Getter for Value for the entity metadata field
    *
    * @group getParam
    */
  def getEntityValue: String = $(entityValue)

  /** internal constructor for writabale annotator */
  def this() = this(Identifiable.randomUID("ENTITY_EXTRACTOR"))

  /** Setter for buildFromTokens param
    *
    * @group setParam
    */
  def setBuildFromTokens(v: Boolean): this.type = set(buildFromTokens, v)

  /** Getter for buildFromTokens param
    *
    * @group getParam
    */
  def getBuildFromTokens: Boolean = $(buildFromTokens)

  setDefault(inputCols, Array(DOCUMENT, TOKEN))
  setDefault(mergeOverlapping, false)
  setDefault(entityValue, "entity")

  @tco final protected def collapse(
      rs: List[(Int, Int)],
      sep: List[(Int, Int)] = Nil): List[(Int, Int)] = rs match {
    case x :: y :: rest =>
      if (y._1 > x._2) collapse(y :: rest, x :: sep)
      else collapse((x._1, x._2 max y._2) :: rest, sep)
    case _ =>
      (rs ::: sep).reverse
  }
  protected def merge(rs: List[(Int, Int)]): List[(Int, Int)] = collapse(rs.sortBy(_._1))

  /** Searches entities and stores them in the annotation. Defines annotator phrase matching
    * depending on whether we are using SBD or not
    *
    * @return
    *   Extracted Entities
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {

    val result = ArrayBuffer[Annotation]()

    val sentences = annotations.filter(_.annotatorType == AnnotatorType.DOCUMENT)

    sentences.zipWithIndex.foreach { case (sentence, sentenceIndex) =>
      val tokens = annotations.filter(token =>
        token.annotatorType == AnnotatorType.TOKEN &&
          token.begin >= sentence.begin &&
          token.end <= sentence.end)

      val foundTokens = $$(searchTrie).search(tokens.map(_.result)).toList

      val finalTokens = if ($(mergeOverlapping)) merge(foundTokens) else foundTokens

      for ((begin, end) <- finalTokens) {

        val firstTokenBegin = tokens(begin).begin
        val lastTokenEnd = tokens(end).end

        /** token indices are not relative to sentence but to document, adjust offset accordingly
          */
        val normalizedText =
          if (! $(buildFromTokens))
            sentence.result
              .substring(firstTokenBegin - sentence.begin, lastTokenEnd - sentence.begin + 1)
          else
            tokens
              .filter(t => t.begin >= firstTokenBegin && t.end <= lastTokenEnd)
              .map(_.result)
              .mkString(" ")

        val annotation = Annotation(
          outputAnnotatorType,
          firstTokenBegin,
          lastTokenEnd,
          normalizedText,
          Map(
            "entity" -> $(entityValue),
            "sentence" -> sentenceIndex.toString,
            "chunk" -> result.length.toString))

        result.append(annotation)
      }
    }

    result
  }

}

trait ReadablePretrainedTextMatcher
    extends ParamsAndFeaturesReadable[TextMatcherModel]
    with HasPretrained[TextMatcherModel] {
  override val defaultModelName = None
  override def pretrained(): TextMatcherModel = super.pretrained()
  override def pretrained(name: String): TextMatcherModel = super.pretrained(name)
  override def pretrained(name: String, lang: String): TextMatcherModel =
    super.pretrained(name, lang)
  override def pretrained(name: String, lang: String, remoteLoc: String): TextMatcherModel =
    super.pretrained(name, lang, remoteLoc)
}

/** This is the companion object of [[TextMatcherModel]]. Please refer to that class for the
  * documentation.
  */
object TextMatcherModel extends ReadablePretrainedTextMatcher




© 2015 - 2024 Weber Informatics LLC | Privacy Policy