All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.TextMatcher.scala Maven / Gradle / Ivy

package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.{Annotation, AnnotatorApproach, DocumentAssembler}
import com.johnsnowlabs.nlp.AnnotatorType.TOKEN
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.Dataset
import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp.annotators.param.ExternalResourceParam
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import org.apache.spark.ml.param.BooleanParam

class TextMatcher(override val uid: String) extends AnnotatorApproach[TextMatcherModel] {

  def this() = this(Identifiable.randomUID("ENTITY_EXTRACTOR"))

  override val inputAnnotatorTypes = Array(DOCUMENT, TOKEN)

  override val outputAnnotatorType: AnnotatorType = CHUNK

  override val description: String = "Extracts entities from target dataset given in a text file"

  val entities = new ExternalResourceParam(this, "entities", "entities external resource.")
  val caseSensitive = new BooleanParam(this, "caseSensitive", "whether to match regardless of case. Defaults true")

  setDefault(inputCols,Array(TOKEN))
  setDefault(caseSensitive, true)

  def setEntities(value: ExternalResource): this.type =
    set(entities, value)

  def setEntities(path: String, readAs: ReadAs.Format, options: Map[String, String] = Map("format" -> "text")): this.type =
    set(entities, ExternalResource(path, readAs, options))

  def setCaseSensitive(v: Boolean): this.type =
    set(caseSensitive, v)

  def getCaseSensitive: Boolean =
    $(caseSensitive)

  /**
    * Loads entities from a provided source.
    */
  private def loadEntities(): Array[Array[String]] = {
    val phrases: Array[String] = ResourceHelper.parseLines($(entities))
    val tokenizer = new Tokenizer()
    val parsedEntities: Array[Array[String]] = phrases.map {
      line =>
        val annotation = Seq(Annotation(line))
        val tokens = tokenizer.annotate(annotation)
        tokens.map(_.result).toArray
    }
    parsedEntities
  }

  private def loadEntities(pipelineModel: PipelineModel): Array[Array[String]] = {
    val phrases: Seq[String] = ResourceHelper.parseLines($(entities))
    import ResourceHelper.spark.implicits._
    val textColumn = pipelineModel.stages.find {
      case _: DocumentAssembler => true
      case _ => false
    }.map(_.asInstanceOf[DocumentAssembler].getInputCol)
      .getOrElse(throw new Exception("Could not retrieve DocumentAssembler from RecursivePipeline"))
    val data = phrases.toDS.withColumnRenamed("value", textColumn)
    pipelineModel.transform(data).select(getInputCols.head).as[Array[Annotation]].map(_.map(_.result)).collect
  }

  override def train(dataset: Dataset[_], recursivePipeline: Option[PipelineModel]): TextMatcherModel = {
    if (recursivePipeline.isDefined) {
      new TextMatcherModel()
        .setEntities(loadEntities(recursivePipeline.get))
    } else {
      new TextMatcherModel()
        .setEntities(loadEntities())
        .setCaseSensitive($(caseSensitive))
    }
  }

}

object TextMatcher extends DefaultParamsReadable[TextMatcher]




© 2015 - 2025 Weber Informatics LLC | Privacy Policy