All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.GraphExtraction.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common.LabeledDependency.DependencyInfo
import com.johnsnowlabs.nlp.annotators.common.{LabeledDependency, NerTagged}
import com.johnsnowlabs.nlp.annotators.ner.NerTagsEncoding
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
import com.johnsnowlabs.nlp.util.GraphBuilder
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{BooleanParam, IntParam, Param, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.{DataFrame, Dataset}

/** Extracts a dependency graph between entities.
  *
  * The GraphExtraction class takes e.g. extracted entities from a
  * [[com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel NerDLModel]] and creates a dependency tree
  * which describes how the entities relate to each other. For that a triple store format is used.
  * Nodes represent the entities and the edges represent the relations between those entities. The
  * graph can then be used to find relevant relationships between words.
  *
  * Both the
  * [[com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel DependencyParserModel]] and
  * [[com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel TypedDependencyParserModel]]
  * need to be present in the pipeline. There are two ways to set them:
  *
  *   1. Both Annotators are present in the pipeline already. The dependencies are taken
  *      implicitly from these two Annotators.
  *   1. Setting `setMergeEntities` to `true` will download the default pretrained models for
  *      those two Annotators automatically. The specific models can also be set with
  *      `setDependencyParserModel` and `setTypedDependencyParserModel`:
  *      {{{
  *            val graph_extraction = new GraphExtraction()
  *              .setInputCols("document", "token", "ner")
  *              .setOutputCol("graph")
  *              .setRelationshipTypes(Array("prefer-LOC"))
  *              .setMergeEntities(true)
  *            //.setDependencyParserModel(Array("dependency_conllu", "en",  "public/models"))
  *            //.setTypedDependencyParserModel(Array("dependency_typed_conllu", "en",  "public/models"))
  *      }}}
  *
  * To transform the resulting graph into a more generic form such as RDF, see the
  * [[com.johnsnowlabs.nlp.GraphFinisher GraphFinisher]].
  *
  * ==Example==
  * {{{
  * import spark.implicits._
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
  * import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
  * import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
  * import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
  * import org.apache.spark.ml.Pipeline
  * import com.johnsnowlabs.nlp.annotators.GraphExtraction
  *
  * val documentAssembler = new DocumentAssembler()
  *   .setInputCol("text")
  *   .setOutputCol("document")
  *
  * val sentence = new SentenceDetector()
  *   .setInputCols("document")
  *   .setOutputCol("sentence")
  *
  * val tokenizer = new Tokenizer()
  *   .setInputCols("sentence")
  *   .setOutputCol("token")
  *
  * val embeddings = WordEmbeddingsModel.pretrained()
  *   .setInputCols("sentence", "token")
  *   .setOutputCol("embeddings")
  *
  * val nerTagger = NerDLModel.pretrained()
  *   .setInputCols("sentence", "token", "embeddings")
  *   .setOutputCol("ner")
  *
  * val posTagger = PerceptronModel.pretrained()
  *   .setInputCols("sentence", "token")
  *   .setOutputCol("pos")
  *
  * val dependencyParser = DependencyParserModel.pretrained()
  *   .setInputCols("sentence", "pos", "token")
  *   .setOutputCol("dependency")
  *
  * val typedDependencyParser = TypedDependencyParserModel.pretrained()
  *   .setInputCols("dependency", "pos", "token")
  *   .setOutputCol("dependency_type")
  *
  * val graph_extraction = new GraphExtraction()
  *   .setInputCols("document", "token", "ner")
  *   .setOutputCol("graph")
  *   .setRelationshipTypes(Array("prefer-LOC"))
  *
  * val pipeline = new Pipeline().setStages(Array(
  *   documentAssembler,
  *   sentence,
  *   tokenizer,
  *   embeddings,
  *   nerTagger,
  *   posTagger,
  *   dependencyParser,
  *   typedDependencyParser,
  *   graph_extraction
  * ))
  *
  * val data = Seq("You and John prefer the morning flight through Denver").toDF("text")
  * val result = pipeline.fit(data).transform(data)
  *
  * result.select("graph").show(false)
  * +-----------------------------------------------------------------------------------------------------------------+
  * |graph                                                                                                            |
  * +-----------------------------------------------------------------------------------------------------------------+
  * |[[node, 13, 18, prefer, [relationship -> prefer,LOC, path1 -> prefer,nsubj,morning,flat,flight,flat,Denver], []]]|
  * +-----------------------------------------------------------------------------------------------------------------+
  * }}}
  *
  * @see
  *   [[com.johnsnowlabs.nlp.GraphFinisher GraphFinisher]] to output the paths in a more generic
  *   format, like RDF
  * @param uid
  *   required uid for storing annotator to disk
  * @groupname anno Annotator types
  * @groupdesc anno
  *   Required input and expected output annotator types
  * @groupname Ungrouped Members
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupname Ungrouped Members
  * @groupprio param  1
  * @groupprio anno  2
  * @groupprio Ungrouped 3
  * @groupprio setParam  4
  * @groupprio getParam  5
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
class GraphExtraction(override val uid: String)
    extends AnnotatorModel[GraphExtraction]
    with HasSimpleAnnotate[GraphExtraction] {

  def this() = this(Identifiable.randomUID("GRAPH_EXTRACTOR"))

  /** Find paths between a pair of token and entity (Default: `Array()`)
    *
    * @group param
    */
  val relationshipTypes = new StringArrayParam(
    this,
    "relationshipTypes",
    "Find paths between a pair of token and entity")

  /** Find paths between a pair of entities (Default: `Array()`)
    *
    * @group param
    */
  val entityTypes =
    new StringArrayParam(this, "entityTypes", "Find paths between a pair of entities")

  /** When set to true find paths between entities (Default: `false`)
    *
    * @group param
    */
  val explodeEntities =
    new BooleanParam(this, "explodeEntities", "When set to true find paths between entities")

  /** Tokens to be consider as root to start traversing the paths (Default: `Array()`). Use it
    * along with `explodeEntities`
    *
    * @group param
    */
  val rootTokens = new StringArrayParam(
    this,
    "rootTokens",
    "Tokens to be consider as root to start traversing the paths. Use it along with explodeEntities")

  /** Maximum sentence size that the annotator will process (Default: `1000`). Above this, the
    * sentence is skipped
    *
    * @group param
    */
  val maxSentenceSize = new IntParam(
    this,
    "maxSentenceSize",
    "Maximum sentence size that the annotator will process. Above this, the sentence is skipped")

  /** Minimum sentence size that the annotator will process (Default: `2`). Below this, the
    * sentence is skipped
    *
    * @group param
    */
  val minSentenceSize = new IntParam(
    this,
    "minSentenceSize",
    "Minimum sentence size that the annotator will process. Below this, the sentence is skipped")

  /** Merge same neighboring entities as a single token (Default: `false`)
    *
    * @group param
    */
  val mergeEntities =
    new BooleanParam(this, "mergeEntities", "Merge same neighboring entities as a single token")

  /** IOB format to apply when merging entities
    *
    * @group param
    */
  val mergeEntitiesIOBFormat = new Param[String](
    this,
    "mergeEntitiesIOBFormat",
    "IOB format to apply when merging entities. Values: IOB or IOB2")

  /** Whether to include edges when building paths (Default: `true`)
    *
    * @group param
    */
  val includeEdges =
    new BooleanParam(this, "includeEdges", "Whether to include edges when building paths")

  /** Delimiter symbol used for path output (Default: `","`)
    *
    * @group param
    */
  val delimiter = new Param[String](this, "delimiter", "Delimiter symbol used for path output")

  /** Coordinates (name, lang, remoteLoc) to a pretrained POS model (Default: `Array()`)
    *
    * @group param
    */
  val posModel = new StringArrayParam(
    this,
    "posModel",
    "Coordinates (name, lang, remoteLoc) to a pretrained POS model")

  /** Coordinates (name, lang, remoteLoc) to a pretrained Dependency Parser model (Default:
    * `Array()`)
    *
    * @group param
    */
  val dependencyParserModel = new StringArrayParam(
    this,
    "dependencyParserModel",
    "Coordinates (name, lang, remoteLoc) to a pretrained Dependency Parser model")

  /** Coordinates (name, lang, remoteLoc) to a pretrained Typed Dependency Parser model (Default:
    * `Array()`)
    *
    * @group param
    */
  val typedDependencyParserModel = new StringArrayParam(
    this,
    "typedDependencyParserModel",
    "Coordinates (name, lang, remoteLoc) to a pretrained Typed Dependency Parser model")

  /** @group setParam */
  def setRelationshipTypes(value: Array[String]): this.type = set(relationshipTypes, value)

  /** @group setParam */
  def setEntityTypes(value: Array[String]): this.type = set(entityTypes, value)

  /** @group setParam */
  def setExplodeEntities(value: Boolean): this.type = set(explodeEntities, value)

  /** @group setParam */
  def setRootTokens(value: Array[String]): this.type = set(rootTokens, value)

  /** @group setParam */
  def setMaxSentenceSize(value: Int): this.type = set(maxSentenceSize, value)

  /** @group setParam */
  def setMinSentenceSize(value: Int): this.type = set(minSentenceSize, value)

  /** @group setParam */
  def setMergeEntities(value: Boolean): this.type = set(mergeEntities, value)

  /** @group setParam */
  def setMergeEntitiesIOBFormat(value: String): this.type = set(mergeEntitiesIOBFormat, value)

  /** @group setParam */
  def setIncludeEdges(value: Boolean): this.type = set(includeEdges, value)

  /** @group setParam */
  def setDelimiter(value: String): this.type = set(delimiter, value)

  /** @group setParam */
  def setPosModel(value: Array[String]): this.type = set(posModel, value)

  /** @group setParam */
  def setDependencyParserModel(value: Array[String]): this.type =
    set(dependencyParserModel, value)

  /** @group setParam */
  def setTypedDependencyParserModel(value: Array[String]): this.type =
    set(typedDependencyParserModel, value)

  setDefault(
    entityTypes -> Array(),
    explodeEntities -> true,
    maxSentenceSize -> 1000,
    minSentenceSize -> 2,
    mergeEntities -> true,
    rootTokens -> Array(),
    relationshipTypes -> Array(),
    includeEdges -> true,
    delimiter -> ",",
    posModel -> Array(),
    dependencyParserModel -> Array(),
    typedDependencyParserModel -> Array(),
    mergeEntitiesIOBFormat -> "IOB2")

  private lazy val allowedEntityRelationships = $(entityTypes).map { entityRelationship =>
    val result = entityRelationship.split("-")
    (result.head, result.last)
  }.distinct

  private lazy val allowedRelationshipTypes = $(relationshipTypes).map { relationshipTypes =>
    val result = relationshipTypes.split("-")
    (result.head, result.last)
  }.distinct

  private var pretrainedPos: Option[PerceptronModel] = None
  private var pretrainedDependencyParser: Option[DependencyParserModel] = None
  private var pretrainedTypedDependencyParser: Option[TypedDependencyParserModel] =
    None

  override def _transform(
      dataset: Dataset[_],
      recursivePipeline: Option[PipelineModel]): DataFrame = {
    if ($(mergeEntities)) {
      super._transform(dataset, recursivePipeline)
    } else {
      val structFields = dataset.schema.fields
        .filter(field => field.metadata.contains("annotatorType"))
        .filter(field =>
          field.metadata.getString("annotatorType") == DEPENDENCY ||
            field.metadata.getString("annotatorType") == LABELED_DEPENDENCY)
      if (structFields.length < 2) {
        throw new IllegalArgumentException(
          s"Missing either $DEPENDENCY or $LABELED_DEPENDENCY annotators. " +
            s"Make sure such annotators exist in your pipeline or setMergeEntities parameter to True")
      }

      val columnNames = structFields.map(structField => structField.name)
      val inputCols = getInputCols ++ columnNames
      val processedDataset = dataset.withColumn(
        getOutputCol,
        wrapColumnMetadata(dfAnnotate(array(inputCols.map(c => dataset.col(c)): _*))))
      processedDataset
    }
  }

  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {

    if ($(mergeEntities)) {
      pretrainedPos = Some(PretrainedAnnotations.getPretrainedPos($(posModel)))
      pretrainedDependencyParser = Some(
        PretrainedAnnotations.getDependencyParser($(dependencyParserModel)))
      pretrainedTypedDependencyParser = Some(TypedDependencyParserModel.pretrained())
    }

    dataset
  }

  /** takes a document and annotations and produces new annotations of this annotator's annotation
    * type
    *
    * @param annotations
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
    * @return
    *   any number of annotations processed for every input annotation. Not necessary one to one
    *   relationship
    */
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentenceIndexesToSkip = annotations
      .filter(_.annotatorType == AnnotatorType.DOCUMENT)
      .filter(annotation =>
        annotation.result.length > $(maxSentenceSize) || annotation.result.length < $(
          minSentenceSize))
      .map(annotation => annotation.metadata("sentence"))
      .toList
      .distinct

    val annotationsToProcess = annotations.filter(annotation => {
      !sentenceIndexesToSkip.contains(annotation.metadata.getOrElse("sentence", "0"))
    })

    if (annotationsToProcess.isEmpty) {
      Seq(Annotation(NODE, 0, 0, "", Map()))
    } else {
      computeAnnotatePaths(annotationsToProcess)
    }
  }

  private def computeAnnotatePaths(annotations: Seq[Annotation]): Seq[Annotation] = {
    val annotationsBySentence = annotations
      .groupBy(token => token.metadata.getOrElse("sentence", "0").toInt)
      .toSeq
      .sortBy(_._1)
      .map(annotationBySentence => annotationBySentence._2)

    val graphPaths = annotationsBySentence.flatMap { sentenceAnnotations =>
      val annotationsWithDependencies = getAnnotationsWithDependencies(sentenceAnnotations)
      val tokens = annotationsWithDependencies.filter(_.annotatorType == AnnotatorType.TOKEN)
      val nerEntities = annotationsWithDependencies.filter(annotation =>
        annotation.annotatorType == TOKEN && annotation.metadata("entity") != "O")

      if (nerEntities.isEmpty) {
        Seq(Annotation(NODE, 0, 0, "", Map()))
      } else {
        val dependencyData = LabeledDependency.unpackHeadAndRelation(annotationsWithDependencies)
        val annotationsInfo = AnnotationsInfo(tokens, nerEntities, dependencyData)

        val graph = new GraphBuilder(dependencyData.length + 1)
        dependencyData.zipWithIndex.foreach { case (dependencyInfo, index) =>
          graph.addEdge(dependencyInfo.headIndex, index + 1)
        }

        if ($(explodeEntities)) {
          extractGraphsFromEntities(annotationsInfo, graph)
        } else {
          extractGraphsFromRelationships(annotationsInfo, graph)
        }
      }
    }

    graphPaths

  }

  private def getAnnotationsWithDependencies(
      sentenceAnnotations: Seq[Annotation]): Seq[Annotation] = {
    if ($(mergeEntities)) {
      getPretrainedAnnotations(sentenceAnnotations)
    } else {
      getEntityAnnotations(sentenceAnnotations)
    }
  }

  private def getPretrainedAnnotations(annotationsToProcess: Seq[Annotation]): Seq[Annotation] = {

    val relatedAnnotatedTokens = mergeRelatedTokens(annotationsToProcess)
    val sentence = annotationsToProcess.filter(_.annotatorType == AnnotatorType.DOCUMENT)

    val posInput = sentence ++ relatedAnnotatedTokens
    val posAnnotations = PretrainedAnnotations.getPosOutput(posInput, pretrainedPos.get)

    val dependencyParserInput = sentence ++ relatedAnnotatedTokens ++ posAnnotations
    val dependencyParserAnnotations =
      PretrainedAnnotations.getDependencyParserOutput(
        dependencyParserInput,
        pretrainedDependencyParser.get)

    val typedDependencyParserInput =
      relatedAnnotatedTokens ++ posAnnotations ++ dependencyParserAnnotations
    val typedDependencyParserAnnotations = PretrainedAnnotations.getTypedDependencyParserOutput(
      typedDependencyParserInput,
      pretrainedTypedDependencyParser.get)

    relatedAnnotatedTokens ++ dependencyParserAnnotations ++ typedDependencyParserAnnotations
  }

  private def getEntityAnnotations(annotationsToProcess: Seq[Annotation]): Seq[Annotation] = {
    val entityAnnotations = annotationsToProcess.filter(_.annotatorType == NAMED_ENTITY)
    val tokensWithEntity =
      annotationsToProcess.filter(_.annotatorType == TOKEN).zipWithIndex.map {
        case (annotation, index) =>
          val tag = entityAnnotations(index).result
          val entity = if (tag.length == 1) tag else tag.substring(2)
          val metadata = annotation.metadata ++ Map("entity" -> entity)
          Annotation(
            annotation.annotatorType,
            annotation.begin,
            annotation.end,
            annotation.result,
            metadata)
      }
    val dependencyParserAnnotations = annotationsToProcess.filter(annotation =>
      annotation.annotatorType == DEPENDENCY || annotation.annotatorType == LABELED_DEPENDENCY)

    tokensWithEntity ++ dependencyParserAnnotations
  }

  private def mergeRelatedTokens(annotations: Seq[Annotation]): Seq[Annotation] = {
    val sentences = NerTagged.unpack(annotations)
    val docs = annotations.filter(a =>
      a.annotatorType == AnnotatorType.DOCUMENT && sentences.exists(b =>
        b.indexedTaggedWords.exists(c => c.begin >= a.begin && c.end <= a.end)))

    val entities = sentences.zip(docs.zipWithIndex).flatMap { case (sentence, doc) =>
      NerTagsEncoding.fromIOB(
        sentence,
        doc._1,
        sentenceIndex = doc._2,
        includeNoneEntities = true,
        format = $(mergeEntitiesIOBFormat))
    }

    entities.map(entity =>
      Annotation(
        TOKEN,
        entity.start,
        entity.end,
        entity.text,
        Map("sentence" -> entity.sentenceId, "entity" -> entity.entity)))
  }

  private def extractGraphsFromEntities(
      annotationsInfo: AnnotationsInfo,
      graph: GraphBuilder): Seq[Annotation] = {

    var rootIndices: Array[Int] = Array()
    var sourceDependencies: Array[DependencyInfo] = Array()

    if ($(rootTokens).isEmpty) {
      sourceDependencies = annotationsInfo.dependencyData
        .filter(dependencyInfo => dependencyInfo.headIndex == 0)
        .toArray
      rootIndices = Array(annotationsInfo.dependencyData.indexOf(sourceDependencies.head) + 1)
    } else {
      sourceDependencies = $(rootTokens).flatMap(rootToken =>
        annotationsInfo.dependencyData.filter(_.token == rootToken))
      rootIndices = sourceDependencies.map(sourceDependency =>
        annotationsInfo.dependencyData.indexOf(sourceDependency) + 1)
    }

    val entitiesPairData =
      getEntitiesData(annotationsInfo.nerEntities, annotationsInfo.dependencyData)
    val annotatedPaths = rootIndices.flatMap(rootIndex =>
      getAnnotatedPaths(entitiesPairData, graph, rootIndex, annotationsInfo))

    if (annotatedPaths.isEmpty && $(rootTokens).nonEmpty) {
      println(
        s"[WARN] Not found paths between given roots: [${$(rootTokens).mkString(",")}] and" +
          s" entities pairs: ${entitiesPairData.map(x => x.entities).mkString(",")}.\n" +
          s"This could mean there are no more labeled tokens below the given roots or NER didn't label any token.\n" +
          s"$entitiesWarnMessage")
    }

    if (annotatedPaths.isEmpty && $(rootTokens).isEmpty) {
      println(
        s"[WARN] Not found paths between the root [${sourceDependencies.head.token}] and " +
          s" entities pairs ${entitiesPairData.map(x => x.entities).mkString(",")}.\n" +
          s"This could mean there are no more labeled tokens below the root or NER didn't label any token.\n" +
          s"$entitiesWarnMessage")
    }

    annotatedPaths
  }

  private def entitiesWarnMessage: String = {
    val notebooksURI =
      "https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/annotation/english/"
    val relationshipTypesNotebook =
      s"$notebooksURI/graph-extraction/graph_extraction_roots_paths.ipynb"
    val displayNotebook = s"$notebooksURI/graph-extraction/graph_extraction_helper_display.ipynb"
    val message =
      s"You can try using relationshipTypes parameter, check this notebook: $relationshipTypesNotebook \n" +
        s"You can also use spark-nlp-display to visualize Dependency Parser and NER output to help identify the kind of relations you can extract" +
        s", check this notebook: $displayNotebook"
    message
  }

  private def extractGraphsFromRelationships(
      annotationsInfo: AnnotationsInfo,
      graph: GraphBuilder): Seq[Annotation] = {

    val annotatedGraphPaths = allowedRelationshipTypes.flatMap { relationshipTypes =>
      val rootData = annotationsInfo.tokens
        .filter(_.result == relationshipTypes._1)
        .map(token => (token, annotationsInfo.tokens.indexOf(token) + 1))
      val entityIndexes = annotationsInfo.nerEntities
        .filter(_.metadata("entity") == relationshipTypes._2)
        .map(nerEntity => annotationsInfo.tokens.indexOf(nerEntity) + 1)

      rootData.flatMap { rootInfo =>
        val paths = entityIndexes.flatMap(entityIndex =>
          buildPath(graph, (rootInfo._2, entityIndex), annotationsInfo.dependencyData))
        val pathsMap = paths.zipWithIndex.flatMap { case (path, index) =>
          Map(s"path${(index + 1).toString}" -> path)
        }.toMap
        if (paths.nonEmpty) {
          Some(
            Annotation(
              NODE,
              rootInfo._1.begin,
              rootInfo._1.end,
              rootInfo._1.result,
              Map(
                "relationship" -> s"${rootInfo._1.result},${relationshipTypes._2}") ++ pathsMap))
        } else {
          None
        }
      }
    }
    annotatedGraphPaths
  }

  private def buildPath(
      graph: GraphBuilder,
      nodesIndexes: (Int, Int),
      dependencyData: Seq[DependencyInfo]): Option[String] = {
    val rootIndex = nodesIndexes._1
    val nodesIndexesPath = graph.findPath(rootIndex, nodesIndexes._2)
    val path = nodesIndexesPath.map { nodeIndex =>
      val dependencyInfo = dependencyData(nodeIndex - 1)
      val relation = dependencyInfo.relation
      var result = dependencyInfo.token
      if ($(includeEdges)) {
        val edge =
          if (relation == "*root*" || nodeIndex == rootIndex) "" else relation + $(delimiter)
        result = edge + dependencyInfo.token
      }
      result
    }
    if (path.isEmpty) None else Some(path.mkString($(delimiter)))
  }

  private def getAnnotatedPaths(
      entitiesPairData: List[EntitiesPairInfo],
      graph: GraphBuilder,
      rootIndex: Int,
      annotationsInfo: AnnotationsInfo): Seq[Annotation] = {

    val tokens = annotationsInfo.tokens
    val dependencyData = annotationsInfo.dependencyData

    val paths = entitiesPairData.flatMap { entitiesPairInfo =>
      val leftPath =
        buildPath(graph, (rootIndex, entitiesPairInfo.entitiesIndex._1), dependencyData)
      val rightPath =
        buildPath(graph, (rootIndex, entitiesPairInfo.entitiesIndex._2), dependencyData)
      if (leftPath.nonEmpty && rightPath.nonEmpty) {
        Some(GraphInfo(entitiesPairInfo.entities, leftPath, rightPath))
      } else None
    }

    val sourceToken = tokens(rootIndex - 1)
    val annotatedPaths = paths.map { path =>
      val leftEntity = path.entities._1
      val rightEntity = path.entities._2
      val leftPathTokens = path.leftPath
      val rightPathTokens = path.rightPath

      Annotation(
        NODE,
        sourceToken.begin,
        sourceToken.end,
        sourceToken.result,
        Map(
          "entities" -> s"$leftEntity,$rightEntity",
          "left_path" -> leftPathTokens.mkString($(delimiter)),
          "right_path" -> rightPathTokens.mkString($(delimiter))))
    }
    annotatedPaths
  }

  private def getEntitiesData(
      annotatedEntities: Seq[Annotation],
      dependencyData: Seq[DependencyInfo]): List[EntitiesPairInfo] = {
    var annotatedEntitiesPairs: List[(Annotation, Annotation)] = List()
    if (allowedEntityRelationships.isEmpty) {
      annotatedEntitiesPairs =
        annotatedEntities.combinations(2).map(entity => (entity.head, entity.last)).toList
    } else {
      annotatedEntitiesPairs = allowedEntityRelationships
        .flatMap(entities => getAnnotatedNerEntitiesPairs(entities, annotatedEntities))
        .filter(entities =>
          entities._1.begin != entities._2.begin && entities._1.end != entities._2.end)
        .toList
    }

    val entitiesPairData = annotatedEntitiesPairs.map { annotatedEntityPair =>
      val dependencyInfoLeft = dependencyData.filter(dependencyInfo =>
        dependencyInfo.beginToken == annotatedEntityPair._1.begin && dependencyInfo.endToken == annotatedEntityPair._1.end)
      val dependencyInfoRight = dependencyData.filter(dependencyInfo =>
        dependencyInfo.beginToken == annotatedEntityPair._2.begin && dependencyInfo.endToken == annotatedEntityPair._2.end)
      val indexLeft = dependencyData.indexOf(dependencyInfoLeft.head) + 1
      val indexRight = dependencyData.indexOf(dependencyInfoRight.head) + 1

      EntitiesPairInfo(
        (indexLeft, indexRight),
        (annotatedEntityPair._1.metadata("entity"), annotatedEntityPair._2.metadata("entity")))
    }
    entitiesPairData.distinct
  }

  private def getAnnotatedNerEntitiesPairs(
      entities: (String, String),
      annotatedEntities: Seq[Annotation]): List[(Annotation, Annotation)] = {

    val leftEntities = annotatedEntities.filter(annotatedEntity =>
      annotatedEntity.metadata("entity") == entities._1)
    val rightEntities = annotatedEntities.filter(annotatedEntity =>
      annotatedEntity.metadata("entity") == entities._2)

    if (leftEntities.length > rightEntities.length) {
      leftEntities.flatMap { leftEntity =>
        rightEntities.map(rightEntity => (leftEntity, rightEntity))
      }.toList
    } else {
      rightEntities.flatMap { rightEntity =>
        leftEntities.map(leftEntity => (leftEntity, rightEntity))
      }.toList
    }

  }

  private case class EntitiesPairInfo(entitiesIndex: (Int, Int), entities: (String, String))

  private case class GraphInfo(
      entities: (String, String),
      leftPath: Option[String],
      rightPath: Option[String])

  private case class AnnotationsInfo(
      tokens: Seq[Annotation],
      nerEntities: Seq[Annotation],
      dependencyData: Seq[DependencyInfo])

  /** Output annotator types: NODE
    *
    * @group anno
    */
  override val outputAnnotatorType: AnnotatorType = NODE

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  /** Input annotator types: DOCUMENT, TOKEN, NAMED_ENTITY
    *
    * @group anno
    */
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN, NAMED_ENTITY)

  override val optionalInputAnnotatorTypes: Array[String] = Array(DEPENDENCY, LABELED_DEPENDENCY)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy