All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.deepsense.graph.DirectedGraph.scala Maven / Gradle / Ivy

/**
 * Copyright 2015 deepsense.ai (CodiLime, Inc)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.deepsense.graph

abstract class DirectedGraph[T <: Operation, G <: DirectedGraph[T, G]](
    val nodes: Set[Node[T]] = Set[Node[T]](),
    val edges: Set[Edge] = Set())
  extends TopologicallySortable[T]
  with Serializable {

  val validEdges = filterValidEdges(nodes, edges)

  private val idToNode = nodes.map(n => n.id -> n).toMap
  private val _predecessors = preparePredecessors
  private val _successors = prepareSuccessors
  private val _containsCycle = new TopologicalSort(this).isSorted

  def topologicallySorted: Option[List[Node[T]]] = new TopologicalSort(this).sortedNodes
  def node(id: Node.Id): Node[T] = idToNode(id)
  def predecessors(id: Node.Id): IndexedSeq[Option[Endpoint]] = _predecessors(id)
  def successors(id: Node.Id): IndexedSeq[Set[Endpoint]] = _successors(id)
  def containsCycle: Boolean = _containsCycle

  def allPredecessorsOf(id: Node.Id): Set[Node[T]] = {
    predecessors(id).foldLeft(Set[Node[T]]())((acc: Set[Node[T]], predecessor: Option[Endpoint]) =>
      predecessor match {
        case None => acc
        case Some(endpoint) => (acc + node(endpoint.nodeId)) ++
          allPredecessorsOf(endpoint.nodeId)
      })
  }

  def size: Int = nodes.size

  def rootNodes: Iterable[Node[T]] = {
    topologicallySorted.get.filter(n => predecessors(n.id).flatten.isEmpty)
  }

  def predecessorsOf(nodes: Set[Node.Id]): Set[Node.Id] = {
    nodes.flatMap {
      node => predecessors(node).flatten.map { _.nodeId }
    }
  }

  def successorsOf(node: Node.Id): Set[Node.Id] =
   successors(node).flatMap(endpoints => endpoints.map(_.nodeId)).toSet

  def subgraph(nodes: Set[Node.Id]): G = {
    def collectNodesEdges(
        previouslyCollectedNodes: Set[Node.Id],
        previouslyCollectedEdges: Set[Edge],
        toProcess: Set[Node.Id]): (Set[Node.Id], Set[Edge]) = {
      // Do not revisit nodes (in case of a cycle).
      val nodesPredecessors = predecessorsOf(toProcess) -- previouslyCollectedNodes
      val nextNodes = previouslyCollectedNodes ++ nodesPredecessors
      val nextEdges = previouslyCollectedEdges ++ edgesOf(toProcess)

      if (toProcess.isEmpty) {
        (nextNodes, nextEdges)
      } else {
        collectNodesEdges(nextNodes, nextEdges, nodesPredecessors)
      }
    }

    val (n, e) = collectNodesEdges(nodes, Set(), nodes)
    subgraph(n.map(node), e)
  }

  def subgraph(nodes: Set[Node[T]], edges: Set[Edge]): G

  def getValidEdges: Set[Edge] = validEdges

  private def edgesOf(nodes: Set[Node.Id]): Set[Edge] = nodes.flatMap(edgesTo)

  private def edgesTo(node: Node.Id): Set[Edge] = validEdges.filter(edge => edge.to.nodeId == node)

  private def preparePredecessors: Map[Node.Id, IndexedSeq[Option[Endpoint]]] = {
    import scala.collection.mutable
    val mutablePredecessors: mutable.Map[Node.Id, mutable.IndexedSeq[Option[Endpoint]]] =
      mutable.Map()

    nodes.foreach(node => {
      mutablePredecessors +=
        node.id -> mutable.IndexedSeq.fill(node.value.inArity)(None)
    })
    validEdges.foreach(edge => {
      mutablePredecessors(edge.to.nodeId)(edge.to.portIndex) = Some(edge.from)
    })
    mutablePredecessors.mapValues(_.toIndexedSeq).toMap
  }

  private def prepareSuccessors: Map[Node.Id, IndexedSeq[Set[Endpoint]]] = {
    import scala.collection.mutable
    val mutableSuccessors: mutable.Map[Node.Id, IndexedSeq[mutable.Set[Endpoint]]] =
      mutable.Map()

    nodes.foreach(node => {
      mutableSuccessors += node.id -> Vector.fill(node.value.outArity)(mutable.Set())
    })
    validEdges.foreach(edge => {
      mutableSuccessors(edge.from.nodeId)(edge.from.portIndex) += edge.to
    })
    mutableSuccessors.mapValues(_.map(_.toSet)).toMap
  }

  private def filterValidEdges(nodes: Set[Node[T]], edges: Set[Edge]): Set[Edge] = {
    val nodesIds = nodes.map(_.id)
    edges.filter(edge => {
      val inNodeOpt = nodes.find(n => n.id == edge.from.nodeId)
      val outNodeOpt = nodes.find(n => n.id == edge.to.nodeId)
      for (inNode <- inNodeOpt; outNode <- outNodeOpt) yield {
        edge.from.portIndex < inNode.value.outArity && edge.to.portIndex < outNode.value.inArity
      }}.getOrElse(false)
    )
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy