All Downloads are FREE. Search and download functionalities are using the official Maven repository.

main.com.netflix.graphql.dgs.subscriptions.websockets.WebsocketGraphQLTransportWSProtocolHandler.kt Maven / Gradle / Ivy

/*
 * Copyright 2022 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netflix.graphql.dgs.subscriptions.websockets

import com.fasterxml.jackson.databind.ObjectMapper
import com.netflix.graphql.dgs.DgsQueryExecutor
import com.netflix.graphql.types.subscription.websockets.CloseCode
import com.netflix.graphql.types.subscription.websockets.Message
import graphql.ExecutionResult
import graphql.GraphqlErrorBuilder
import jakarta.annotation.PostConstruct
import jakarta.annotation.PreDestroy
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import org.springframework.web.socket.CloseStatus
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketSession
import org.springframework.web.socket.handler.TextWebSocketHandler
import java.time.Duration
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList

/**
 * WebSocketHandler for GraphQL based on
 * GraphQL Over WebSocket Protocol and
 * for use in DGS framework.
 */
class WebsocketGraphQLTransportWSProtocolHandler(
    private val dgsQueryExecutor: DgsQueryExecutor,
    private val connectionInitTimeout: Duration,
    private val subscriptionErrorLogLevel: Level,
    private val objectMapper: ObjectMapper
) : TextWebSocketHandler() {

    internal val sessions = CopyOnWriteArrayList()
    internal val contexts = ConcurrentHashMap>()

    @Volatile
    private var timer: Timer? = null

    @PostConstruct
    fun setupCleanup() {
        val timer = Timer("dgs-graphql-ws-transport-session-cleanup", true)
        this.timer = timer

        val timerTask = object : TimerTask() {
            override fun run() {
                for (session in sessions) {
                    if (!session.isOpen) {
                        cleanupSubscriptionsForSession(session)
                    }
                }
            }
        }
        timer.scheduleAtFixedRate(timerTask, 0, 5000)
    }

    @PreDestroy
    fun destroy() {
        val timer = this.timer ?: return
        timer.cancel()
        this.timer = null
    }

    override fun afterConnectionEstablished(session: WebSocketSession) {
        val context = Context()
        contexts[session.id] = context
        val timer = Timer("dgs-graphql-ws-transport-connection-timeout-watchdog-${session.id}", true)

        val timerTask = object : TimerTask() {
            override fun run() {
                if (!context.getConnectionInitReceived()) {
                    session.close(CloseStatus(CloseCode.ConnectionInitialisationTimeout.code))
                    contexts.remove(session.id)
                }
                timer.cancel()
            }
        }

        timer.schedule(timerTask, connectionInitTimeout.toMillis())
    }

    override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
        if (status == CloseStatus.NORMAL) {
            cleanupSubscriptionsForSession(session)
        }
    }

    public override fun handleTextMessage(session: WebSocketSession, textMessage: TextMessage) {
        val message = objectMapper.readValue(textMessage.payload, Message::class.java)
        val context = contexts[session.id]!!

        when (message) {
            is Message.ConnectionInitMessage -> {
                logger.info("Initialized connection for {}", session.id)
                if (context.setConnectionInitReceived()) {
                    return session.close(CloseStatus(CloseCode.BadRequest.code, "Too many initialisation requests"))
                }
                sessions.add(session)

                context.connectionParams = message.payload
                try {
                    session.sendMessage(
                        TextMessage(
                            objectMapper.writeValueAsBytes(
                                Message.ConnectionAckMessage()
                            )
                        )
                    )
                    context.acknowledged = true
                } catch (e: Throwable) {
                    session.close(CloseStatus(CloseCode.Forbidden.code, "Forbidden"))
                }
            }
            is Message.PingMessage -> {
                session.sendMessage(
                    TextMessage(
                        objectMapper.writeValueAsBytes(
                            Message.PongMessage(
                                payload = message.payload
                            )
                        )
                    )
                )
            }
            is Message.PongMessage -> {
            }
            is Message.SubscribeMessage -> {
                if (!context.acknowledged) {
                    return session.close(CloseStatus(CloseCode.Unauthorized.code, "Unauthorized"))
                }
                val (id, payload) = message
                if (context.subscriptions.contains(id)) {
                    return session.close(CloseStatus(CloseCode.SubscriberAlreadyExists.code, "Subscriber for $id already exists"))
                }

                handleSubscription(id, payload, session)
            }
            is Message.CompleteMessage -> {
                logger.info("Complete subscription for " + message.id)
                val subscription = context.subscriptions.remove(message.id)
                subscription?.cancel()
            }
            else -> session.close(CloseStatus(CloseCode.BadRequest.code, "Unexpected message format"))
        }
    }

    private fun cleanupSubscriptionsForSession(session: WebSocketSession) {
        logger.info("Cleaning up for session {}", session.id)
        contexts[session.id]?.subscriptions?.values?.forEach { it.cancel() }
        contexts.remove(session.id)
        sessions.remove(session)
    }

    private fun handleSubscription(
        id: String,
        payload: Message.SubscribeMessage.Payload,
        session: WebSocketSession
    ) {
        val executionResult: ExecutionResult =
            dgsQueryExecutor.execute(
                payload.query,
                payload.variables.orEmpty(),
                payload.extensions,
                null,
                payload.operationName,
                null
            )

        val subscriptionStream: Publisher = executionResult.getData()

        subscriptionStream.subscribe(object : Subscriber {
            override fun onSubscribe(s: Subscription) {
                logger.info("Subscription started for {}", id)
                contexts[session.id]?.subscriptions?.set(id, s)

                s.request(1)
            }

            override fun onNext(er: ExecutionResult) {
                val message = Message.NextMessage(
                    payload = com.netflix.graphql.types.subscription.websockets.ExecutionResult(er.getData(), er.errors),
                    id = id
                )
                val jsonMessage = TextMessage(objectMapper.writeValueAsBytes(message))
                logger.debug("Sending subscription data: {}", jsonMessage)

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                    contexts[session.id]?.subscriptions?.get(id)?.request(1)
                }
            }

            override fun onError(t: Throwable) {
                when (subscriptionErrorLogLevel) {
                    Level.ERROR -> logger.error("Error on subscription {}", id, t)
                    Level.WARN -> logger.warn("Error on subscription {}", id, t)
                    Level.INFO -> logger.info("Error on subscription {}: {}", id, t.message)
                    Level.DEBUG -> logger.debug("Error on subscription {}", id, t)
                    Level.TRACE -> logger.trace("Error on subscription {}", id, t)
                }
                val message =
                    Message.ErrorMessage(
                        id = id,
                        payload = listOf(GraphqlErrorBuilder.newError().message(t.message).build())
                    )
                val jsonMessage = TextMessage(objectMapper.writeValueAsBytes(message))
                logger.debug("Sending subscription error: {}", jsonMessage)

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                }
            }

            override fun onComplete() {
                logger.info("Subscription completed for {}", id)
                val message = Message.CompleteMessage(id)
                val jsonMessage = TextMessage(objectMapper.writeValueAsBytes(message))

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                }
                contexts[session.id]?.subscriptions?.remove(id)
            }
        })
    }

    private companion object {
        val logger: Logger = LoggerFactory.getLogger(WebsocketGraphQLTransportWSProtocolHandler::class.java)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy