All Downloads are FREE. Search and download functionalities are using the official Maven repository.

caliban.ws.Protocol.scala Maven / Gradle / Ivy

The newest version!
package caliban.ws

import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.Value.StringValue
import caliban._
import zio.stm.TMap
import zio.stream.{ UStream, ZStream }
import zio.{ Duration, Promise, Queue, Random, Ref, Schedule, UIO, URIO, ZIO }

sealed trait Protocol {
  def name: String

  def make[R, E](
    interpreter: GraphQLInterpreter[R, E],
    keepAliveTime: Option[Duration],
    webSocketHooks: WebSocketHooks[R, E]
  ): URIO[R, CalibanPipe]

}

object Protocol {

  def fromName(name: String): Protocol =
    if (name.equalsIgnoreCase(GraphQLWS.name)) GraphQLWS
    else Legacy

  object GraphQLWS extends Protocol {
    object Ops {
      final val Next           = "next"
      final val Error          = "error"
      final val Complete       = "complete"
      final val Pong           = "pong"
      final val Ping           = "ping"
      final val Subscribe      = "subscribe"
      final val ConnectionInit = "connection_init"
      final val ConnectionAck  = "connection_ack"
    }

    final val name = "graphql-transport-ws"

    private val handler: ResponseHandler = new ResponseHandler {
      override def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput =
        GraphQLWSOutput(Ops.Next, Some(id), Some(r.toResponseValue))

      override def complete(id: String): GraphQLWSOutput =
        GraphQLWSOutput(Ops.Complete, Some(id), None)

      override def error[E](id: Option[String], e: E): GraphQLWSOutput =
        GraphQLWSOutput(
          Ops.Error,
          id,
          Some(ResponseValue.ListValue(List(e match {
            case e: CalibanError => e.toResponseValue
            case e               => StringValue(e.toString)
          })))
        )
    }

    override def make[R, E](
      interpreter: GraphQLInterpreter[R, E],
      keepAliveTime: Option[Duration],
      webSocketHooks: WebSocketHooks[R, E]
    ): URIO[R, CalibanPipe] =
      for {
        env           <- ZIO.environment[R]
        subscriptions <- SubscriptionManager.make
        ack           <- Ref.make(false)
        output        <- Queue.unbounded[Either[GraphQLWSClose, GraphQLWSOutput]]
        pipe          <- ZIO.succeed[CalibanPipe] { input =>
                           ZStream.scoped(
                             input.runForeach {
                               case GraphQLWSInput(Ops.ConnectionInit, id, payload)  =>
                                 val before     = ZIO.whenCase((webSocketHooks.beforeInit, payload)) {
                                   case (Some(beforeInit), Some(payload)) =>
                                     beforeInit(payload).orElse(output.offer(Left(GraphQLWSClose(4403, "Forbidden"))))
                                 }
                                 val ackPayload = webSocketHooks.onAck.fold[URIO[R, Option[ResponseValue]]](ZIO.none)(_.option)
                                 val response   =
                                   ack.set(true) *> ackPayload.flatMap(payload => output.offer(Right(connectionAck(payload))))
                                 val ka         = ping(keepAliveTime).runForeach(output.offer).fork
                                 val after      = ZIO.whenCase(webSocketHooks.afterInit) { case Some(afterInit) =>
                                   afterInit
                                     .catchAllCause(cause =>
                                       ZIO.foreachDiscard(cause.failureOption)(e =>
                                         generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e))))
                                       ) *> output.offer(Left(GraphQLWSClose(4401, "Unauthorized")))
                                     )
                                     .fork
                                 }

                                 before *> response *> ka *> after
                               case GraphQLWSInput(Ops.Pong, id, payload)            =>
                                 ZIO.whenCase(webSocketHooks.onPong -> payload) { case (Some(onPong), Some(payload)) =>
                                   onPong(payload).catchAll(e =>
                                     generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e))))
                                   )
                                 }
                               case GraphQLWSInput(Ops.Ping, id, payload)            =>
                                 def sendPong(p: Option[ResponseValue]) = output.offer(Right(GraphQLWSOutput(Ops.Pong, id, p)))

                                 webSocketHooks.onPing match {
                                   case Some(onPing) =>
                                     onPing(payload)
                                       .flatMap(sendPong)
                                       .catchAll(e =>
                                         generateId(id).flatMap(uuid => output.offer(Right(handler.error(uuid, e))))
                                       )
                                   case _            => sendPong(None)
                                 }
                               case GraphQLWSInput(Ops.Subscribe, Some(id), payload) =>
                                 val request = payload.collect { case InputValue.ObjectValue(fields) =>
                                   val query         = fields.get("query").collect { case StringValue(v) => v }
                                   val operationName = fields.get("operationName").collect { case StringValue(v) => v }
                                   val variables     = fields.get("variables").collect { case InputValue.ObjectValue(v) => v }
                                   val extensions    = fields.get("extensions").collect { case InputValue.ObjectValue(v) => v }
                                   GraphQLRequest(query, operationName, variables, extensions)
                                 }

                                 val continue = request match {
                                   case Some(req) =>
                                     val stream = handler.generateGraphQLResponse(req, id, interpreter, subscriptions)

                                     ZIO.ifZIO(subscriptions.isTracking(id))(
                                       output.offer(Left(GraphQLWSClose(4409, s"Subscriber for $id already exists"))).unit,
                                       webSocketHooks.onMessage
                                         .fold(stream)(stream.via(_))
                                         .map(Right(_))
                                         .runForeachChunk(output.offerAll)
                                         .catchAll(e => output.offer(Right(handler.error(Some(id), e))))
                                         .fork
                                         .interruptible
                                         .unit
                                     )

                                   case None =>
                                     generateId(None).flatMap(uuid => output.offer(Right(connectionError(uuid))))
                                 }

                                 ZIO.ifZIO(ack.get)(continue, output.offer(Left(GraphQLWSClose(4401, "Unauthorized"))))
                               case GraphQLWSInput(Ops.Complete, Some(id), _)        =>
                                 subscriptions.untrack(id)
                               case GraphQLWSInput(unsupported, _, _)                =>
                                 output.offer(Left(GraphQLWSClose(4400, s"Unsupported operation: $unsupported")))
                             }.interruptible
                               .orElse(
                                 generateId(None).flatMap(uuid => output.offer(Right(connectionError(Some(uuid.toString)))))
                               )
                               .ensuring(subscriptions.untrackAll)
                               .provideEnvironment(env)
                               .forkScoped
                           ) *> ZStream.fromQueueWithShutdown(output)
                         }
      } yield pipe

    private def connectionError(id: Option[String]): GraphQLWSOutput           = GraphQLWSOutput(Ops.Error, id, None)
    private def connectionAck(payload: Option[ResponseValue]): GraphQLWSOutput =
      GraphQLWSOutput(Ops.ConnectionAck, None, payload)

    private def generateId(id: Option[String]): ZIO[Any, Nothing, Option[String]] =
      id match {
        case Some(_) => ZIO.succeed(id)
        case None    => Random.nextUUID.map(uuid => Some(uuid.toString))
      }

    private def ping(keepAlive: Option[Duration]): UStream[Either[Nothing, GraphQLWSOutput]] =
      keepAlive match {
        case None           => ZStream.empty
        case Some(duration) =>
          ZStream
            .repeatWithSchedule(Right(GraphQLWSOutput(Ops.Ping, None, None)), Schedule.spaced(duration))
      }

  }

  object Legacy extends Protocol {
    object Ops {
      final val ConnectionInit      = "connection_init"
      final val ConnectionAck       = "connection_ack"
      final val ConnectionKeepAlive = "ka"
      final val ConnectionTerminate = "connection_terminate"
      final val Start               = "start"
      final val Stop                = "stop"
      final val Error               = "error"
      final val ConnectionError     = "connection_error"
      final val Complete            = "complete"
      final val Data                = "data"
    }

    final val name = "graphql-ws"

    private val handler: ResponseHandler = new ResponseHandler {
      override def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput =
        GraphQLWSOutput(Ops.Data, Some(id), Some(r.toResponseValue))

      override def complete(id: String): GraphQLWSOutput =
        GraphQLWSOutput(Ops.Complete, Some(id), None)

      override def error[E](id: Option[String], e: E): GraphQLWSOutput =
        GraphQLWSOutput(
          Ops.Error,
          id,
          Some(ResponseValue.ListValue(List(e match {
            case e: CalibanError => e.toResponseValue
            case e               => StringValue(e.toString)
          })))
        )
    }

    override def make[R, E](
      interpreter: GraphQLInterpreter[R, E],
      keepAliveTime: Option[Duration],
      webSocketHooks: WebSocketHooks[R, E]
    ): URIO[R, CalibanPipe] =
      for {
        env           <- ZIO.environment[R]
        ack           <- Ref.make(false)
        subscriptions <- SubscriptionManager.make
        output        <- Queue.unbounded[Either[GraphQLWSClose, GraphQLWSOutput]]
        pipe          <- ZIO.succeed[CalibanPipe] { input =>
                           ZStream
                             .acquireReleaseWith(
                               input.runForeach {
                                 case GraphQLWSInput(Ops.ConnectionInit, id, payload) =>
                                   val before     = ZIO.whenCase((webSocketHooks.beforeInit, payload)) {
                                     case (Some(beforeInit), Some(payload)) =>
                                       beforeInit(payload).catchAll(e => output.offer(Right(handler.error(id, e))))
                                   }
                                   val ackPayload = webSocketHooks.onAck.fold[URIO[R, Option[ResponseValue]]](ZIO.none)(_.option)

                                   val response =
                                     ack.set(true) *> ackPayload.flatMap(payload => output.offer(Right(connectionAck(payload))))
                                   val ka       = keepAlive(keepAliveTime).runForeach(o => output.offer(Right(o))).fork
                                   val after    = ZIO.whenCase(webSocketHooks.afterInit) { case Some(afterInit) =>
                                     afterInit
                                       .catchAllCause(cause =>
                                         ZIO.foreachDiscard(cause.failureOption)(e =>
                                           output.offer(Right(handler.error(id, e)))
                                         ) *> output.offer(Left(GraphQLWSClose(4401, "Unauthorized")))
                                       )
                                       .fork
                                   }

                                   before *> response *> ka *> after
                                 case GraphQLWSInput(Ops.Start, id, payload)          =>
                                   val request  = payload.collect { case InputValue.ObjectValue(fields) =>
                                     val query         = fields.get("query").collect { case StringValue(v) => v }
                                     val operationName = fields.get("operationName").collect { case StringValue(v) => v }
                                     val variables     = fields.get("variables").collect { case InputValue.ObjectValue(v) => v }
                                     val extensions    = fields.get("extensions").collect { case InputValue.ObjectValue(v) => v }
                                     GraphQLRequest(query, operationName, variables, extensions)
                                   }
                                   val continue = request match {
                                     case Some(req) =>
                                       val stream =
                                         handler.generateGraphQLResponse(req, id.getOrElse(""), interpreter, subscriptions)
                                       webSocketHooks.onMessage
                                         .fold(stream)(stream.via(_))
                                         .runForeachChunk(o => output.offerAll(o.map(Right(_))))
                                         .catchAll(e => output.offer(Right(handler.error(id, e))))
                                         .fork
                                         .interruptible
                                         .unit

                                     case None => output.offer(Right(connectionError))
                                   }

                                   ZIO.ifZIO(ack.get)(continue, output.offer(Left(GraphQLWSClose(4401, "Unauthorized"))))
                                 case GraphQLWSInput(Ops.Stop, Some(id), _)           =>
                                   subscriptions.untrack(id)
                                 case GraphQLWSInput(Ops.ConnectionTerminate, _, _)   =>
                                   ZIO.interrupt
                                 case _                                               =>
                                   ZIO.unit
                               }.interruptible
                                 .orElse(output.offer(Right(connectionError)))
                                 .ensuring(subscriptions.untrackAll)
                                 .provideEnvironment(env)
                                 .forkDaemon
                             )(_.interrupt) *> ZStream.fromQueueWithShutdown(output)
                         }
      } yield pipe

    private def keepAlive(keepAlive: Option[Duration]): UStream[GraphQLWSOutput] =
      keepAlive match {
        case None           => ZStream.empty
        case Some(duration) =>
          ZStream
            .repeatWithSchedule(GraphQLWSOutput(Ops.ConnectionKeepAlive, None, None), Schedule.spaced(duration))
      }

    private val connectionError: GraphQLWSOutput                               = GraphQLWSOutput(Ops.ConnectionError, None, None)
    private def connectionAck(payload: Option[ResponseValue]): GraphQLWSOutput =
      GraphQLWSOutput(Ops.ConnectionAck, None, payload)
  }

  private trait ResponseHandler {
    self =>
    def toResponse[E](id: String, fieldName: String, r: ResponseValue, errors: List[E]): GraphQLWSOutput =
      toResponse(id, GraphQLResponse(ObjectValue(List(fieldName -> r)), errors))

    def toResponse[E](id: String, r: GraphQLResponse[E]): GraphQLWSOutput

    def complete(id: String): GraphQLWSOutput

    def error[E](id: Option[String], e: E): GraphQLWSOutput

    def toStreamComplete(id: String): UStream[GraphQLWSOutput] =
      ZStream.succeed(complete(id))

    def toStreamError[E](id: Option[String], e: E): UStream[GraphQLWSOutput] =
      ZStream.succeed(error(id, e))

    final def generateGraphQLResponse[R, E](
      payload: GraphQLRequest,
      id: String,
      interpreter: GraphQLInterpreter[R, E],
      subscriptions: SubscriptionManager
    ): ZStream[R, E, GraphQLWSOutput] = {
      val resp =
        ZStream
          .fromZIO(interpreter.executeRequest(payload))
          .flatMap(res =>
            res.data match {
              case ObjectValue((fieldName, StreamValue(stream)) :: Nil) =>
                subscriptions.track(id).flatMap { p =>
                  stream.map(self.toResponse(id, fieldName, _, res.errors)).interruptWhen(p)
                }
              case other                                                =>
                ZStream.succeed(self.toResponse(id, GraphQLResponse(other, res.errors)))
            }
          )

      (resp ++ self.toStreamComplete(id)).catchAll(self.toStreamError(Option(id), _))
    }
  }

  private class SubscriptionManager private (private val tracked: TMap[String, Promise[Any, Unit]]) {
    def track(id: String): UStream[Promise[Any, Unit]] =
      ZStream.fromZIO(Promise.make[Any, Unit].tap(tracked.put(id, _).commit))

    def untrack(id: String): UIO[Unit] =
      (tracked.get(id) <* tracked.delete(id)).commit.flatMap {
        case None    => ZIO.unit
        case Some(p) => p.succeed(()).unit
      }

    def untrackAll: UIO[Unit] =
      tracked.keys.map(ids => ZIO.foreachDiscard(ids)(untrack)).commit.flatten

    def isTracking(id: String): UIO[Boolean] = tracked.contains(id).commit
  }

  private object SubscriptionManager {
    val make = TMap.make[String, Promise[Any, Unit]]().map(new SubscriptionManager(_)).commit
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy