All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.er.EntityRulerModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.er

import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT, TOKEN}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.StructFeature
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, HasPretrained, HasSimpleAnnotate}
import com.johnsnowlabs.storage.Database.{ENTITY_PATTERNS, ENTITY_REGEX_PATTERNS, Name}
import com.johnsnowlabs.storage._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{BooleanParam, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.slf4j.{Logger, LoggerFactory}

/** Instantiated model of the [[EntityRulerApproach]]. For usage and examples see the
  * documentation of the main class.
  *
  * @param uid
  *   internally renquired UID to make it writable
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class EntityRulerModel(override val uid: String)
    extends AnnotatorModel[EntityRulerModel]
    with HasSimpleAnnotate[EntityRulerModel]
    with HasStorageModel {

  def this() = this(Identifiable.randomUID("ENTITY_RULER"))

  private val logger: Logger = LoggerFactory.getLogger("Credentials")

  @deprecated("Enabling pattern regex now is define on each pattern", "Since 4.2.0")
  private[er] val enablePatternRegex =
    new BooleanParam(this, "enablePatternRegex", "Enables regex pattern match")

  private[er] val useStorage =
    new BooleanParam(this, "useStorage", "Whether to use RocksDB storage to serialize patterns")

  private[er] val regexEntities =
    new StringArrayParam(this, "regexEntities", "entities defined in regex patterns")

  private[er] val entityRulerFeatures: StructFeature[EntityRulerFeatures] =
    new StructFeature[EntityRulerFeatures](
      this,
      "Structure to store data when RocksDB is not used")

  private[er] val sentenceMatch = new BooleanParam(
    this,
    "sentenceMatch",
    "Whether to find match at sentence level (regex only). True: sentence level. False: token level")

  private[er] val ahoCorasickAutomaton: StructFeature[Option[AhoCorasickAutomaton]] =
    new StructFeature[Option[AhoCorasickAutomaton]](this, "AhoCorasickAutomaton")

  @deprecated("Enabling pattern regex now is define on each pattern", "Since 4.2.0")
  private[er] def setEnablePatternRegex(value: Boolean): this.type =
    set(enablePatternRegex, value)

  private[er] def setRegexEntities(value: Array[String]): this.type = set(regexEntities, value)

  private[er] def setEntityRulerFeatures(value: EntityRulerFeatures): this.type =
    set(entityRulerFeatures, value)

  private[er] def setUseStorage(value: Boolean): this.type = set(useStorage, value)

  private[er] def setSentenceMatch(value: Boolean): this.type = set(sentenceMatch, value)

  private[er] def setAhoCorasickAutomaton(value: Option[AhoCorasickAutomaton]): this.type =
    set(ahoCorasickAutomaton, value)

  private var automatonModel: Option[Broadcast[AhoCorasickAutomaton]] = None

  def setAutomatonModelIfNotSet(
      spark: SparkSession,
      automaton: Option[AhoCorasickAutomaton]): this.type = {
    if (automatonModel.isEmpty && automaton.isDefined) {
      automatonModel = Some(spark.sparkContext.broadcast(automaton.get))
    }
    this
  }

  def getAutomatonModelIfNotSet: Option[AhoCorasickAutomaton] = {
    if (automatonModel.isDefined) {
      Some(automatonModel.get.value)
    } else {
      if ($$(ahoCorasickAutomaton).isDefined) $$(ahoCorasickAutomaton) else None
    }
  }

  setDefault(useStorage -> false, caseSensitive -> true)

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  val inputAnnotatorTypes: Array[String] = Array(DOCUMENT)
  override val optionalInputAnnotatorTypes: Array[String] = Array(TOKEN)
  val outputAnnotatorType: AnnotatorType = CHUNK

  override def _transform(
      dataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): DataFrame = {
    if ($(regexEntities).nonEmpty) {
      val structFields = dataset.schema.fields
        .filter(field => field.metadata.contains("annotatorType"))
        .filter(field => field.metadata.getString("annotatorType") == TOKEN)
      if (structFields.isEmpty) {
        throw new IllegalArgumentException(
          s"Missing $TOKEN annotator. Regex patterns requires it in your pipeline")
      } else {
        super._transform(dataset, recursivePipeline)
      }
    } else {
      super._transform(dataset, recursivePipeline)
    }
  }

  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
    this.setAutomatonModelIfNotSet(dataset.sparkSession, $$(ahoCorasickAutomaton))
    dataset
  }

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    var annotatedEntitiesByKeywords: Seq[Annotation] = Seq()
    val sentences = SentenceSplit.unpack(annotations)
    val annotatedEntitiesByRegex = computeAnnotatedEntitiesByRegex(annotations, sentences)

    if (getAutomatonModelIfNotSet.isDefined) {
      annotatedEntitiesByKeywords = sentences.flatMap { sentence =>
        getAutomatonModelIfNotSet.get.searchPatternsInText(sentence)
      }
    }

    annotatedEntitiesByRegex ++ annotatedEntitiesByKeywords
  }

  private def computeAnnotatedEntitiesByRegex(
      annotations: Seq[Annotation],
      sentences: Seq[Sentence]): Seq[Annotation] = {
    if ($(regexEntities).nonEmpty) {
      val regexPatternsReader =
        if ($(useStorage))
          Some(getReader(Database.ENTITY_REGEX_PATTERNS).asInstanceOf[RegexPatternsReader])
        else None

      if ($(sentenceMatch)) {
        annotateEntitiesFromRegexPatternsBySentence(sentences, regexPatternsReader)
      } else {
        val tokenizedWithSentences = TokenizedWithSentence.unpack(annotations)
        annotateEntitiesFromRegexPatterns(tokenizedWithSentences, regexPatternsReader)
      }
    } else Seq()
  }

  private def annotateEntitiesFromRegexPatterns(
      tokenizedWithSentences: Seq[TokenizedSentence],
      regexPatternsReader: Option[RegexPatternsReader]): Seq[Annotation] = {

    val annotatedEntities = tokenizedWithSentences.flatMap { tokenizedWithSentence =>
      tokenizedWithSentence.indexedTokens.flatMap { indexedToken =>
        val entity = getMatchedEntity(indexedToken.token, regexPatternsReader)
        if (entity.isDefined) {
          val entityMetadata = getEntityMetadata(entity)
          Some(
            Annotation(
              CHUNK,
              indexedToken.begin,
              indexedToken.end,
              indexedToken.token,
              entityMetadata ++ Map("sentence" -> tokenizedWithSentence.sentenceIndex.toString)))
        } else None
      }
    }

    annotatedEntities
  }

  private def getMatchedEntity(
      token: String,
      regexPatternsReader: Option[RegexPatternsReader]): Option[String] = {

    val matchesByEntity = $(regexEntities).flatMap { regexEntity =>
      val regexPatterns: Option[Seq[String]] = regexPatternsReader match {
        case Some(rpr) => rpr.lookup(regexEntity)
        case None => $$(entityRulerFeatures).regexPatterns.get(regexEntity)
      }
      if (regexPatterns.isDefined) {
        val matches = regexPatterns.get.flatMap(regexPattern => regexPattern.r.findFirstIn(token))
        if (matches.nonEmpty) Some(regexEntity) else None
      } else None
    }.toSeq

    if (matchesByEntity.size > 1) {
      logger.warn("More than one entity found. Sending the first element of the array")
    }

    matchesByEntity.headOption
  }

  private def getMatchedEntityBySentence(
      sentence: Sentence,
      regexPatternsReader: Option[RegexPatternsReader]): Array[(IndexedToken, String)] = {

    val matchesByEntity = $(regexEntities)
      .flatMap { regexEntity =>
        val regexPatterns: Option[Seq[String]] = regexPatternsReader match {
          case Some(rpr) => rpr.lookup(regexEntity)
          case None => $$(entityRulerFeatures).regexPatterns.get(regexEntity)
        }
        if (regexPatterns.isDefined) {

          val resultMatches = regexPatterns.get.flatMap { regexPattern =>
            val matchedResult = regexPattern.r.findFirstMatchIn(sentence.content)
            if (matchedResult.isDefined) {
              val begin = matchedResult.get.start + sentence.start
              val end = matchedResult.get.end + sentence.start - 1
              Some(matchedResult.get.toString(), begin, end, regexEntity)
            } else None
          }

          val intervals =
            resultMatches.map(resultMatch => List(resultMatch._2, resultMatch._3)).toList
          val mergedIntervals = EntityRulerUtil.mergeIntervals(intervals)
          val filteredMatches =
            resultMatches.filter(x => mergedIntervals.contains(List(x._2, x._3)))

          if (filteredMatches.nonEmpty) Some(filteredMatches) else None
        } else None
      }
      .flatten
      .sortBy(_._2)

    matchesByEntity.map(matches => (IndexedToken(matches._1, matches._2, matches._3), matches._4))
  }

  private def annotateEntitiesFromRegexPatternsBySentence(
      sentences: Seq[Sentence],
      patternsReader: Option[RegexPatternsReader]): Seq[Annotation] = {

    val annotatedEntities = sentences.flatMap { sentence =>
      val matchedEntities = getMatchedEntityBySentence(sentence, patternsReader)
      matchedEntities.map { case (indexedToken, label) =>
        val entityMetadata = getEntityMetadata(Some(label))
        Annotation(
          CHUNK,
          indexedToken.begin,
          indexedToken.end,
          indexedToken.token,
          entityMetadata ++ Map("sentence" -> sentence.index.toString))
      }
    }
    annotatedEntities
  }

  private def getEntityMetadata(labelData: Option[String]): Map[String, String] = {

    val entityMetadata = labelData.get
      .split(",")
      .zipWithIndex
      .flatMap { case (metadata, index) =>
        if (index == 0) {
          Map("entity" -> metadata)
        } else Map("id" -> metadata)
      }
      .toMap

    entityMetadata
  }

  override def deserializeStorage(path: String, spark: SparkSession): Unit = {
    if ($(useStorage)) {
      super.deserializeStorage(path: String, spark: SparkSession)
    }
  }

  override def onWrite(path: String, spark: SparkSession): Unit = {
    if ($(useStorage)) {
      super.onWrite(path, spark)
    }
  }

  protected val databases: Array[Name] = EntityRulerModel.databases

  protected def createReader(database: Name, connection: RocksDBConnection): StorageReader[_] = {
    database match {
      case Database.ENTITY_PATTERNS => new PatternsReader(connection)
      case Database.ENTITY_REGEX_PATTERNS => new RegexPatternsReader(connection)
    }
  }
}

trait ReadablePretrainedEntityRuler
    extends StorageReadable[EntityRulerModel]
    with HasPretrained[EntityRulerModel] {

  override val databases: Array[Name] = Array(ENTITY_PATTERNS, ENTITY_REGEX_PATTERNS)

  override val defaultModelName: Option[String] = None

  override def pretrained(): EntityRulerModel = super.pretrained()

  override def pretrained(name: String): EntityRulerModel = super.pretrained(name)

  override def pretrained(name: String, lang: String): EntityRulerModel =
    super.pretrained(name, lang)

  override def pretrained(name: String, lang: String, remoteLoc: String): EntityRulerModel =
    super.pretrained(name, lang, remoteLoc)

}

object EntityRulerModel extends ReadablePretrainedEntityRuler




© 2015 - 2024 Weber Informatics LLC | Privacy Policy