![JAR search and dependency download from the Maven repository](/logo.png)
com.johnsnowlabs.nlp.annotators.er.EntityRulerApproach.scala Maven / Gradle / Ivy
* Copyright 2017-2022 John Snow Labs
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package com.johnsnowlabs.nlp.annotators.er
import com.johnsnowlabs.nlp.AnnotatorApproach
import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT, TOKEN}
import com.johnsnowlabs.nlp.annotators.param.ExternalResourceParam
import com.johnsnowlabs.nlp.util.io.ResourceHelper.spark
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import com.johnsnowlabs.storage.Database.Name
import com.johnsnowlabs.storage._
import com.johnsnowlabs.util.JsonParser
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.BooleanParam
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions.{col, collect_set, concat, flatten, lit}
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
/** Fits an Annotator to match exact strings or regex patterns provided in a file against a
* Document and assigns them an named entity. The definitions can contain any number of named
* entities.
* There are multiple ways and formats to set the extraction resource. It is possible to set it
* either as a "JSON", "JSONL" or "CSV" file. A path to the file needs to be provided to
* `setPatternsResource`. The file format needs to be set as the "format" field in the `option`
* parameter map and depending on the file type, additional parameters might need to be set.
* If the file is in a JSON format, then the rule definitions need to be given in a list with the
* fields "id", "label" and "patterns":
* {{{
* [
* {
* "id": "person-regex",
* "label": "PERSON",
* "patterns": ["\\w+\\s\\w+", "\\w+-\\w+"]
* },
* {
* "id": "locations-words",
* "label": "LOCATION",
* "patterns": ["Winterfell"]
* }
* ]
* }}}
* The same fields also apply to a file in the JSONL format:
* {{{
* {"id": "names-with-j", "label": "PERSON", "patterns": ["Jon", "John", "John Snow"]}
* {"id": "names-with-s", "label": "PERSON", "patterns": ["Stark", "Snow"]}
* {"id": "names-with-e", "label": "PERSON", "patterns": ["Eddard", "Eddard Stark"]}
* }}}
* In order to use a CSV file, an additional parameter "delimiter" needs to be set. In this case,
* the delimiter might be set by using `.setPatternsResource("patterns.csv", ReadAs.TEXT,
* Map("format"->"csv", "delimiter" -> "\\|"))`
* {{{
* PERSON|John Snow
* LOCATION|Winterfell
* }}}
* ==Example==
* In this example, the entities file as the form of
* {{{
* PERSON|John Snow
* LOCATION|Winterfell
* }}}
* where each line represents an entity and the associated string delimited by "|".
* {{{
* import spark.implicits._
* import com.johnsnowlabs.nlp.base.DocumentAssembler
* import com.johnsnowlabs.nlp.annotators.Tokenizer
* import com.johnsnowlabs.nlp.annotators.er.EntityRulerApproach
* import com.johnsnowlabs.nlp.util.io.ReadAs
* import org.apache.spark.ml.Pipeline
* val documentAssembler = new DocumentAssembler()
* .setInputCol("text")
* .setOutputCol("document")
* val tokenizer = new Tokenizer()
* .setInputCols("document")
* .setOutputCol("token")
* val entityRuler = new EntityRulerApproach()
* .setInputCols("document", "token")
* .setOutputCol("entities")
* .setPatternsResource(
* path = "src/test/resources/entity-ruler/patterns.csv",
* readAs = ReadAs.TEXT,
* options = Map("format" -> "csv", "delimiter" -> "\\|")
* )
* val pipeline = new Pipeline().setStages(Array(
* documentAssembler,
* tokenizer,
* entityRuler
* ))
* val data = Seq("Jon Snow wants to be lord of Winterfell.").toDF("text")
* val result = pipeline.fit(data).transform(data)
* result.selectExpr("explode(entities)").show(false)
* +--------------------------------------------------------------------+
* |col |
* +--------------------------------------------------------------------+
* |[chunk, 0, 2, Jon, [entity -> PERSON, sentence -> 0], []] |
* |[chunk, 29, 38, Winterfell, [entity -> LOCATION, sentence -> 0], []]|
* +--------------------------------------------------------------------+
* }}}
* @param uid
* required uid for storing annotator to disk
* @groupname anno Annotator types
* @groupdesc anno
* Required input and expected output annotator types
* @groupname Ungrouped Members
* @groupname param Parameters
* @groupname setParam Parameter setters
* @groupname getParam Parameter getters
* @groupname Ungrouped Members
* @groupprio param 1
* @groupprio anno 2
* @groupprio Ungrouped 3
* @groupprio setParam 4
* @groupprio getParam 5
* @groupdesc param
* A list of (hyper-)parameter keys this annotator can take. Users can set and get the
* parameter values through setters and getters, respectively.
class EntityRulerApproach(override val uid: String)
extends AnnotatorApproach[EntityRulerModel]
with HasStorage {
def this() = this(Identifiable.randomUID("ENTITY_RULER"))
override val description: String = "Entity Ruler matches entities based on text patterns"
private var entitiesForRegex: Array[String] = Array()
private val keywordsPatterns: ArrayBuffer[EntityPattern] = ArrayBuffer()
private var regexPatterns: Map[String, Seq[String]] = Map()
/** Resource in JSON or CSV format to map entities to patterns (Default: `null`).
* @group param
val patternsResource: ExternalResourceParam = new ExternalResourceParam(
"Resource in JSON or CSV format to map entities to patterns")
val sentenceMatch = new BooleanParam(
"Whether to find match at sentence level (regex only). True: sentence level. False: token level")
/** Whether to use RocksDB storage to serialize patterns (Default: `true`).
* @group param
val useStorage =
new BooleanParam(this, "useStorage", "Whether to use RocksDB storage to serialize patterns")
val alphabet = new ExternalResourceParam(
"Alphabet resource path to plain text file with all characters in a given alphabet")
/** @group setParam */
def setPatternsResource(
path: String,
readAs: ReadAs.Format,
options: Map[String, String] = Map("format" -> "JSON")): this.type =
set(patternsResource, ExternalResource(path, readAs, options))
def setSentenceMatch(value: Boolean): this.type = set(sentenceMatch, value)
/** @group setParam */
def setUseStorage(value: Boolean): this.type = set(useStorage, value)
/** @group setParam */
def setAlphabetResource(path: String): this.type = {
set(alphabet, ExternalResource(path, ReadAs.TEXT, Map()))
storagePath -> ExternalResource("", ReadAs.TEXT, Map()),
patternsResource -> null,
useStorage -> false,
sentenceMatch -> false,
caseSensitive -> true,
alphabet -> ExternalResource("english", ReadAs.TEXT, Map()))
private val AVAILABLE_FORMATS = Array("JSON", "JSONL", "CSV")
override def beforeTraining(spark: SparkSession): Unit = {
override def train(
dataset: Dataset[_],
recursivePipeline: Option[PipelineModel]): EntityRulerModel = {
val entityRuler = new EntityRulerModel()
if ($(useStorage)) {
} else {
val entityRulerFeatures = EntityRulerFeatures(Map(), regexPatterns)
var automaton: Option[AhoCorasickAutomaton] = None
if (keywordsPatterns.nonEmpty) {
val alphabet = EntityRulerUtil.loadAlphabet($(this.alphabet).path)
automaton = Some(
new AhoCorasickAutomaton(alphabet, keywordsPatterns.toArray, $(caseSensitive)))
protected def index(
fitDataset: Dataset[_],
storageSourcePath: Option[String],
readAs: Option[ReadAs.Value],
writers: Map[Name, StorageWriter[_]],
readOptions: Option[Map[String, String]]): Unit = {
if ($(useStorage)) {
val storageWriter =
private def storePatterns(storageWriter: Option[RegexPatternsReadWriter]): Unit = {
resourceFormats match {
case "JSON&TEXT" => storePatternsFromJson(storageWriter)
case "JSONL&TEXT" => storePatternsFromJsonl(storageWriter)
case "JSON&SPARK" => storePatternsFromJSONDataFrame(storageWriter, "JSON")
case "JSONL&SPARK" => storePatternsFromJSONDataFrame(storageWriter, "JSONL")
case "CSV&TEXT" => storePatternsFromCSV(storageWriter)
case "CSV&SPARK" => storeEntityPatternsFromCSVDataFrame(storageWriter)
case _ @format => throw new IllegalArgumentException(s"format $format not available")
private def validateParameters(): Unit = {
require($(patternsResource) != null, "patternsResource parameter required")
require($(patternsResource).path != "", "path for a patternsResource file is required")
$(patternsResource).options.getOrElse("format", "").toUpperCase()),
"format option parameter required with either JSON or CSV values")
if ($(patternsResource).options("format").toUpperCase() == "CSV") {
$(patternsResource).options.getOrElse("delimiter", "") != "",
"delimiter option parameter required")
require($(patternsResource).readAs != null, "readAs parameter required")
private lazy val resourceFormats: String = $(patternsResource)
.toUpperCase() + "&" + $(patternsResource).readAs
private def storePatternsFromJson(storageReadWriter: Option[RegexPatternsReadWriter]): Unit = {
val entityPatterns: Array[EntityPattern] = parseJSON()
entityPatterns.foreach { entityPattern =>
if (entityPattern.regex.getOrElse(false)) {
storeEntityPattern(entityPattern, storageReadWriter)
} else {
private def storeEntityPattern(
entityPattern: EntityPattern,
storageReadWriter: Option[RegexPatternsReadWriter]): Unit = {
val entity =
if (entityPattern.id.isDefined) s"${entityPattern.label},${entityPattern.id.get}"
else entityPattern.label
storageReadWriter.getOrElse(None) match {
case patternsWriter: PatternsReadWriter =>
storePatterns(entityPattern.patterns.toIterator, entity, patternsWriter)
case regexPatternsWriter: RegexPatternsReadWriter =>
storeRegexPattern(entityPattern.patterns, entity, regexPatternsWriter)
case None => {
val isRegex = entityPattern.regex.getOrElse(false)
computePatterns(entityPattern.patterns, isRegex, entity)
private def parseJSON(): Array[EntityPattern] = {
val stream = ResourceHelper.getResourceStream($(patternsResource).path)
val jsonContent = Source.fromInputStream(stream).mkString
val entityPatterns: Array[EntityPattern] = JsonParser.parseArray[EntityPattern](jsonContent)
private def storePatternsFromJsonl(storageReadWriter: Option[RegexPatternsReadWriter]): Unit = {
val sourceStream = ResourceHelper.SourceStream($(patternsResource).path)
sourceStream.content.foreach(content =>
content.foreach { line =>
val entityPattern: EntityPattern = JsonParser.parseObject[EntityPattern](line)
if (entityPattern.regex.getOrElse(false)) {
storeEntityPattern(entityPattern, storageReadWriter)
} else keywordsPatterns.append(entityPattern)
private def storePatternsFromCSV(regexPatternsWriter: Option[RegexPatternsReadWriter]): Unit = {
val delimiter: String = $(patternsResource).options("delimiter")
val patternsLines = ResourceHelper.parseLines($(patternsResource))
val regexList: ArrayBuffer[String] = ArrayBuffer()
val keywords: mutable.Map[String, Seq[String]] = mutable.Map()
val regexPatterns: mutable.Map[String, Seq[String]] = mutable.Map()
var patternsHasRegex = false
val groupByLabel =
patternsLines.groupBy(pattern => EntityRulerUtil.splitString(pattern, delimiter)(0))
groupByLabel.foreach { case (label, lines) =>
lines.foreach { line =>
val columns: Array[String] = EntityRulerUtil.splitString(line, delimiter)
val pattern = columns(1)
val isRegex = if (columns.length == 2) false else EntityRulerUtil.toBoolean(columns(2))
if (isRegex) {
patternsHasRegex = true
} else {
val patterns = keywords.getOrElse(label, Seq())
keywords(label) = patterns ++ Seq(pattern)
if (regexPatternsWriter.isEmpty) {
regexPatterns(label) = regexList
if (patternsHasRegex && regexPatternsWriter.nonEmpty) {
storeRegexPattern(regexList, label, regexPatternsWriter.get)
keywords.foreach { case (label, patterns) =>
keywordsPatterns.append(EntityPattern(label, patterns))
if (regexPatternsWriter.isEmpty) {
this.regexPatterns = regexPatterns.toMap
if (patternsHasRegex) entitiesForRegex = regexPatterns.keys.toArray
private def storeEntityPatternsFromCSVDataFrame(
storageReadWriter: Option[RegexPatternsReadWriter]): Unit = {
val patternOptions = $(patternsResource).options
val patternsSchema = StructType(
StructField("label", StringType, nullable = false),
StructField("pattern", StringType, nullable = false),
StructField("regex", BooleanType, nullable = true)))
val patternsDataFrame = spark.read
.option("delimiter", patternOptions("delimiter"))
.fill(value = false, Array("regex"))
val groupedByPatternsDataFrame = patternsDataFrame
.groupBy("label", "regex")
idFieldExist = false,
regexFieldExist = true,
private def storePatternsFromJSONDataFrame(
storageReadWriter: Option[RegexPatternsReadWriter],
format: String): Unit = {
val path = $(patternsResource).path
val dataFrameReader = spark.read
if (format.equals("JSON")) {
dataFrameReader.option("multiline", "true")
var patternsDataFrame = dataFrameReader
val idField: Array[StructField] =
patternsDataFrame.schema.fields.filter(field => field.name == "id")
val regexField: Array[StructField] =
patternsDataFrame.schema.fields.filter(field => field.name == "regex")
if (regexField.isEmpty) {
patternsDataFrame = patternsDataFrame.withColumn("regex", lit(false))
} else {
patternsDataFrame = patternsDataFrame.na.fill(value = false, Array("regex"))
if (idField.nonEmpty) patternsDataFrame.na.drop()
private def storeFromDataFrame(
patternsDataFrame: DataFrame,
idFieldExist: Boolean,
regexFieldExist: Boolean,
storageReadWriter: Option[RegexPatternsReadWriter]): Unit = {
val regexPatternsDataFrame = patternsDataFrame.filter(col("regex") === true)
val cleanedRegexPatternsDataFrame =
cleanPatternsDataFrame(regexPatternsDataFrame, idFieldExist)
cleanedRegexPatternsDataFrame.rdd.toLocalIterator.foreach { row =>
val patterns = row.getAs[Seq[String]]("flatten_patterns")
val entity =
if (idFieldExist) row.getAs[String]("label_id") else row.getAs[String]("label")
storageReadWriter.getOrElse(None) match {
case patternsWriter: PatternsReadWriter =>
storePatterns(patterns.toIterator, entity, patternsWriter)
case regexPatternsWriter: RegexPatternsReadWriter =>
storeRegexPattern(patterns, entity, regexPatternsWriter)
case None => computePatterns(patterns, isRegex = true, entity)
val keywordsDataFrame = patternsDataFrame.filter(col("regex") === false)
val cleanedKeywordsDataFrame = cleanPatternsDataFrame(keywordsDataFrame, idFieldExist)
cleanedKeywordsDataFrame.rdd.toLocalIterator.foreach { row =>
val patterns = row.getAs[Seq[String]]("flatten_patterns")
if (idFieldExist) {
val labelId = row.getAs[String]("label_id")
val label = labelId.split(",")(0)
val id = labelId.split(",")(1)
keywordsPatterns.append(EntityPattern(label, patterns, Some(id), Some(true)))
} else {
val label = row.getAs[String]("label")
keywordsPatterns.append(EntityPattern(label, patterns, None, Some(true)))
private def cleanPatternsDataFrame(
patternsDataFrame: DataFrame,
idFieldExist: Boolean): DataFrame = {
if (idFieldExist) {
val patternsWithIdDataFrame =
patternsDataFrame.withColumn("label_id", concat(col("label"), lit(","), col("id")))
} else {
private def storePatterns(
patterns: Iterator[String],
entity: String,
patternsReaderWriter: PatternsReadWriter): Unit = {
patterns.foreach(pattern => storePattern(pattern, entity, patternsReaderWriter))
private def storePattern(
pattern: String,
entity: String,
patternsReaderWriter: PatternsReadWriter): Unit = {
patternsReaderWriter.lookup(pattern).getOrElse(patternsReaderWriter.add(pattern, entity))
private def storeRegexPattern(
pattern: Seq[String],
entity: String,
regexPatternsReaderWriter: RegexPatternsReadWriter): Unit = {
if (!entitiesForRegex.contains(entity)) {
entitiesForRegex = entitiesForRegex ++ Array(entity)
.getOrElse(regexPatternsReaderWriter.add(entity, pattern))
protected def createWriter(database: Name, connection: RocksDBConnection): StorageWriter[_] = {
new RegexPatternsReadWriter(connection)
override def indexStorage(fitDataset: Dataset[_], resource: Option[ExternalResource]): Unit = {
if ($(useStorage)) {
super.indexStorage(fitDataset, resource)
private def computePatterns(patterns: Seq[String], isRegex: Boolean, entity: String): Unit = {
if (isRegex) {
regexPatterns = regexPatterns ++ Map(entity -> patterns)
if (!entitiesForRegex.contains(entity)) {
entitiesForRegex = entitiesForRegex ++ Array(entity)
/** Input annotator types: DOCUMENT, TOKEN
* @group anno
override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT)
override val optionalInputAnnotatorTypes: Array[String] = Array(TOKEN)
/** Output annotator types: CHUNK
* @group anno
override val outputAnnotatorType: AnnotatorType = CHUNK
override protected val databases: Array[Name] = EntityRulerModel.databases
© 2015 - 2025 Weber Informatics LLC | Privacy Policy