All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.ner.NerOverwriter.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.ner

import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, HasSimpleAnnotate}
import org.apache.spark.ml.param.{Param, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}

import java.util
import scala.collection.JavaConverters._

/** Overwrites entities of specified strings.
  *
  * The input for this Annotator have to be entities that are already extracted, Annotator type
  * `NAMED_ENTITY`. The strings specified with `setStopWords` will have new entities assigned to,
  * specified with `setNewResult`.
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel
  * import com.johnsnowlabs.nlp.annotators.ner.NerOverwriter
  * import org.apache.spark.ml.Pipeline
  *
  * // First extract the prerequisite Entities
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val sentence = new SentenceDetector()
  *   .setInputCols("document")
  *   .setOutputCol("sentence")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols("sentence")
  *   .setOutputCol("token")
  *
  * val embeddings = WordEmbeddingsModel.pretrained()
  *   .setInputCols("sentence", "token")
  *   .setOutputCol("embeddings")
  *
  * val nerTagger = NerDLModel.pretrained()
  *   .setInputCols("sentence", "token", "embeddings")
  *   .setOutputCol("ner")
  *
  * val pipeline = new Pipeline().setStages(Array(
  *   documentAssembler,
  *   sentence,
  *   tokenizer,
  *   embeddings,
  *   nerTagger
  * ))
  *
  * val data = Seq("Spark NLP Crosses Five Million Downloads, John Snow Labs Announces.").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.selectExpr("explode(ner)").show(false)
  * /*
  * +------------------------------------------------------+
  * |col                                                   |
  * +------------------------------------------------------+
  * |[named_entity, 0, 4, B-ORG, [word -> Spark], []]      |
  * |[named_entity, 6, 8, I-ORG, [word -> NLP], []]        |
  * |[named_entity, 10, 16, O, [word -> Crosses], []]      |
  * |[named_entity, 18, 21, O, [word -> Five], []]         |
  * |[named_entity, 23, 29, O, [word -> Million], []]      |
  * |[named_entity, 31, 39, O, [word -> Downloads], []]    |
  * |[named_entity, 40, 40, O, [word -> ,], []]            |
  * |[named_entity, 42, 45, B-ORG, [word -> John], []]     |
  * |[named_entity, 47, 50, I-ORG, [word -> Snow], []]     |
  * |[named_entity, 52, 55, I-ORG, [word -> Labs], []]     |
  * |[named_entity, 57, 65, I-ORG, [word -> Announces], []]|
  * |[named_entity, 66, 66, O, [word -> .], []]            |
  * +------------------------------------------------------+
  * */
  * // The recognized entities can then be overwritten
  * val nerOverwriter = new NerOverwriter()
  *   .setInputCols("ner")
  *   .setOutputCol("ner_overwritten")
  *   .setNerWords(Array("Million"))
  *   .setNerNewEntity,("B-CARDINAL")
  *
  * nerOverwriter.transform(result).selectExpr("explode(ner_overwritten)").show(false)
  * +---------------------------------------------------------+
  * |col                                                      |
  * +---------------------------------------------------------+
  * |[named_entity, 0, 4, B-ORG, [word -> Spark], []]         |
  * |[named_entity, 6, 8, I-ORG, [word -> NLP], []]           |
  * |[named_entity, 10, 16, O, [word -> Crosses], []]         |
  * |[named_entity, 18, 21, O, [word -> Five], []]            |
  * |[named_entity, 23, 29, B-CARDINAL, [word -> Million], []]|
  * |[named_entity, 31, 39, O, [word -> Downloads], []]       |
  * |[named_entity, 40, 40, O, [word -> ,], []]               |
  * |[named_entity, 42, 45, B-ORG, [word -> John], []]        |
  * |[named_entity, 47, 50, I-ORG, [word -> Snow], []]        |
  * |[named_entity, 52, 55, I-ORG, [word -> Labs], []]        |
  * |[named_entity, 57, 65, I-ORG, [word -> Announces], []]   |
  * |[named_entity, 66, 66, O, [word -> .], []]               |
  * +---------------------------------------------------------+
  * }}}
  *
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class NerOverwriter(override val uid: String)
    extends AnnotatorModel[NerOverwriter]
    with HasSimpleAnnotate[NerOverwriter] {

  import com.johnsnowlabs.nlp.AnnotatorType.NAMED_ENTITY

  /** Output Annotator Type : NAMED_ENTITY
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = NAMED_ENTITY

  /** Input Annotator Type : NAMED_ENTITY
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(NAMED_ENTITY)

  def this() = this(Identifiable.randomUID("NER_OVERWRITER"))

  /** The words to be filtered out.
    *
    * @group param
    */
  val nerWords: StringArrayParam =
    new StringArrayParam(
      this,
      "nerWords",
      "The words to be filtered out to replace the current NER entity with a new NER Entity")

  /** The words to be filtered out.
    *
    * @group setParam
    */
  def setNerWords(value: Array[String]): this.type = set(nerWords, value)

  /** The words to be filtered out.
    *
    * @group getParam
    */
  def getNerWords: Array[String] = $(nerWords)

  /** New NER class to overwrite
    *
    * @group param
    */
  val newNerEntity: Param[String] = new Param(this, "newNerEntity", "New NER label to overwrite")

  /** New NER class to overwrite
    *
    * @group setParam
    */
  def setNewNerEntity(r: String): this.type = {
    set(newNerEntity, r)
  }

  /** New NER class to overwrite
    *
    * @group getParam
    */
  def getNewNerEntity: String = $(newNerEntity)

  val replaceEntities: MapFeature[String, String] =
    new MapFeature[String, String](this, "replaceEntities")

  def setReplaceEntities(w: Map[String, String]): this.type = set(replaceEntities, w)

  // for Python access
  /** @group setParam */
  def setReplaceEntities(w: util.HashMap[String, String]): this.type = {

    val ws = w.asScala.toMap
    set(replaceEntities, ws)
  }

  def getReplaceEntities: Map[String, String] = {
    if (!replaceEntities.isSet) {
      Map.empty[String, String]
    } else {
      $$(replaceEntities)
    }
  }

  setDefault(newNerEntity -> "I-OVERWRITE", nerWords -> Array())

  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {

    val annotationsOverwritten = annotations
    val replace = getReplaceEntities
    annotationsOverwritten
      .map { tokenAnnotation =>
        val nerWordsSet = $(nerWords).toSet
        if (nerWordsSet.contains(tokenAnnotation.metadata("word"))) {
          Annotation(
            outputAnnotatorType,
            tokenAnnotation.begin,
            tokenAnnotation.end,
            $(newNerEntity),
            tokenAnnotation.metadata)
        } else {
          Annotation(
            outputAnnotatorType,
            tokenAnnotation.begin,
            tokenAnnotation.end,
            tokenAnnotation.result,
            tokenAnnotation.metadata)
        }
      }
      .map { ann =>
        ann.copy(result = replace.getOrElse(ann.result, ann.result))
      }

  }

}

/** This is the companion object of [[NerOverwriter]]. Please refer to that class for the
  * documentation.
  */
object NerOverwriter extends DefaultParamsReadable[NerOverwriter]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy