
commonMain.io.ktor.client.plugins.websocket.WebSockets.kt Maven / Gradle / Ivy
/*
* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
package io.ktor.client.plugins.websocket
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.*
import io.ktor.util.*
import io.ktor.util.logging.*
import io.ktor.websocket.*
import kotlin.native.concurrent.*
private val REQUEST_EXTENSIONS_KEY = AttributeKey>>("Websocket extensions")
internal val LOGGER = KtorSimpleLogger("io.ktor.client.plugins.websocket.WebSockets")
/**
* Indicates if a client engine supports WebSockets.
*/
public object WebSocketCapability : HttpClientEngineCapability {
override fun toString(): String = "WebSocketCapability"
}
/**
* Indicates if a client engine supports extensions for WebSocket plugin.
*/
public object WebSocketExtensionsCapability : HttpClientEngineCapability {
override fun toString(): String = "WebSocketExtensionsCapability"
}
/**
* Client WebSocket plugin.
*
* @property pingInterval - interval between [FrameType.PING] messages.
* @property maxFrameSize - max size of a single websocket frame.
* @property extensionsConfig - extensions configuration
* @property contentConverter - converter for serialization/deserialization
*/
public class WebSockets internal constructor(
public val pingInterval: Long,
public val maxFrameSize: Long,
private val extensionsConfig: WebSocketExtensionsConfig,
public val contentConverter: WebsocketContentConverter? = null
) {
/**
* Client WebSocket plugin.
*
* @property pingInterval - interval between [FrameType.PING] messages.
* @property maxFrameSize - max size of single websocket frame.
*/
public constructor(
pingInterval: Long = -1L,
maxFrameSize: Long = Int.MAX_VALUE.toLong()
) : this(pingInterval, maxFrameSize, WebSocketExtensionsConfig())
/**
* Client WebSocket plugin.
*/
public constructor() : this(-1L, Int.MAX_VALUE.toLong(), WebSocketExtensionsConfig())
private fun installExtensions(context: HttpRequestBuilder) {
val installed = extensionsConfig.build()
context.attributes.put(REQUEST_EXTENSIONS_KEY, installed)
val protocols = installed.flatMap { it.protocols }
addNegotiatedProtocols(context, protocols)
}
@Suppress("UNCHECKED_CAST")
private fun completeNegotiation(
call: HttpClientCall
): List> {
val serverExtensions: List = call.response
.headers[HttpHeaders.SecWebSocketExtensions]
?.let { parseWebSocketExtensions(it) } ?: emptyList()
val clientExtensions = call.attributes[REQUEST_EXTENSIONS_KEY]
return clientExtensions.filter { it.clientNegotiation(serverExtensions) }
}
private fun addNegotiatedProtocols(context: HttpRequestBuilder, protocols: List) {
if (protocols.isEmpty()) return
val headerValue = protocols.joinToString(";")
context.header(HttpHeaders.SecWebSocketExtensions, headerValue)
}
internal fun convertSessionToDefault(session: WebSocketSession): DefaultWebSocketSession {
if (session is DefaultWebSocketSession) return session
return DefaultWebSocketSession(session, pingInterval, timeoutMillis = pingInterval * 2).also {
it.maxFrameSize = [email protected]
}
}
/**
* [WebSockets] configuration.
*/
@KtorDsl
public class Config {
internal val extensionsConfig: WebSocketExtensionsConfig = WebSocketExtensionsConfig()
/**
* Sets interval of sending ping frames.
*
* Value -1L is for disabled ping.
*/
public var pingInterval: Long = -1L
/**
* Sets maximum frame size in bytes.
*/
public var maxFrameSize: Long = Int.MAX_VALUE.toLong()
/**
* A converter for serialization/deserialization
*/
public var contentConverter: WebsocketContentConverter? = null
/**
* Configure WebSocket extensions.
*/
public fun extensions(block: WebSocketExtensionsConfig.() -> Unit) {
extensionsConfig.apply(block)
}
}
/**
* Add WebSockets support for ktor http client.
*/
public companion object Plugin : HttpClientPlugin {
override val key: AttributeKey = AttributeKey("Websocket")
override fun prepare(block: Config.() -> Unit): WebSockets {
val config = Config().apply(block)
return WebSockets(
config.pingInterval,
config.maxFrameSize,
config.extensionsConfig,
config.contentConverter
)
}
@OptIn(InternalAPI::class)
override fun install(plugin: WebSockets, scope: HttpClient) {
val extensionsSupported = scope.engine.supportedCapabilities.contains(WebSocketExtensionsCapability)
scope.requestPipeline.intercept(HttpRequestPipeline.Render) {
if (!context.url.protocol.isWebsocket()) {
LOGGER.trace("Skipping WebSocket plugin for non-websocket request: ${context.url}")
return@intercept
}
LOGGER.trace("Sending WebSocket request ${context.url}")
context.setCapability(WebSocketCapability, Unit)
if (extensionsSupported) {
plugin.installExtensions(context)
}
proceedWith(WebSocketContent())
}
scope.responsePipeline.intercept(HttpResponsePipeline.Transform) { (info, session) ->
val response = this.context.response
val status = response.status
val requestContent = response.request.content
if (requestContent !is WebSocketContent) {
LOGGER.trace("Skipping non-websocket response from ${context.request.url}: $session")
return@intercept
}
if (status != HttpStatusCode.SwitchingProtocols) {
throw WebSocketException(
"Handshake exception, expected status code ${HttpStatusCode.SwitchingProtocols.value} but was ${status.value}" // ktlint-disable max-line-length
)
}
if (session !is WebSocketSession) {
throw WebSocketException(
"Handshake exception, expected `WebSocketSession` content but was $session"
)
}
LOGGER.trace("Receive websocket session from ${context.request.url}: $session")
val clientSession: ClientWebSocketSession = when (info.type) {
DefaultClientWebSocketSession::class -> {
val defaultSession = plugin.convertSessionToDefault(session)
val clientSession = DefaultClientWebSocketSession(context, defaultSession)
val negotiated = if (extensionsSupported) {
plugin.completeNegotiation(context)
} else {
emptyList()
}
clientSession.apply {
start(negotiated)
}
}
else -> DelegatingClientWebSocketSession(context, session)
}
proceedWith(HttpResponseContainer(info, clientSession))
}
}
}
}
@Suppress("KDocMissingDocumentation")
public class WebSocketException(message: String, cause: Throwable?) : IllegalStateException(message, cause) {
// required for backwards binary compatibility
public constructor(message: String) : this(message, cause = null)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy