All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opencypher.okapi.ir.impl.ExpressionConverter.scala Maven / Gradle / Ivy

/*
 * Copyright (c) 2016-2019 "Neo4j Sweden, AB" [https://neo4j.com]
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Attribution Notice under the terms of the Apache License 2.0
 *
 * This work was created by the collective efforts of the openCypher community.
 * Without limiting the terms of Section 6, any Derivative Work that is not
 * approved by the public consensus process of the openCypher Implementers Group
 * should not be described as “Cypher” (and Cypher® is a registered trademark of
 * Neo4j Inc.) or as "openCypher". Extensions by implementers or prototypes or
 * proposals for change that have been documented or implemented should only be
 * described as "implementation extensions to Cypher" or as "proposed changes to
 * Cypher that are not yet approved by the openCypher community".
 */
package org.opencypher.okapi.ir.impl

import org.opencypher.okapi.api.types.CypherType._
import org.opencypher.okapi.api.types._
import org.opencypher.okapi.impl.exception.NotImplementedException
import org.opencypher.okapi.ir.api._
import org.opencypher.okapi.ir.api.expr._
import org.opencypher.okapi.ir.impl.OperatorTyping._
import org.opencypher.okapi.ir.impl.parse.functions.FunctionExtensions
import org.opencypher.okapi.ir.impl.parse.{functions => f}
import org.opencypher.okapi.ir.impl.typer.{InvalidArgument, InvalidContainerAccess, MissingParameter, NoSuitableSignatureForExpr, SignatureConverter, UnTypedExpr}
import org.opencypher.v9_0.expressions.{OperatorExpression, RegexMatch, TypeSignatures, functions}
import org.opencypher.v9_0.{expressions => ast}

import scala.language.implicitConversions

object AddType {

  private implicit class RichCTList(val left: CTList) extends AnyVal {
    def listConcatJoin(right: CypherType): CypherType = (left, right) match {
      case (CTList(lInner), CTList(rInner)) => CTList(lInner join rInner)
      case (CTList(lInner), _) => CTList(lInner join right)
    }
  }

  def apply(lhs: CypherType, rhs: CypherType): Option[CypherType] = {
    val lookup = lhs.material -> rhs.material match {
      case (CTVoid, _) => CTNull
      case (_, CTVoid) => CTNull
      case (left: CTList, _) => left listConcatJoin rhs
      case (_, right: CTList) => right listConcatJoin lhs
      case (CTString, _) if rhs.subTypeOf(CTNumber) => CTString
      case (_, CTString) if lhs.subTypeOf(CTNumber) => CTString
      case (CTString, CTString) => CTString
      case (CTDuration, CTDuration) => CTDuration
      case (CTLocalDateTime, CTDuration) => CTLocalDateTime
      case (CTDuration, CTLocalDateTime) => CTLocalDateTime
      case (CTDate, CTDuration) => CTDate
      case (CTDuration, CTDate) => CTDate
      case (CTInteger, CTInteger) => CTInteger
      case (CTFloat, CTInteger) => CTFloat
      case (CTInteger, CTFloat) => CTFloat
      case (CTFloat, CTFloat) => CTFloat
      case (CTNumber, y) if y.subTypeOf(CTNumber) => CTNumber
      case (x, CTNumber) if x.subTypeOf(CTNumber) => CTNumber
      case _ => null
    }
    Option(lookup).map(_.asNullableAs(lhs.join(rhs)))
  }

}

final class ExpressionConverter(context: IRBuilderContext) {

  private def schema = context.workingGraph.schema

  private def parameterType(p: ast.Parameter): CypherType = {
    context.parameters.get(p.name) match {
      case None => throw MissingParameter(p.name)
      case Some(param) => param.cypherType
    }
  }

  def convert(e: ast.Expression): Expr = e match {
    case ast.Variable(name) => Var(name)(context.knownTypes.getOrElse(e, throw UnTypedExpr(e)))
    case [email protected](name, _) => Param(name)(parameterType(p))

    // Literals
    case astExpr: ast.IntegerLiteral => IntegerLit(astExpr.value)
    case ast.StringLiteral(value) => StringLit(value)
    case _: ast.True => TrueLit
    case _: ast.False => FalseLit
    case ast.ListLiteral(exprs) =>
      val elements = exprs.map(convert).toList
      val elementType = elements.foldLeft(CTVoid: CypherType) { case (agg, nextExpr) => agg.join(nextExpr.cypherType) }
      ListLit(elements)(CTList(elementType))

    case ast.Property(m, ast.PropertyKeyName(name)) =>
      val mapLike = convert(m)
      val propertyType = mapLike.cypherType.material match {
        case CTVoid => CTNull
        // This means that the node can have any possible label combination, as the user did not specify any constraints
        case n: CTNode if n.labels.isEmpty =>
          schema.allCombinations
            .map(l => schema.nodePropertyKeyType(l, name).getOrElse(CTNull))
            .foldLeft(CTVoid: CypherType)(_ join _)
        // User specified label constraints - we can use those for type inference
        case CTNode(labels, _) =>
          schema.nodePropertyKeyType(labels, name).getOrElse(CTNull)
        case CTRelationship(types, _) =>
          schema.relationshipPropertyKeyType(types, name).getOrElse(CTNull)
        case CTMap(inner) =>
          inner.getOrElse(name, CTVoid)
        case _: TemporalValueCypherType =>
          CTInteger.asNullableAs(mapLike.cypherType)
        case _ => throw InvalidContainerAccess(e)
      }
      Property(mapLike, PropertyKey(name))(propertyType)

    // Predicates
    case ast.Ands(expressions) => Ands(expressions.map(convert))
    case ast.Ors(expressions) => Ors(expressions.map(convert))
    case ast.HasLabels(node, labels) =>
      val exprs = labels.map { l: ast.LabelName =>
        HasLabel(convert(node), Label(l.name))
      }
      if (exprs.size == 1) exprs.head else Ands(exprs.toSet)
    case ast.Not(expr) => Not(convert(expr))
    case ast.Equals(f: ast.FunctionInvocation, s: ast.StringLiteral) if f.function == functions.Type =>
      HasType(convert(f.args.head), RelType(s.value))
    case ast.Equals(lhs, rhs) => Equals(convert(lhs), convert(rhs))
    case ast.LessThan(lhs, rhs) => LessThan(convert(lhs), convert(rhs))
    case ast.LessThanOrEqual(lhs, rhs) => LessThanOrEqual(convert(lhs), convert(rhs))
    case ast.GreaterThan(lhs, rhs) => GreaterThan(convert(lhs), convert(rhs))
    case ast.GreaterThanOrEqual(lhs, rhs) => GreaterThanOrEqual(convert(lhs), convert(rhs))
    // if the list only contains a single element, convert to simple equality to avoid list construction
    case ast.In(lhs, ast.ListLiteral(elems)) if elems.size == 1 => Equals(convert(lhs), convert(elems.head))
    case ast.In(lhs, rhs) => In(convert(lhs), convert(rhs))
    case ast.IsNull(expr) => IsNull(convert(expr))
    case ast.IsNotNull(expr) => IsNotNull(convert(expr))
    case ast.StartsWith(lhs, rhs) => StartsWith(convert(lhs), convert(rhs))
    case ast.EndsWith(lhs, rhs) => EndsWith(convert(lhs), convert(rhs))
    case ast.Contains(lhs, rhs) => Contains(convert(lhs), convert(rhs))

    // Arithmetics
    case ast.Add(lhs, rhs) =>
      val convertedLhs = convert(lhs)
      val convertedRhs = convert(rhs)
      val addType = AddType(convertedLhs.cypherType, convertedRhs.cypherType).getOrElse(
        throw NoSuitableSignatureForExpr(e, Seq(convertedLhs.cypherType, convertedRhs.cypherType))
      )

      Add(convertedLhs, convertedRhs)(addType)
    case [email protected](lhs, rhs) =>

      val convertedLhs = convert(lhs)
      val convertedRhs = convert(rhs)
      val exprType = s.returnTypeFor(convertedLhs.cypherType, convertedRhs.cypherType)

      Subtract(convertedLhs, convertedRhs)(exprType)
    case [email protected](lhs, rhs) =>
      val convertedLhs = convert(lhs)
      val convertedRhs = convert(rhs)
      val exprType = m.returnTypeFor(convertedLhs.cypherType, convertedRhs.cypherType)

      Multiply(convertedLhs, convertedRhs)(exprType)
    case [email protected](lhs, rhs) =>
      val convertedLhs = convert(lhs)
      val convertedRhs = convert(rhs)
      val exprType = d.returnTypeFor(convertedLhs.cypherType, convertedRhs.cypherType)

      Divide(convertedLhs, convertedRhs)(exprType)

    case funcInv: ast.FunctionInvocation =>
      val convertedArgs = funcInv.args.map(convert).toList
      def returnType: CypherType = funcInv.returnTypeFor(convertedArgs.map(_.cypherType): _*)

      val distinct = funcInv.distinct

      def arg0 = convertedArgs(0)

      def arg1 = convertedArgs(1)

      def arg2 = convertedArgs(2)

      funcInv.function match {
        case functions.Id => Id(arg0)
        case functions.Labels => Labels(arg0)
        case functions.Type => Type(arg0)
        case functions.Avg => Avg(arg0)
        case functions.Max => Max(arg0)(returnType)
        case functions.Min => Min(arg0)(returnType)
        case functions.Sum => Sum(arg0)(returnType)
        case functions.Count => Count(arg0, distinct)
        case functions.Collect => Collect(arg0, distinct)
        case functions.Exists => Exists(arg0)
        case functions.Size => Size(arg0)
        case functions.Keys => Keys(arg0)(returnType)
        case functions.StartNode => StartNodeFunction(arg0)(returnType)
        case functions.EndNode => EndNodeFunction(arg0)(returnType)
        case functions.ToFloat => ToFloat(arg0)
        case functions.ToInteger => ToInteger(arg0)
        case functions.ToString => ToString(arg0)
        case functions.ToBoolean => ToBoolean(arg0)
        case functions.Coalesce =>
          // Special optimisation for coalesce using short-circuit logic
          convertedArgs.map(_.cypherType).indexWhere(!_.isNullable) match {
            case 0 =>
              // first argument is non-nullable; just use it directly without coalesce
              convertedArgs.head
            case -1 =>
              // nothing was non-nullable; keep all args
              val outType = convertedArgs.map(_.cypherType).reduceLeft(_ join _)
              Coalesce(convertedArgs)(outType)
            case other =>
              // keep only the args up until the first non-nullable (inclusive)
              val relevantArgs = convertedArgs.slice(0, other + 1)
              val outType = relevantArgs.map(_.cypherType).reduceLeft(_ join _)
              Coalesce(relevantArgs)(outType.material)
          }
        case functions.Range => Range(arg0, arg1, convertedArgs.lift(2))
        case functions.Substring => Substring(arg0, arg1, convertedArgs.lift(2))
        case functions.Left => Substring(arg0, IntegerLit(0), convertedArgs.lift(1))
        case functions.Right => Substring(arg0, Subtract(Multiply(IntegerLit(-1), arg1)(CTInteger), IntegerLit(1))(CTInteger), None)
        case functions.Replace => Replace(arg0, arg1, arg2)
        case functions.Trim => Trim(arg0)
        case functions.LTrim => LTrim(arg0)
        case functions.RTrim => RTrim(arg0)
        case functions.ToUpper => ToUpper(arg0)
        case functions.ToLower => ToLower(arg0)
        case functions.Properties =>
          val outType = arg0.cypherType.material match {
            case CTVoid => CTNull
            case CTNode(labels, _) =>
              CTMap(schema.nodePropertyKeysForCombinations(schema.combinationsFor(labels)))
            case CTRelationship(types, _) =>
              CTMap(schema.relationshipPropertyKeysForTypes(types))
            case m: CTMap => m
            case _ => throw InvalidArgument(funcInv, funcInv.args(0))
          }
          Properties(arg0)(outType)

        // Logarithmic functions
        case functions.Sqrt => Sqrt(arg0)
        case functions.Log => Log(arg0)
        case functions.Log10 => Log10(arg0)
        case functions.Exp => Exp(arg0)
        case functions.E => E
        case functions.Pi => Pi

        // Numeric functions
        case functions.Abs => Abs(arg0)(returnType)
        case functions.Ceil => Ceil(arg0)
        case functions.Floor => Floor(arg0)
        case functions.Rand => Rand
        case functions.Round => Round(arg0)
        case functions.Sign => Sign(arg0)

        // Trigonometric functions
        case functions.Acos => Acos(arg0)
        case functions.Asin => Asin(arg0)
        case functions.Atan => Atan(arg0)
        case functions.Atan2 => Atan2(arg0, arg1)
        case functions.Cos => Cos(arg0)
        case functions.Cot => Cot(arg0)
        case functions.Degrees => Degrees(arg0)
        case functions.Haversin => Haversin(arg0)
        case functions.Radians => Radians(arg0)
        case functions.Sin => Sin(arg0)
        case functions.Tan => Tan(arg0)

        // Match by name
        case functions.UnresolvedFunction => funcInv.name match {
          // Time functions
          case f.Timestamp.name => Timestamp
          case f.LocalDateTime.name => LocalDateTime(convertedArgs.headOption)
          case f.Date.name => Date(convertedArgs.headOption)
          case f.Duration.name => Duration(arg0)
          case name => throw NotImplementedException(s"Support for converting function '$name' is not yet implemented")
        }

        case a: functions.Function =>
          throw NotImplementedException(s"Support for converting function '${a.name}' is not yet implemented")
      }

    case _: ast.CountStar => CountStar

    // Exists (rewritten Pattern Expressions)
    case org.opencypher.okapi.ir.impl.parse.rewriter.ExistsPattern(subquery, trueVar) =>
      val innerModel = IRBuilder(subquery)(context) match {
        case sq: SingleQuery => sq
        case _ => throw new IllegalArgumentException("ExistsPattern only accepts SingleQuery")
      }
      ExistsPatternExpr(
        Var(trueVar.name)(CTBoolean),
        innerModel
      )

    // Case When .. Then .. [Else ..] End
    case ast.CaseExpression(None, alternatives, default) =>
      val convertedAlternatives = alternatives.toList.map { case (left, right) => convert(left) -> convert(right) }
      val maybeConvertedDefault: Option[Expr] = default.map(expr => convert(expr))
      val possibleTypes = convertedAlternatives.map { case (_, thenExpr) => thenExpr.cypherType }
      val defaultCaseType = maybeConvertedDefault.map(_.cypherType).getOrElse(CTNull)
      val returnType = possibleTypes.foldLeft(defaultCaseType)(_ join _)
      CaseExpr(convertedAlternatives, maybeConvertedDefault)(returnType)

    case ast.MapExpression(items) =>
      val convertedMap = items.map { case (key, value) => key.name -> convert(value) }.toMap
      val mapType = CTMap(convertedMap.map { case (key, value) => key -> value.cypherType })
      MapExpression(convertedMap)(mapType)

    // Expression
    case ast.ListSlice(list, Some(from), Some(to)) => ListSliceFromTo(convert(list), convert(from), convert(to))
    case ast.ListSlice(list, None, Some(to)) => ListSliceTo(convert(list), convert(to))
    case ast.ListSlice(list, Some(from), None) => ListSliceFrom(convert(list), convert(from))

    case ast.ContainerIndex(container, index) =>
      val convertedContainer = convert(container)
      val elementType = convertedContainer.cypherType.material match {
        case CTList(eltTyp) => eltTyp
        case CTMap(innerTypes) =>
          index match {
            case ast.Parameter(name, _) =>
              val key = context.parameters(name).cast[String]
              innerTypes.getOrElse(key, CTVoid)
            case ast.StringLiteral(key) => innerTypes.getOrElse(key, CTVoid)
            case _ => innerTypes.values.foldLeft(CTVoid: CypherType)(_ join _).nullable
          }
        case _ => throw InvalidContainerAccess(e)
      }
      ContainerIndex(convertedContainer, convert(index))(elementType)

    case ast.Null() => NullLit

    case RegexMatch(lhs, rhs) => expr.RegexMatch(convert(lhs), convert(rhs))

    case _ =>
      throw NotImplementedException(s"Not yet able to convert expression: $e")
  }

}

object OperatorTyping {

  def returnTypeFor(signatures: Seq[ast.TypeSignature], args: Seq[CypherType]): Option[CypherType] = {
    val expandedSignatures = SignatureConverter.from(signatures)
      .expandWithNulls
      .expandWithSubstitutions(CTFloat, CTInteger)
      .signatures

    val possibleReturnTypes = expandedSignatures.filter { sig =>
      sig.input.zip(args).forall {
        case (sigType, argType) =>
          argType.couldBeSameTypeAs(sigType)
      }
    }.map(_.output)

    possibleReturnTypes.reduceLeftOption(_ join _)
  }

  implicit class RichOperatorExpression(val o: ast.Expression with OperatorExpression) {
    def returnTypeFor(args: CypherType*): CypherType = {
      OperatorTyping.returnTypeFor(o.signatures, args).getOrElse(throw NoSuitableSignatureForExpr(o, args))
    }
  }

  implicit class RichTypeSignatures(val f: ast.FunctionInvocation) {
    def returnTypeFor(args: CypherType*): CypherType = {

      val signatures = FunctionExtensions.getOrElse(f.function.name, f.function) match {
        case t: TypeSignatures => t.signatures
        case _ => throw NoSuitableSignatureForExpr(f, args)
      }

      OperatorTyping.returnTypeFor(signatures, args).getOrElse(throw NoSuitableSignatureForExpr(f, args))
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy