All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.clulab.reach.utils.DependencyUtils.scala Maven / Gradle / Ivy

The newest version!
package org.clulab.reach.utils

import org.clulab.processors.Sentence
import org.clulab.struct.{DirectedGraph, Interval}

/**
 * Utility functions for use with directed (dependency) graphs
 * User: danebell
 * Date: 2/23/15
 */
object DependencyUtils {

  val defaultPolicy: (Seq[Int]) => Int = _.last

  /**
   * Given an Interval, finds the minimal span covering all of the Interval's nodes' children (recursively).
   *
   * @param span Interval of nodes
   * @param sent the sentence over which the interval applies
   * @return the minimal Interval that contains all the nodes that are children of span's nodes
   */
  def subgraph(span: Interval, sent: Sentence): Option[Interval] = {
    val graph = sent.dependencies.getOrElse(return None)

    val heads = if (span.size < 2) Seq(span.start) else findHeadsStrict(span, sent)

    @annotation.tailrec
    def followTrail(remaining: Seq[Int], results: Seq[Int]): Seq[Int] = remaining match {
      case Nil => results
      case first +: rest if results contains first => followTrail(rest, results)
      case first +: rest =>
        val children: Seq[Int] = try {
          graph.getOutgoingEdges(first).map(_._1)
        } catch {
          case e: Exception =>
            Nil
        }
        followTrail(children ++ rest, first +: results)
    }

    val outgoing = (for (h <- heads) yield followTrail(Seq(h), Nil)).flatten.distinct

    // outgoing may only have a single index
    if (outgoing.isEmpty) Some(span) // no head found, so no expansion is possible
    else if (outgoing.length == 1) Some(Interval(outgoing.head, outgoing.head + 1))
    else Some(Interval(outgoing.min, outgoing.max+1))
  }

  /**
   * Finds the highest node (i.e. closest to a root) in an Interval of a directed graph. If there are multiple nodes of
   * the same rank, chooseWhich adjudicates which single node is returned.
   *
   * @param span an Interval of nodes
   * @param graph a directed graph containing the nodes in span
   * @param chooseWhich a function deciding which of multiple heads is returned; the rightmost head selected by default
   * @return the single node which is closest to the root among those in span
   */
  def findHead(span: Interval, graph: DirectedGraph[String], chooseWhich:(Seq[Int]) => Int = defaultPolicy): Int = {
    chooseWhich(findHeads(span, graph))
  }

  /**
   * Finds the highest node (i.e. closest to a root) in an Interval of a directed graph. If there are multiple nodes of
   * the same rank, all are returned.
   *
   * @param span an Interval of nodes
   * @param graph a directed graph containing the nodes in span
   * @return the single node which is closest to the root among those in span
   */
  def findHeads(span: Interval, graph: DirectedGraph[String]): Seq[Int] = {

    def getIncoming(i: Int) = graph.incomingEdges.lift(i).getOrElse(Array.empty).map(_._1)

    def followTrail (i: Int, heads:Seq[Int]): Seq[Int] = {

      // dependents
      val incoming = getIncoming(i)

      incoming match {
        case valid if valid.isEmpty | valid.contains(i) => Seq(i)
        case found if found.min < span.start | found.max > (span.end - 1) => followTrailOut(found.head, heads ++ Seq(i), i)
        case _ => followTrail(incoming.head, heads ++ Seq(i))
      }
    }

    def followTrailOut (i: Int, heads:Seq[Int], highest: Int): Seq[Int] = {

      val incoming = getIncoming(i)

      incoming match {
        case valid if valid.isEmpty | valid.contains(i) => Seq(highest)
        case outgoing if outgoing.min < span.start | outgoing.max > (span.end - 1) => followTrailOut (incoming.head, heads ++ Seq(i), highest)
        case _ => followTrail (incoming.head, heads ++ Seq(i))
      }
    }

    val heads = for (i <- span.start until span.end) yield followTrail(i, Nil)

    heads.flatten.distinct.sorted
  }

  /**
   * Find the single highest node in an interval of a dependency graph, ignoring punctuation, coordinations, and
   * prepositions.
   * @param span the interval within which to search
   * @param sent the Sentence within which to look
   * @param chooseWhich the function to adjudicate which is highest when there's a tie
   * @return Option containing the highest node index, or None if no such node is found
   */
  def findHeadStrict(span: Interval, sent: Sentence, chooseWhich:(Seq[Int]) => Int = defaultPolicy): Option[Int] = {
    val hds = findHeadsStrict(span, sent)
    hds match{
      case valid if valid.nonEmpty => Some(chooseWhich(valid))
      case _ => None
    }
  }

  /**
   * Find the highest nodes in an interval of a dependency graph, ignoring puncutation, coordinations, and prepositions.
   * Allows multiple node indices to be "highest" in the case of a tie.
   * @param span the interval within which to search
   * @param sent the Sentence within which to look
   * @return Option containing a sequence of highest node indices, or None if no such node is found
   */
  def findHeadsStrict(span: Interval, sent: Sentence): Seq[Int] = {

    if (sent.dependencies.isEmpty) return Nil

    val stopTags = "(.|,|\\(|\\)|:|''|``|#|$|CC|TO|IN)"

    def getIncoming(i: Int) = sent.dependencies.get.incomingEdges.lift(i).getOrElse(Array.empty).map(_._1)

    def followTrail (i: Int, heads:Seq[Int]): Seq[Int] = {

      // dependents
      val incoming = getIncoming(i)

      incoming match {
        case valid if valid.isEmpty | valid.contains(i) =>  Seq(i)
        case found if found.min < span.start | found.max > (span.end - 1) => followTrailOut(found.head, heads ++ Seq(i), i)
        case _ => followTrail(incoming.head, heads ++ Seq(i))
      }
    }

    def followTrailOut (i: Int, heads:Seq[Int], highest: Int): Seq[Int] = {

      val incoming = getIncoming(i)

      incoming match {
        case valid if valid.isEmpty | valid.contains(i) => Seq(highest)
        case outgoing if outgoing.min < span.start | outgoing.max > (span.end - 1) => followTrailOut (incoming.head, heads ++ Seq(i), highest)
        case _ => followTrail (incoming.head, heads ++ Seq(i))
      }
    }

    val heads = for (i <- span.start until span.end) yield followTrail(i, Nil)

    val filtered = heads.flatten.distinct.filter(x => !(sent.tags.get(x) matches stopTags)).sorted
    filtered
  }


  /**
   * Finds the highest node (i.e. closest to a root) in an Interval of a directed graph, ignoring the graph outside the
   * Interval. If there are multiple nodes of the same rank, chooseWhich adjudicates which single node is returned.
   * 

* Crucially, any node that has an incoming edge from outside the Interval is considered a head. This is efficient if * you know your span has a single head that is also within the span. * * @param span an Interval of nodes * @param graph a directed graph containing the nodes in span * @param chooseWhich a function deciding which of multiple heads is returned; the rightmost head selected by default * @return the single node which is closest to the root among those in span */ def findHeadLocal(span: Interval, graph: DirectedGraph[String], chooseWhich:(Seq[Int]) => Int = defaultPolicy): Int = { chooseWhich(findHeadsLocal(span, graph)) } /** * Finds the highest node (i.e. closest to a root) in an Interval of a directed graph. If there are multiple nodes of * the same rank, all are returned. *

* Crucially, any node that has an incoming edge from outside the Interval is considered a head. This is efficient if * you know your span has a single head that is also within the span. * * @param span an Interval of nodes * @param graph a directed graph containing the nodes in span * @return the single node which is closest to the root among those in span */ def findHeadsLocal(span: Interval, graph: DirectedGraph[String]): Seq[Int] = { def followTrail (i: Int, heads:Seq[Int]): Seq[Int] = { // dependents val incoming = graph.getIncomingEdges(i).map(_._1) incoming match { case valid if valid.isEmpty | valid.contains(i) => Seq(i) case found if heads.contains(i) | found.min < span.start | found.max > (span.end - 1) => Seq(i) case _ => followTrail(incoming.head, heads ++ Seq(i)) } } val heads = for (i <- span.start until span.end) yield followTrail(i, Nil) heads.flatten.distinct.sorted } /** * * @param a Interval in Sentence sentA * @param b Interval in Sentence sentB * @param sentA Sentence containing a * @param sentB Sentence containing b * @return returns true if Interval a contains Interval b or vice versa */ def nested(a: Interval, b: Interval, sentA: Sentence, sentB: Sentence): Boolean = { if (sentA != sentB) return false val graph = sentA.dependencies.getOrElse(return false) val aSubgraph = subgraph(a, sentA) val bSubgraph = subgraph(b, sentB) if((aSubgraph.isDefined && aSubgraph.get.contains(b)) || (bSubgraph.isDefined && bSubgraph.get.contains(a))) true else false } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy