All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.apollographql.execution.websocket.SubscriptionWebSocketHandler.kt Maven / Gradle / Ivy

The newest version!
package com.apollographql.execution.websocket

import com.apollographql.apollo.annotations.ApolloInternal
import com.apollographql.apollo.api.Error
import com.apollographql.apollo.api.ExecutionContext
import com.apollographql.apollo.api.Optional
import com.apollographql.apollo.api.json.JsonWriter
import com.apollographql.apollo.api.json.jsonReader
import com.apollographql.apollo.api.json.readAny
import com.apollographql.apollo.api.json.writeAny
import com.apollographql.apollo.api.json.writeObject
import com.apollographql.execution.ExecutableSchema
import com.apollographql.execution.GraphQLRequest
import com.apollographql.execution.GraphQLResponse
import com.apollographql.execution.SubscriptionError
import com.apollographql.execution.SubscriptionResponse
import com.apollographql.execution.jsonWriter
import com.apollographql.execution.parseGraphQLRequest
import com.apollographql.execution.writeError
import kotlinx.atomicfu.locks.reentrantLock
import kotlinx.atomicfu.locks.withLock
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import okio.Buffer
import okio.Sink


/**
 * A [WebSocketHandler] that implements https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
 */
class SubscriptionWebSocketHandler(
    private val executableSchema: ExecutableSchema,
    private val scope: CoroutineScope,
    private val executionContext: ExecutionContext,
    private val sendMessage: (WebSocketMessage) -> Unit,
    private val connectionInitHandler: ConnectionInitHandler = { ConnectionInitAck },
) : WebSocketHandler {
  private val lock = reentrantLock()
  private val activeSubscriptions = mutableMapOf()
  private var isClosed: Boolean = false
  private var initJob: Job? = null

  override fun handleMessage(message: WebSocketMessage) {
    val clientMessage = when (message) {
      is WebSocketBinaryMessage -> message.data.decodeToString()
      is WebSocketTextMessage -> message.data
    }.parseApolloWebsocketClientMessage()

    when (clientMessage) {
      is SubscriptionWebsocketInit -> {
        initJob = lock.withLock {
          scope.launch {
            when(val result = connectionInitHandler.invoke(clientMessage.connectionParams)) {
              is ConnectionInitAck -> {
                sendMessage(SubscriptionWebsocketConnectionAck.toWsMessage())
              }
              is ConnectionInitError -> {
                sendMessage(SubscriptionWebsocketConnectionError(result.payload).toWsMessage())
              }
            }
          }
        }
      }

      is SubscriptionWebsocketStart -> {
        val isActive = lock.withLock {
          activeSubscriptions.containsKey(clientMessage.id)
        }
        if (isActive) {
          sendMessage(SubscriptionWebsocketError(id = clientMessage.id, error = Error.Builder("Subscription ${clientMessage.id} is already active").build()).toWsMessage())
          return
        }

        val flow = executableSchema.executeSubscription(clientMessage.request, executionContext + CurrentSubscription(clientMessage.id))

        val job = scope.launch {
          flow.collect {
            when (it) {
              is SubscriptionResponse -> {
                sendMessage(SubscriptionWebsocketData(id = clientMessage.id, response = it.response).toWsMessage())
              }

              is SubscriptionError -> {
                sendMessage(SubscriptionWebsocketError(id = clientMessage.id, error = it.errors.first()).toWsMessage())
              }
            }
          }
          sendMessage(SubscriptionWebsocketComplete(id = clientMessage.id).toWsMessage())
          lock.withLock {
            activeSubscriptions.remove(clientMessage.id)?.cancel()
          }
        }

        lock.withLock {
          activeSubscriptions.put(clientMessage.id, job)
        }
      }

      is SubscriptionWebsocketStop -> {
        lock.withLock {
          activeSubscriptions.remove(clientMessage.id)?.cancel()
        }
      }

      SubscriptionWebsocketTerminate -> {
        // nothing to do
      }

      is SubscriptionWebsocketClientMessageParseError -> {
        sendMessage(SubscriptionWebsocketError(null, Error.Builder("Cannot handle message (${clientMessage.message})").build()).toWsMessage())
      }
    }
  }

  fun close() {
    lock.withLock {
      if (isClosed) {
        return
      }

      activeSubscriptions.forEach {
        it.value.cancel()
      }
      activeSubscriptions.clear()

      initJob?.cancel()
      isClosed = true
    }
  }
}



internal sealed interface SubscriptionWebsocketClientMessageResult

internal class SubscriptionWebsocketClientMessageParseError internal constructor(
    val message: String,
) : SubscriptionWebsocketClientMessageResult

internal sealed interface SubscriptionWebsocketClientMessage : SubscriptionWebsocketClientMessageResult

internal class SubscriptionWebsocketInit(
    val connectionParams: Any?,
) : SubscriptionWebsocketClientMessage

internal class SubscriptionWebsocketStart(
    val id: String,
    val request: GraphQLRequest,
) : SubscriptionWebsocketClientMessage

internal class SubscriptionWebsocketStop(
    val id: String,
) : SubscriptionWebsocketClientMessage


internal object SubscriptionWebsocketTerminate : SubscriptionWebsocketClientMessage

internal sealed interface SubscriptionWebsocketServerMessage {
  fun serialize(sink: Sink)
}

private fun Sink.writeMessage(type: String, block: (JsonWriter.() -> Unit)? = null) {
  jsonWriter().apply {
    writeObject {
      name("type")
      value(type)
      block?.invoke(this)
    }
    flush()
  }
}

internal data object SubscriptionWebsocketConnectionAck : SubscriptionWebsocketServerMessage {
  override fun serialize(sink: Sink) {
    sink.writeMessage("connection_ack")
  }
}

internal class SubscriptionWebsocketConnectionError(private val payload: Optional) : SubscriptionWebsocketServerMessage {
  override fun serialize(sink: Sink) {
    sink.writeMessage("connection_error") {
      if (payload is Optional.Present<*>) {
        name("payload")
        writeAny(payload.value)
      }
    }
  }
}

internal class SubscriptionWebsocketData(
    val id: String,
    val response: GraphQLResponse,
) : SubscriptionWebsocketServerMessage {
  override fun serialize(sink: Sink) {
    sink.writeMessage("data") {
      name("id")
      value(id)
      name("payload")
      response.serialize(this)
    }
  }
}

internal class SubscriptionWebsocketError(
    val id: String?,
    val error: Error,
) : SubscriptionWebsocketServerMessage {

  override fun serialize(sink: Sink) {
    sink.writeMessage("error") {
      if (id != null) {
        name("id")
        value(id)
      }
      name("payload")
      writeError(error)
    }
  }
}

internal class SubscriptionWebsocketComplete(
    val id: String,
) : SubscriptionWebsocketServerMessage {
  override fun serialize(sink: Sink) {
    sink.writeMessage("complete") {
      name("id")
      value(id)
    }
  }
}

@OptIn(ApolloInternal::class)
internal fun String.parseApolloWebsocketClientMessage(): SubscriptionWebsocketClientMessageResult {
  @Suppress("UNCHECKED_CAST")
  val map = try {
    Buffer().writeUtf8(this).jsonReader().readAny() as Map
  } catch (e: Exception) {
    return SubscriptionWebsocketClientMessageParseError("Malformed Json: ${e.message}")
  }

  val type = map["type"]
  if (type == null) {
    return SubscriptionWebsocketClientMessageParseError("No 'type' found in $this")
  }
  if (type !is String) {
    return SubscriptionWebsocketClientMessageParseError("'type' must be a String in $this")
  }

  when (type) {
    "start", "stop" -> {
      val id = map["id"]
      if (id == null) {
        return SubscriptionWebsocketClientMessageParseError("No 'id' found in $this")
      }

      if (id !is String) {
        return SubscriptionWebsocketClientMessageParseError("'id' must be a String in $this")
      }

      if (type == "start") {
        val payload = map["payload"]
        if (payload == null) {
          return SubscriptionWebsocketClientMessageParseError("No 'payload' found in $this")
        }
        if (payload !is Map<*, *>) {
          return SubscriptionWebsocketClientMessageParseError("'payload' must be an Object in $this")
        }

        @Suppress("UNCHECKED_CAST")
        val request = (payload as Map).parseGraphQLRequest()
        return request.fold(
            onFailure = { SubscriptionWebsocketClientMessageParseError("Cannot parse start payload: '${it.message}'") },
            onSuccess = { SubscriptionWebsocketStart(id, request = it) }
        )
      } else {
        return SubscriptionWebsocketStop(id)
      }
    }

    "connection_init" -> {
      return SubscriptionWebsocketInit(map["payload"])
    }

    "connection_terminate" -> {
      return SubscriptionWebsocketTerminate
    }

    else -> return SubscriptionWebsocketClientMessageParseError("Unknown message type '$type'")
  }
}

private fun SubscriptionWebsocketServerMessage.toWsMessage(): WebSocketMessage {
  return WebSocketTextMessage(Buffer().apply { serialize(this) }.readUtf8())
}

sealed interface ConnectionInitResult
data object ConnectionInitAck : ConnectionInitResult
class ConnectionInitError(val payload: Optional = Optional.absent()): ConnectionInitResult

typealias ConnectionInitHandler = suspend (Any?) -> ConnectionInitResult

private class CurrentSubscription(val id: String) : ExecutionContext.Element {

  override val key: ExecutionContext.Key = Key

  companion object Key : ExecutionContext.Key
}

fun ExecutionContext.subscriptionId(): String = get(CurrentSubscription)?.id ?: error("Apollo: not executing a subscription")




© 2015 - 2024 Weber Informatics LLC | Privacy Policy