
org.jetbrains.kotlinx.jupyter.connection.kt Maven / Gradle / Ivy
package org.jetbrains.kotlinx.jupyter
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.encodeToJsonElement
import kotlinx.serialization.json.jsonObject
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import org.zeromq.SocketType
import org.zeromq.ZMQ
import java.io.Closeable
import java.io.IOException
import java.security.SignatureException
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.concurrent.thread
import kotlin.math.min
class JupyterConnection(val config: KernelConfig) : Closeable {
inner class Socket(private val socket: JupyterSockets, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type) {
val name: String get() = socket.name
init {
val port = config.ports[socket.ordinal]
bind("${config.transport}://*:$port")
if (type == SocketType.PUB) {
// Workaround to prevent losing few first messages on kernel startup
// For more information on losing messages see this scheme:
// http://zguide.zeromq.org/page:all#Missing-Message-Problem-Solver
// It seems we cannot do correct sync because messaging protocol
// doesn't support this. Value of 500 ms was chosen experimentally.
Thread.sleep(500)
}
log.debug("[$name] listen: ${config.transport}://*:$port")
}
inline fun onData(body: Socket.(ByteArray) -> Unit) = recv()?.let { body(it) }
inline fun onMessage(body: Socket.(Message) -> Unit) = recv()?.let { bytes -> receiveMessage(bytes)?.let { body(it) } }
fun sendStatus(status: KernelStatus, msg: Message) {
connection.iopub.send(makeReplyMessage(msg, MessageType.STATUS, content = StatusReply(status)))
}
fun sendWrapped(incomingMessage: Message, msg: Message) {
sendStatus(KernelStatus.BUSY, incomingMessage)
send(msg)
sendStatus(KernelStatus.IDLE, incomingMessage)
}
fun sendOut(msg: Message, stream: JupyterOutType, text: String) {
send(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text)))
}
fun send(msg: Message) {
log.debug("[$name] snd>: $msg")
sendMessage(msg, hmac)
}
fun receiveMessage(start: ByteArray): Message? {
return try {
val msg = receiveMessage(start, hmac)
log.debug("[$name] >rcv: $msg")
msg
} catch (e: SignatureException) {
log.error("[$name] ${e.message}")
null
}
}
val connection: JupyterConnection = this@JupyterConnection
}
inner class StdinInputStream : java.io.InputStream() {
private var currentBuf: ByteArray? = null
private var currentBufPos = 0
private fun getInput(): String {
stdin.send(
makeReplyMessage(
contextMessage!!,
MessageType.INPUT_REQUEST,
content = InputRequest("stdin:")
)
)
val msg = stdin.receiveMessage(stdin.recv())
val content = msg?.data?.content as? InputReply
return content?.value ?: throw UnsupportedOperationException("Unexpected input message $msg")
}
private fun initializeCurrentBuf(): ByteArray {
val buf = currentBuf
return if (buf != null) {
buf
} else {
val newBuf = getInput().toByteArray()
currentBuf = newBuf
currentBufPos = 0
newBuf
}
}
@Synchronized
override fun read(): Int {
val buf = initializeCurrentBuf()
if (currentBufPos >= buf.size) {
currentBuf = null
return -1
}
return buf[currentBufPos++].toInt()
}
@Synchronized
override fun read(b: ByteArray, off: Int, len: Int): Int {
val buf = initializeCurrentBuf()
val lenLeft = buf.size - currentBufPos
if (lenLeft <= 0) {
currentBuf = null
return -1
}
val lenToRead = min(len, lenLeft)
for (i in 0 until lenToRead) {
b[off + i] = buf[currentBufPos + i]
}
currentBufPos += lenToRead
return lenToRead
}
}
private val hmac = HMAC(config.signatureScheme.replace("-", ""), config.signatureKey)
private val context = ZMQ.context(1)
val heartbeat = Socket(JupyterSockets.HB)
val shell = Socket(JupyterSockets.SHELL)
val control = Socket(JupyterSockets.CONTROL)
val stdin = Socket(JupyterSockets.STDIN)
val iopub = Socket(JupyterSockets.IOPUB)
val stdinIn = StdinInputStream()
var contextMessage: Message? = null
private val currentExecutions = HashSet()
private val coroutineScope = CoroutineScope(Dispatchers.Default)
data class ConnectionExecutionResult(
val result: T?,
val throwable: Throwable?,
val isInterrupted: Boolean,
)
fun runExecution(body: () -> T): ConnectionExecutionResult {
var execRes: T? = null
var execException: Throwable? = null
val execThread = thread {
try {
execRes = body()
} catch (e: Throwable) {
execException = e
}
}
currentExecutions.add(execThread)
execThread.join()
currentExecutions.remove(execThread)
val exception = execException
val isInterrupted = exception is ThreadDeath ||
(exception is ReplException && exception.cause is ThreadDeath)
return ConnectionExecutionResult(execRes, exception, isInterrupted)
}
/**
* We cannot use [Thread.interrupt] here because we have no way
* to control the code user executes. [Thread.interrupt] will do nothing for
* the simple calculation (like `while (true) 1`). Consider replacing with
* something more smart in the future.
*/
fun interruptExecution() {
@Suppress("deprecation")
while (currentExecutions.isNotEmpty()) {
val execution = currentExecutions.firstOrNull()
execution?.stop()
currentExecutions.remove(execution)
}
}
fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
coroutineScope.launch(block = runnable)
}
override fun close() {
heartbeat.close()
shell.close()
control.close()
stdin.close()
iopub.close()
context.close()
}
}
private val MESSAGE_DELIMITER: ByteArray = "".map { it.code.toByte() }.toByteArray()
class HMAC(algorithm: String, key: String?) {
private val mac = if (key?.isNotBlank() == true) Mac.getInstance(algorithm) else null
init {
mac?.init(SecretKeySpec(key!!.toByteArray(), algorithm))
}
@Synchronized
operator fun invoke(data: Iterable): String? =
mac?.let { mac ->
data.forEach { mac.update(it) }
mac.doFinal().toHexString()
}
operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable())
}
fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) })
fun ZMQ.Socket.sendMessage(msg: Message, hmac: HMAC) {
synchronized(this) {
msg.id.forEach { sendMore(it) }
sendMore(MESSAGE_DELIMITER)
val dataJson = Json.encodeToJsonElement(msg.data).jsonObject
val signableMsg = listOf("header", "parent_header", "metadata", "content")
.map { fieldName -> dataJson[fieldName]?.let { Json.encodeToString(it) }?.toByteArray() ?: emptyJsonObjectStringBytes }
sendMore(hmac(signableMsg) ?: "")
signableMsg.take(signableMsg.size - 1).forEach { sendMore(it) }
send(signableMsg.last())
}
}
fun ZMQ.Socket.receiveMessage(start: ByteArray, hmac: HMAC): Message {
val ids = listOf(start) + generateSequence { recv() }.takeWhile { !it.contentEquals(MESSAGE_DELIMITER) }
val sig = recvStr().lowercase()
val header = recv()
val parentHeader = recv()
val metadata = recv()
val content = recv()
val calculatedSig = hmac(header, parentHeader, metadata, content)
if (calculatedSig != null && sig != calculatedSig) {
throw SignatureException("Invalid signature: expected $calculatedSig, received $sig - $ids")
}
fun ByteArray.parseJson(): JsonElement {
val json = Json.decodeFromString(this.toString(Charsets.UTF_8))
return if (json is JsonObject && json.isEmpty()) JsonNull else json
}
fun JsonElement.orEmptyObject() = if (this is JsonNull) emptyJsonObject else this
val dataJson = jsonObject(
"header" to header.parseJson(),
"parent_header" to parentHeader.parseJson(),
"metadata" to metadata.parseJson().orEmptyObject(),
"content" to content.parseJson().orEmptyObject()
)
val data = Json.decodeFromJsonElement(dataJson)
return Message(
ids,
data
)
}
object DisabledStdinInputStream : java.io.InputStream() {
override fun read(): Int {
throw IOException("Input from stdin is unsupported by the client")
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy