All Downloads are FREE. Search and download functionalities are using the official Maven repository.

main.com.netflix.graphql.dgs.subscriptions.websockets.WebsocketGraphQLWSProtocolHandler.kt Maven / Gradle / Ivy

There is a newer version: 9.1.3
Show newest version
/*
 * Copyright 2020 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netflix.graphql.dgs.subscriptions.websockets

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.netflix.graphql.dgs.DgsQueryExecutor
import com.netflix.graphql.types.subscription.*
import graphql.ExecutionResult
import jakarta.annotation.PostConstruct
import jakarta.annotation.PreDestroy
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketSession
import org.springframework.web.socket.handler.TextWebSocketHandler
import java.io.UncheckedIOException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList

class WebsocketGraphQLWSProtocolHandler(
    private val dgsQueryExecutor: DgsQueryExecutor,
    private val subscriptionErrorLogLevel: Level,
    private val objectMapper: ObjectMapper
) : TextWebSocketHandler() {

    internal val subscriptions = ConcurrentHashMap>()
    internal val sessions = CopyOnWriteArrayList()

    @Volatile
    private var timer: Timer? = null

    @PostConstruct
    fun setupCleanup() {
        val timer = Timer("dgs-graphql-ws-session-cleanup", true)
        this.timer = timer
        val timerTask = object : TimerTask() {
            override fun run() {
                for (session in sessions) {
                    if (!session.isOpen) {
                        cleanupSubscriptionsForSession(session)
                    }
                }
            }
        }
        timer.scheduleAtFixedRate(timerTask, 0, 5000)
    }

    @PreDestroy
    fun destroy() {
        val timer = this.timer ?: return
        timer.cancel()
        this.timer = null
    }

    public override fun handleTextMessage(session: WebSocketSession, message: TextMessage) {
        val (type, payload, id) = objectMapper.readValue(message.payload, OperationMessage::class.java)
        when (type) {
            GQL_CONNECTION_INIT -> {
                logger.info("Initialized connection for {}", session.id)
                sessions.add(session)
                session.sendMessage(
                    TextMessage(
                        objectMapper.writeValueAsBytes(
                            OperationMessage(
                                GQL_CONNECTION_ACK
                            )
                        )
                    )
                )
            }
            GQL_START -> {
                val queryPayload = objectMapper.convertValue(payload, QueryPayload::class.java)
                handleSubscription(id!!, queryPayload, session)
            }
            GQL_STOP -> {
                subscriptions[session.id]?.get(id)?.cancel()
                subscriptions[session.id]?.remove(id)
            }
            GQL_CONNECTION_TERMINATE -> {
                logger.info("Terminated session {}", session.id)
                cleanupSubscriptionsForSession(session)
                session.close()
            }
            else -> session.sendMessage(TextMessage(objectMapper.writeValueAsBytes(OperationMessage("error"))))
        }
    }

    private fun cleanupSubscriptionsForSession(session: WebSocketSession) {
        logger.info("Cleaning up for session {}", session.id)
        subscriptions[session.id]?.values?.forEach { it.cancel() }
        subscriptions.remove(session.id)
        sessions.remove(session)
    }

    private fun handleSubscription(id: String, payload: QueryPayload, session: WebSocketSession) {
        val executionResult: ExecutionResult = dgsQueryExecutor.execute(payload.query, payload.variables.orEmpty())
        val subscriptionStream: Publisher = executionResult.getData()

        subscriptionStream.subscribe(object : Subscriber {
            override fun onSubscribe(s: Subscription) {
                logger.info("Subscription started for {}", id)
                subscriptions.putIfAbsent(session.id, mutableMapOf())
                subscriptions[session.id]?.set(id, s)

                s.request(1)
            }

            override fun onNext(er: ExecutionResult) {
                val message = OperationMessage(GQL_DATA, DataPayload(er.getData(), er.errors), id)
                val jsonMessage = try {
                    TextMessage(objectMapper.writeValueAsBytes(message))
                } catch (exc: JsonProcessingException) {
                    throw UncheckedIOException(exc)
                }
                logger.debug("Sending subscription data: {}", jsonMessage)

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                    subscriptions[session.id]?.get(id)?.request(1)
                }
            }

            override fun onError(t: Throwable) {
                when (subscriptionErrorLogLevel) {
                    Level.ERROR -> logger.error("Error on subscription {}", id, t)
                    Level.WARN -> logger.warn("Error on subscription {}", id, t)
                    Level.INFO -> logger.info("Error on subscription {}: {}", id, t.message)
                    Level.DEBUG -> logger.debug("Error on subscription {}", id, t)
                    Level.TRACE -> logger.trace("Error on subscription {}", id, t)
                }
                val message = OperationMessage(GQL_ERROR, DataPayload(null, listOf(t.message!!)), id)
                val jsonMessage = TextMessage(objectMapper.writeValueAsBytes(message))
                logger.debug("Sending subscription error: {}", jsonMessage)

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                }
            }

            override fun onComplete() {
                logger.info("Subscription completed for {}", id)
                val message = OperationMessage(GQL_COMPLETE, null, id)
                val jsonMessage = TextMessage(objectMapper.writeValueAsBytes(message))

                if (session.isOpen) {
                    session.sendMessage(jsonMessage)
                }

                subscriptions[session.id]?.remove(id)
            }
        })
    }

    private companion object {
        val logger: Logger = LoggerFactory.getLogger(WebsocketGraphQLWSProtocolHandler::class.java)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy