All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.ner.dl.ZeroShotNerModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2023 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.ner.dl

import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification}
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel}
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN}
import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, HasPretrained, ParamsAndFeaturesReadable}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.param.{FloatParam, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.SparkSession

import java.util
import scala.collection.JavaConverters._

/** ZeroShotNerModel implements zero shot named entity recognition by utilizing RoBERTa
  * transformer models fine tuned on a question answering task.
  *
  * Its input is a list of document annotations and it automatically generates questions which are
  * used to recognize entities. The definitions of entities is given by a dictionary structures,
  * specifying a set of questions for each entity. The model is based on
  * RoBertaForQuestionAnswering.
  *
  * For more extended examples see the
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/named-entity-recognition/ZeroShot_NER.ipynb Examples]]
  *
  * Pretrained models can be loaded with `pretrained` of the companion object:
  * {{{
  * val zeroShotNer = ZeroShotNerModel.pretrained()
  *   .setInputCols("document")
  *   .setOutputCol("zer_shot_ner")
  * }}}
  *
  * For available pretrained models please see the
  * [[https://sparknlp.org/models?task=Zero-Shot-NER Models Hub]].
  *
  * ==Example==
  * {{{
  *  val documentAssembler = new DocumentAssembler()
  *    .setInputCol("text")
  *    .setOutputCol("document")
  *
  *  val sentenceDetector = new SentenceDetector()
  *    .setInputCols(Array("document"))
  *    .setOutputCol("sentences")
  *
  *  val zeroShotNer = ZeroShotNerModel
  *    .pretrained()
  *    .setEntityDefinitions(
  *      Map(
  *        "NAME" -> Array("What is his name?", "What is her name?"),
  *        "CITY" -> Array("Which city?")))
  *    .setPredictionThreshold(0.01f)
  *    .setInputCols("sentences")
  *    .setOutputCol("zero_shot_ner")
  *
  *  val pipeline = new Pipeline()
  *    .setStages(Array(
  *      documentAssembler,
  *      sentenceDetector,
  *      zeroShotNer))
  *
  *  val model = pipeline.fit(Seq("").toDS.toDF("text"))
  *  val results = model.transform(
  *    Seq("Clara often travels between New York and Paris.").toDS.toDF("text"))
  *
  *  results
  *    .selectExpr("document", "explode(zero_shot_ner) AS entity")
  *    .select(
  *      col("entity.result"),
  *      col("entity.metadata.word"),
  *      col("entity.metadata.sentence"),
  *      col("entity.begin"),
  *      col("entity.end"),
  *      col("entity.metadata.confidence"),
  *      col("entity.metadata.question"))
  *    .show(truncate=false)
  *
  * +------+-----+--------+-----+---+----------+------------------+
  * |result|word |sentence|begin|end|confidence|question          |
  * +------+-----+--------+-----+---+----------+------------------+
  * |B-CITY|Paris|0       |41   |45 |0.78655756|Which is the city?|
  * |B-CITY|New  |0       |28   |30 |0.29346612|Which city?       |
  * |I-CITY|York |0       |32   |35 |0.29346612|Which city?       |
  * +------+-----+--------+-----+---+----------+------------------+
  *
  * }}}
  *
  * @see
  *   [[https://arxiv.org/abs/1907.11692]] for details about the RoBERTa transformer
  * @see
  *   [[RoBertaForQuestionAnswering]] for the SparkNLP implementation of RoBERTa question
  *   answering
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class ZeroShotNerModel(override val uid: String) extends RoBertaForQuestionAnswering {

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  def this() = this(Identifiable.randomUID("ZeroShotNerModel"))

  /** Input Annotator Types: DOCUMENT
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)

  /** Output Annotator Types: NAMED_ENTITY
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = NAMED_ENTITY

  /** List of definitions of named entities
    *
    * @group param
    */
  private val entityDefinitions = new MapFeature[String, Array[String]](this, "entityDefinitions")

  /** Set definitions of named entities
    *
    * @group setParam
    */
  def setEntityDefinitions(definitions: Map[String, Array[String]]): this.type = {
    set(this.entityDefinitions, definitions)
  }

  /** Set definitions of named entities
    *
    * @group setParam
    */
  def setEntityDefinitions(definitions: util.HashMap[String, util.List[String]]): this.type = {
    val c = definitions.asScala.mapValues(_.asScala.toList.toArray).toMap
    set(this.entityDefinitions, c)
  }

  /** Get definitions of named entities
    *
    * @group getParam
    */
  private def getEntityDefinitions: scala.collection.immutable.Map[String, Array[String]] = {
    if (!entityDefinitions.isSet)
      return Map.empty
    $$(entityDefinitions)
  }

  def getEntityDefinitionsStr: Array[String] = {
    getEntityDefinitions.map(x => x._1 + "@@@" + x._2.mkString("@@@")).toArray
  }

  var predictionThreshold =
    new FloatParam(this, "predictionThreshold", "Minimal score of predicted entity")

  var ignoreEntities = new StringArrayParam(this, "ignoreEntities", "List of entities to ignore")

  /** Get the minimum entity prediction score
    *
    * @group getParam
    */
  def getPredictionThreshold: Float = $(predictionThreshold)

  /** Set the minimum entity prediction score
    *
    * @group setParam
    */
  def setPredictionThreshold(value: Float): this.type = set(this.predictionThreshold, value)

  /** Get the list of questions to catch the distractor entity
    *
    * @group getParam
    */
  def getIgnoreEntities: Array[String] = $(ignoreEntities)

  /** Get the list of entities which are recognized
    *
    * @group getParam
    */

  def getEntities: Array[String] = getEntityDefinitions.keys.toArray

  /** Set the list of questions to catch the distractor entity
    *
    * @group setParam
    */
  def setIgnoreEntities(value: Array[String]): this.type = set(this.ignoreEntities, value)

  private def getNerQuestionAnnotations
      : scala.collection.immutable.Map[String, Array[Annotation]] = {
    getEntityDefinitions.map(nerDef => {
      (
        nerDef._1,
        nerDef._2.map(nerQ =>
          new Annotation(
            AnnotatorType.DOCUMENT,
            0,
            nerQ.length,
            nerQ,
            Map("entity" -> nerDef._1) ++ Map("ner_question" -> nerQ))))
    })
  }

  private var _model: Option[Broadcast[ZeroShotNerClassification]] = None

  override def setModelIfNotSet(
      spark: SparkSession,
      tensorflowWrapper: Option[TensorflowWrapper],
      onnxWrapper: Option[OnnxWrapper]): ZeroShotNerModel = {
    if (_model.isEmpty) {
      _model = Some(
        spark.sparkContext.broadcast(
          new ZeroShotNerClassification(
            tensorflowWrapper,
            onnxWrapper,
            sentenceStartTokenId,
            sentenceEndTokenId,
            padTokenId,
            false,
            configProtoBytes = getConfigProtoBytes,
            tags = Map.empty[String, Int],
            signatures = getSignatures,
            $$(merges),
            $$(vocabulary))))
    }

    this
  }

  override def getModelIfNotSet: RoBertaClassification = _model.get.value

  setDefault(ignoreEntities -> Array(), predictionThreshold -> 0.01f)

  val maskSymbol = "_"

  private def spansOverlap(span1: (Int, Int), span2: (Int, Int)): Boolean = {
    !((span2._1 > span1._2) || (span1._1 > span2._2))
  }

  private def recognizeEntities(
      document: Annotation,
      nerDefs: Map[String, Array[Annotation]]): Seq[Annotation] = {
    val docPredictions = nerDefs
      .flatMap(nerDef => {
        val nerBatch = nerDef._2.map(nerQuestion => Array(nerQuestion, document)).toSeq
        val entityPredictions = super
          .batchAnnotate(nerBatch)
          .zip(nerBatch.map(_.head.result))
          .map(x => (x._1.head, x._2))
          .filter(x => x._1.result.nonEmpty)
          .filter(x =>
            (if (x._1.metadata.contains("score"))
               x._1.metadata("score").toFloat
             else Float.MinValue) > getPredictionThreshold)
        entityPredictions.map(prediction =>
          new Annotation(
            AnnotatorType.CHUNK,
            prediction._1.begin,
            prediction._1.end,
            prediction._1.result,
            Map(
              "entity" -> nerDef._1,
              "sentence" -> document.metadata("sentence"),
              "word" -> prediction._1.result,
              "confidence" -> prediction._1.metadata("score"),
              "question" -> prediction._2)))
      })
      .toSeq
    // Detect overlapping predictions and choose the one with the higher score
    docPredictions
      .filter(x => ! $(ignoreEntities).contains(x.metadata("entity"))) // Discard ignored entities
      .filter(prediction => {
        !docPredictions
          .filter(_ != prediction)
          .exists(otherPrediction => {
            spansOverlap(
              (prediction.begin, prediction.end),
              (otherPrediction.begin, otherPrediction.end)) && (otherPrediction.metadata(
              "confidence") > prediction.metadata("confidence"))
          })
      })

  }

  def maskEntity(document: Annotation, entity: Annotation): String = {
    val entityStart = entity.begin - document.begin
    val entityEnd = entity.end - document.begin
    //    println(document.result.slice(0, entityStart) + maskSymbol + {entityStart to entityEnd - 2}.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength)))
    document.result.slice(0, entityStart) + maskSymbol + {
      entityStart to entityEnd - 2
    }.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength))
  }

  def recognizeMultipleEntities(
      document: Annotation,
      nerDefs: Map[String, Array[Annotation]],
      recognizedEntities: Seq[Annotation] = Seq()): Seq[Annotation] = {
    val newEntities = recognizeEntities(document, nerDefs)
      .filter(entity =>
        (!recognizedEntities
          .exists(recognizedEntity =>
            spansOverlap(
              (entity.begin, entity.end),
              (recognizedEntity.begin, recognizedEntity.end)))) && (entity.result != maskSymbol))

    newEntities ++ newEntities.flatMap { entity =>
      val newDoc = new Annotation(
        document.annotatorType,
        document.begin,
        document.end,
        maskEntity(document, entity),
        document.metadata)
      recognizeMultipleEntities(
        newDoc,
        nerDefs.filter(x => x._1 == entity.metadata("entity")),
        recognizedEntities ++ newEntities)
    }
  }

  def isTokenInEntity(token: Annotation, entity: Annotation): Boolean = {
    (
      token.metadata("sentence") == entity.metadata("sentence")
      && (token.begin >= entity.begin) && (token.end <= entity.end)
    )
  }

  override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {

    batchedAnnotations.map(annotations => {
      val documents = annotations
        .filter(_.annotatorType == AnnotatorType.DOCUMENT)
        .toSeq
      val tokens = annotations.filter(_.annotatorType == AnnotatorType.TOKEN)
      val entities = documents.flatMap { doc =>
        recognizeMultipleEntities(doc, getNerQuestionAnnotations).flatMap { entity =>
          tokens
            .filter(t => isTokenInEntity(t, entity))
            .zipWithIndex
            .map { case (token, i) =>
              val bioPrefix = if (i == 0) "B-" else "I-"
              new Annotation(
                annotatorType = AnnotatorType.NAMED_ENTITY,
                begin = token.begin,
                end = token.end,
                result = bioPrefix + entity.metadata("entity"),
                metadata = Map(
                  "sentence" -> entity.metadata("sentence"),
                  "word" -> token.result,
                  "confidence" -> entity.metadata("confidence"),
                  "question" -> entity.metadata("question")))
            }
        }.toList
      }
      tokens
        .map(token => {
          val entity = entities.find(e => isTokenInEntity(token, e))
          if (entity.nonEmpty) {
            entity.get
          } else {
            new Annotation(
              annotatorType = AnnotatorType.NAMED_ENTITY,
              begin = token.begin,
              end = token.end,
              result = "O",
              metadata = Map("sentence" -> token.metadata("sentence"), "word" -> token.result))
          }
        })
        .toSeq
    })
  }

}

trait ReadablePretrainedZeroShotNer
    extends ParamsAndFeaturesReadable[ZeroShotNerModel]
    with HasPretrained[ZeroShotNerModel] {
  override val defaultModelName: Some[String] = Some("zero_shot_ner_roberta")

  /** Java compliant-overrides */
  override def pretrained(): ZeroShotNerModel =
    pretrained(defaultModelName.get, defaultLang, defaultLoc)

  override def pretrained(name: String): ZeroShotNerModel =
    pretrained(name, defaultLang, defaultLoc)

  override def pretrained(name: String, lang: String): ZeroShotNerModel =
    pretrained(name, lang, defaultLoc)

  override def pretrained(name: String, lang: String, remoteLoc: String): ZeroShotNerModel = {
    try {
      ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
        ResourceDownloader
          .downloadModel(RoBertaForQuestionAnswering, name, Option(lang), remoteLoc))
    } catch {
      case _: java.lang.RuntimeException =>
        ResourceDownloader.downloadModel(ZeroShotNerModel, name, Option(lang), remoteLoc)
    }
  }

  override def load(path: String): ZeroShotNerModel = {
    try {
      super.load(path)
    } catch {
      case e: java.lang.ClassCastException =>
        try {
          ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
            RoBertaForQuestionAnswering.load(path))
        } catch {
          case _: Throwable => throw e
        }
    }
  }
}

trait ReadZeroShotNerDLModel extends ReadTensorflowModel with ReadOnnxModel {
  this: ParamsAndFeaturesReadable[ZeroShotNerModel] =>

  override val tfFile: String = "roberta_classification_tensorflow"
  override val onnxFile: String = "roberta_classification_onnx"

  def readModel(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {
    instance.getEngine match {
      case TensorFlow.name => {
        val tfWrapper =
          readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
        instance.setModelIfNotSet(spark, Some(tfWrapper), None)
      }
      case ONNX.name => {
        val onnxWrapper = readOnnxModel(
          path,
          spark,
          "_roberta_classification_onnx",
          zipped = true,
          useBundle = false,
          None)
        instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
      }
      case _ =>
        throw new Exception(notSupportedEngineError)
    }
  }

  addReader(readModel)
}
object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel {

  def apply(model: PipelineStage): PipelineStage = {
    model match {
      case answering: RoBertaForQuestionAnswering if !model.isInstanceOf[ZeroShotNerModel] =>
        getFromRoBertaForQuestionAnswering(answering)
      case _ =>
        model
    }
  }

  def getFromRoBertaForQuestionAnswering(model: RoBertaForQuestionAnswering): ZeroShotNerModel = {
    val spark = SparkSession.builder.getOrCreate()

    val newModel = new ZeroShotNerModel()
      .setVocabulary(
        model.vocabulary.get.getOrElse(throw new RuntimeException("Vocabulary not set")))
      .setMerges(model.merges.get.getOrElse(throw new RuntimeException("Merges not set")))
      .setCaseSensitive(model.getCaseSensitive)
      .setBatchSize(model.getBatchSize)

    if (model.signatures.isSet)
      newModel.setSignatures(
        model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set")))

    model.getEngine match {
      case TensorFlow.name =>
        newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
      case ONNX.name =>
        newModel.setModelIfNotSet(spark, None, model.getModelIfNotSet.onnxWrapper)
    }

    model
      .extractParamMap()
      .toSeq
      .foreach(x => {
        newModel.set(x.param.name, x.value)
      })

    newModel
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy