All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opencypher.okapi.ir.api.pattern.Pattern.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2016-2019 "Neo4j Sweden, AB" [https://neo4j.com]
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Attribution Notice under the terms of the Apache License 2.0
 *
 * This work was created by the collective efforts of the openCypher community.
 * Without limiting the terms of Section 6, any Derivative Work that is not
 * approved by the public consensus process of the openCypher Implementers Group
 * should not be described as “Cypher” (and Cypher® is a registered trademark of
 * Neo4j Inc.) or as "openCypher". Extensions by implementers or prototypes or
 * proposals for change that have been documented or implemented should only be
 * described as "implementation extensions to Cypher" or as "proposed changes to
 * Cypher that are not yet approved by the openCypher community".
 */
package org.opencypher.okapi.ir.api.pattern

import org.opencypher.okapi.api.types.{CTNode, CTRelationship, CypherType}
import org.opencypher.okapi.ir.api._
import org.opencypher.okapi.ir.api.block.Binds
import org.opencypher.okapi.ir.api.expr.MapExpression
import org.opencypher.okapi.ir.impl.exception.PatternConversionException

import scala.annotation.tailrec
import scala.collection.immutable.ListMap

case object Pattern {
  def empty[E]: Pattern = Pattern(fields = Set.empty, topology = ListMap.empty)

  def node[E](node: IRField): Pattern = Pattern(fields = Set(node), topology = ListMap.empty)
}

final case class Pattern(
  fields: Set[IRField],
  topology: ListMap[IRField, Connection],
  properties: Map[IRField, MapExpression] = Map.empty,
  baseFields: Map[IRField, IRField]= Map.empty
) extends Binds {

  lazy val nodes: Set[IRField] = getEntity(CTNode)
  lazy val rels: Set[IRField] = getEntity(CTRelationship)

  private def getEntity(t: CypherType) =
    fields.collect { case e if e.cypherType.subTypeOf(t) => e }

  /**
    * Fuse patterns but fail if they disagree in the definitions of entities or connections
    *
    * @return A pattern that contains all entities and connections of their input
    */
  def ++(other: Pattern): Pattern = {
    val thisMap = fields.map(f => f.name -> f.cypherType).toMap
    val otherMap = other.fields.map(f => f.name -> f.cypherType).toMap

    verifyFieldTypes(thisMap, otherMap)

    val conflicts = topology.keySet.intersect(other.topology.keySet).filter(k => topology(k) != other.topology(k))
    if (conflicts.nonEmpty) throw PatternConversionException(
      s"Expected disjoint patterns but found conflicting connection for ${conflicts.head}:\n" +
        s"${topology(conflicts.head)} and ${other.topology(conflicts.head)}")
    val newTopology = topology ++ other.topology

    // Base field conflicts are checked by frontend
    val newBaseFields = baseFields ++ other.baseFields

    Pattern(fields ++ other.fields, newTopology, properties ++ other.properties, newBaseFields)
  }

  private def verifyFieldTypes(map1: Map[String, CypherType], map2: Map[String, CypherType]): Unit = {
    (map1.keySet ++ map2.keySet).foreach { f =>
      map1.get(f) -> map2.get(f) match {
        case (Some(t1), Some(t2)) =>
          if (t1 != t2)
            throw PatternConversionException(s"Expected disjoint patterns but found conflicting entities $f")
        case _ =>
      }
    }
  }

  def connectionsFor(node: IRField): Map[IRField, Connection] = {
    topology.filter {
      case (_, c) => c.endpoints.contains(node)
    }
  }

  def isEmpty: Boolean = this == Pattern.empty

  def withConnection(key: IRField, connection: Connection, propertiesOpt: Option[MapExpression] = None): Pattern = {
    val withProperties: Pattern = propertiesOpt match {
      case Some(props) => copy(properties = properties.updated(key, props))
      case None => this
    }

    if (topology.get(key).contains(connection)) withProperties else withProperties.copy(topology = topology.updated(key, connection))
  }

  def withEntity(field: IRField, propertiesOpt: Option[MapExpression] = None): Pattern = {
    val withProperties: Pattern = propertiesOpt match {
      case Some(props) => copy(properties = properties.updated(field, props))
      case None => this
    }

    if (fields(field)) withProperties else withProperties.copy(fields = fields + field)
  }

  def withBaseField(field: IRField, baseOpt: Option[IRField]): Pattern = baseOpt match {
    case Some(base) if fields.contains(field) => copy(baseFields = baseFields.updated(field, base))
    case _ => this
  }

  def components: Set[Pattern] = {
    val _fields = fields.foldLeft(Map.empty[IRField, Int]) { case (m, f) => m.updated(f, m.size) }
    val components = nodes.foldLeft(Map.empty[Int, Pattern]) {
      case (m, f) => m.updated(_fields(f), Pattern.node(f))
    }
    computeComponents(topology.toSeq, components, _fields.size, _fields)
  }

  @tailrec
  private def computeComponents(
    input: Seq[(IRField, Connection)],
    components: Map[Int, Pattern],
    count: Int,
    fieldToComponentIndex: Map[IRField, Int]
  ): Set[Pattern] = input match {
    case Seq((field, connection), tail@_*) =>
      val endpoints = connection.endpoints.toSet
      val links = endpoints.flatMap(fieldToComponentIndex.get)

      if (links.isEmpty) {
        // Connection forms a new connected component on its own
        val newCount = count + 1
        val newPattern = Pattern(
          fields = fields intersect endpoints,
          topology = ListMap(field -> connection)
        ).withEntity(field)
        val newComponents = components.updated(count, newPattern)
        val newFields = endpoints.foldLeft(fieldToComponentIndex) { case (m, endpoint) => m.updated(endpoint, count) }
        computeComponents(tail, newComponents, newCount, newFields)
      } else if (links.size == 1) {
        // Connection should be added to a single, existing component
        val link = links.head
        val oldPattern = components(link) // This is not supposed to fail
        val newPattern = oldPattern
          .withConnection(field, connection)
          .withEntity(field)
        val newComponents = components.updated(link, newPattern)
        computeComponents(tail, newComponents, count, fieldToComponentIndex)
      } else {
        // Connection bridges two connected components
        val fusedPattern = links.flatMap(components.get).reduce(_ ++ _)
        val newPattern = fusedPattern
          .withConnection(field, connection)
          .withEntity(field)
        val newCount = count + 1
        val newComponents = links
          .foldLeft(components) { case (m, l) => m - l }
          .updated(newCount, newPattern)
        val newFields = fieldToComponentIndex.mapValues(l => if (links(l)) newCount else l)
        computeComponents(tail, newComponents, newCount, newFields)
      }

    case Seq() =>
      components.values.toSet
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy