org.jetbrains.kotlinx.jupyter.connection.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of kotlin-jupyter-kernel Show documentation
Show all versions of kotlin-jupyter-kernel Show documentation
Kotlin Jupyter kernel published as artifact
package org.jetbrains.kotlinx.jupyter
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.encodeToJsonElement
import kotlinx.serialization.json.jsonObject
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import org.zeromq.SocketType
import org.zeromq.ZMQ
import java.io.Closeable
import java.io.IOException
import java.security.SignatureException
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.concurrent.thread
import kotlin.math.min
class JupyterConnection(val config: KernelConfig) : Closeable {
inner class Socket(private val socket: JupyterSockets, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type) {
val name: String get() = socket.name
init {
val port = config.ports[socket.ordinal]
bind("${config.transport}://*:$port")
if (type == SocketType.PUB) {
// Workaround to prevent losing few first messages on kernel startup
// For more information on losing messages see this scheme:
// http://zguide.zeromq.org/page:all#Missing-Message-Problem-Solver
// It seems we cannot do correct sync because messaging protocol
// doesn't support this. Value of 500 ms was chosen experimentally.
Thread.sleep(500)
}
log.debug("[$name] listen: ${config.transport}://*:$port")
}
inline fun onData(body: Socket.(ByteArray) -> Unit) = recv()?.let { body(it) }
inline fun onMessage(body: Socket.(Message) -> Unit) = recv()?.let { bytes -> receiveMessage(bytes)?.let { body(it) } }
fun sendStatus(status: KernelStatus, msg: Message) {
connection.iopub.send(makeReplyMessage(msg, MessageType.STATUS, content = StatusReply(status)))
}
fun sendWrapped(incomingMessage: Message, msg: Message) {
sendStatus(KernelStatus.BUSY, incomingMessage)
send(msg)
sendStatus(KernelStatus.IDLE, incomingMessage)
}
fun sendOut(msg: Message, stream: JupyterOutType, text: String) {
send(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text)))
}
fun send(msg: Message) {
log.debug("[$name] snd>: $msg")
sendMessage(msg, hmac)
}
fun receiveMessage(start: ByteArray): Message? {
return try {
val msg = receiveMessage(start, hmac)
log.debug("[$name] >rcv: $msg")
msg
} catch (e: SignatureException) {
log.error("[$name] ${e.message}")
null
}
}
val connection: JupyterConnection = this@JupyterConnection
}
inner class StdinInputStream : java.io.InputStream() {
private var currentBuf: ByteArray? = null
private var currentBufPos = 0
private fun getInput(): String {
stdin.send(
makeReplyMessage(
contextMessage!!,
MessageType.INPUT_REQUEST,
content = InputRequest("stdin:")
)
)
val msg = stdin.receiveMessage(stdin.recv())
val content = msg?.data?.content as? InputReply
return content?.value ?: throw UnsupportedOperationException("Unexpected input message $msg")
}
private fun initializeCurrentBuf(): ByteArray {
val buf = currentBuf
return if (buf != null) {
buf
} else {
val newBuf = getInput().toByteArray()
currentBuf = newBuf
currentBufPos = 0
newBuf
}
}
@Synchronized
override fun read(): Int {
val buf = initializeCurrentBuf()
if (currentBufPos >= buf.size) {
currentBuf = null
return -1
}
return buf[currentBufPos++].toInt()
}
@Synchronized
override fun read(b: ByteArray, off: Int, len: Int): Int {
val buf = initializeCurrentBuf()
val lenLeft = buf.size - currentBufPos
if (lenLeft <= 0) {
currentBuf = null
return -1
}
val lenToRead = min(len, lenLeft)
for (i in 0 until lenToRead) {
b[off + i] = buf[currentBufPos + i]
}
currentBufPos += lenToRead
return lenToRead
}
}
private val hmac = HMAC(config.signatureScheme.replace("-", ""), config.signatureKey)
private val context = ZMQ.context(1)
val heartbeat = Socket(JupyterSockets.HB)
val shell = Socket(JupyterSockets.SHELL)
val control = Socket(JupyterSockets.CONTROL)
val stdin = Socket(JupyterSockets.STDIN)
val iopub = Socket(JupyterSockets.IOPUB)
val stdinIn = StdinInputStream()
var contextMessage: Message? = null
private val currentExecutions = HashSet()
private val coroutineScope = CoroutineScope(Dispatchers.Default)
data class ConnectionExecutionResult(
val result: T?,
val throwable: Throwable?,
val isInterrupted: Boolean,
)
fun runExecution(body: () -> T): ConnectionExecutionResult {
var execRes: T? = null
var execException: Throwable? = null
val execThread = thread {
try {
execRes = body()
} catch (e: Throwable) {
execException = e
}
}
currentExecutions.add(execThread)
execThread.join()
currentExecutions.remove(execThread)
val exception = execException
val isInterrupted = exception is ThreadDeath ||
(exception is ReplException && exception.cause is ThreadDeath)
return ConnectionExecutionResult(execRes, exception, isInterrupted)
}
/**
* We cannot use [Thread.interrupt] here because we have no way
* to control the code user executes. [Thread.interrupt] will do nothing for
* the simple calculation (like `while (true) 1`). Consider replacing with
* something more smart in the future.
*/
fun interruptExecution() {
@Suppress("deprecation")
while (currentExecutions.isNotEmpty()) {
val execution = currentExecutions.firstOrNull()
execution?.stop()
currentExecutions.remove(execution)
}
}
fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
coroutineScope.launch(block = runnable)
}
override fun close() {
heartbeat.close()
shell.close()
control.close()
stdin.close()
iopub.close()
context.close()
}
}
private val MESSAGE_DELIMITER: ByteArray = "".map { it.code.toByte() }.toByteArray()
class HMAC(algorithm: String, key: String?) {
private val mac = if (key?.isNotBlank() == true) Mac.getInstance(algorithm) else null
init {
mac?.init(SecretKeySpec(key!!.toByteArray(), algorithm))
}
@Synchronized
operator fun invoke(data: Iterable): String? =
mac?.let { mac ->
data.forEach { mac.update(it) }
mac.doFinal().toHexString()
}
operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable())
}
fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) })
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC) {
synchronized(this) {
msg.id.forEach { sendMore(it) }
sendMore(MESSAGE_DELIMITER)
val dataJson = Json.encodeToJsonElement(msg.data).jsonObject
val signableMsg = listOf("header", "parent_header", "metadata", "content")
.map { fieldName -> dataJson[fieldName]?.let { Json.encodeToString(it) }?.toByteArray() ?: emptyJsonObjectStringBytes }
sendMore(hmac(signableMsg) ?: "")
signableMsg.take(signableMsg.size - 1).forEach { sendMore(it) }
send(signableMsg.last())
}
}
fun ZMQ.Socket.receiveMessage(start: ByteArray, hmac: HMAC): Message {
val ids = listOf(start) + generateSequence { recv() }.takeWhile { !it.contentEquals(MESSAGE_DELIMITER) }
val sig = recvStr().lowercase()
val header = recv()
val parentHeader = recv()
val metadata = recv()
val content = recv()
val calculatedSig = hmac(header, parentHeader, metadata, content)
if (calculatedSig != null && sig != calculatedSig) {
throw SignatureException("Invalid signature: expected $calculatedSig, received $sig - $ids")
}
fun ByteArray.parseJson(): JsonElement {
val json = Json.decodeFromString(this.toString(Charsets.UTF_8))
return if (json is JsonObject && json.isEmpty()) JsonNull else json
}
fun JsonElement.orEmptyObject() = if (this is JsonNull) emptyJsonObject else this
val dataJson = jsonObject(
"header" to header.parseJson(),
"parent_header" to parentHeader.parseJson(),
"metadata" to metadata.parseJson().orEmptyObject(),
"content" to content.parseJson().orEmptyObject()
)
val data = Json.decodeFromJsonElement(dataJson)
return Message(
ids,
data
)
}
object DisabledStdinInputStream : java.io.InputStream() {
override fun read(): Int {
throw IOException("Input from stdin is unsupported by the client")
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy