
com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfModel.scala Maven / Gradle / Ivy
package com.johnsnowlabs.nlp.annotators.ner.crf
import com.johnsnowlabs.ml.crf.{FbCalculator, LinearChainCrfModel, VectorMath}
import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp.annotators.common.Annotated.{NerTaggedSentence, PosTaggedSentence}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, ParamsAndFeaturesReadable}
import org.apache.spark.ml.param.{BooleanParam, StringArrayParam}
import org.apache.spark.ml.util._
/*
Named Entity Recognition model
*/
class NerCrfModel(override val uid: String) extends AnnotatorModel[NerCrfModel] {
def this() = this(Identifiable.randomUID("NER"))
val entities = new StringArrayParam(this, "entities", "List of Entities to recognize")
val model: StructFeature[LinearChainCrfModel] = new StructFeature[LinearChainCrfModel](this, "crfModel")
val dictionaryFeatures: MapFeature[String, String] = new MapFeature[String, String](this, "dictionaryFeatures")
val includeConfidence = new BooleanParam(this, "includeConfidence", "whether or not to calculate prediction confidence by token, includes in metadata")
def setModel(crf: LinearChainCrfModel): NerCrfModel = set(model, crf)
def setDictionaryFeatures(dictFeatures: DictionaryFeatures): this.type = set(dictionaryFeatures, dictFeatures.dict)
def setEntities(toExtract: Array[String]): NerCrfModel = set(entities, toExtract)
def setIncludeConfidence(c: Boolean): this.type = set(includeConfidence, c)
setDefault(dictionaryFeatures, () => Map.empty[String, String])
setDefault(includeConfidence, false)
/**
Predicts Named Entities in input sentences
* @param sentences POS tagged sentences.
* @return sentences with recognized Named Entities
*/
def tag(sentences: Seq[(PosTaggedSentence, WordpieceEmbeddingsSentence)]): Seq[NerTaggedSentence] = {
require(model.isSet, "model must be set before tagging")
val crf = $$(model)
val fg = FeatureGenerator(new DictionaryFeatures($$(dictionaryFeatures)))
sentences.map{case (sentence, withEmbeddings) =>
val instance = fg.generate(sentence, withEmbeddings, crf.metadata)
val confidenceValues = {
get(includeConfidence).filter(_ == true).map(_ => {
val fb = new FbCalculator(instance.items.length, crf.metadata)
fb.calculate(instance, $$(model).weights, 1)
fb.alpha
})
}
val labelIds = crf.predict(instance)
val words = sentence.indexedTaggedWords
.zip(labelIds.labels)
.zipWithIndex
.flatMap{case ((word, labelId), idx) =>
val label = crf.metadata.labels(labelId)
val alpha = if ($(includeConfidence)) {
Some(confidenceValues.get(idx).max)
} else None
if (!isDefined(entities) || $(entities).isEmpty || $(entities).contains(label)) {
Some(IndexedTaggedWord(word.word, label, word.begin, word.end, alpha))
}
else {
None
}
}
TaggedSentence(words)
}
}
override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
val sourceSentences = PosTagged.unpack(annotations)
val withEmbeddings = WordpieceEmbeddingsSentence.unpack(annotations)
val taggedSentences = tag(sourceSentences.zip(withEmbeddings))
NerTagged.pack(taggedSentences)
}
def shrink(minW: Float): NerCrfModel = set(model, $$(model).shrink(minW))
override val inputAnnotatorTypes = Array(DOCUMENT, TOKEN, POS, WORD_EMBEDDINGS)
override val outputAnnotatorType: AnnotatorType = NAMED_ENTITY
}
trait PretrainedNerCrf {
def pretrained(name: String = "ner_crf", lang: String = "en", remoteLoc: String = ResourceDownloader.publicLoc): NerCrfModel =
ResourceDownloader.downloadModel(NerCrfModel, name, Option(lang), remoteLoc)
}
object NerCrfModel extends ParamsAndFeaturesReadable[NerCrfModel] with PretrainedNerCrf
© 2015 - 2025 Weber Informatics LLC | Privacy Policy