All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.jupyter.test.protocol.executeTests.kt Maven / Gradle / Ivy

package org.jetbrains.kotlinx.jupyter.test.protocol

import ch.qos.logback.classic.Level.DEBUG
import ch.qos.logback.classic.Level.OFF
import io.kotest.matchers.paths.shouldBeAFile
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.types.shouldBeTypeOf
import jupyter.kotlin.providers.UserHandlesProvider
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.jsonPrimitive
import org.jetbrains.kotlinx.jupyter.LoggingManagement.mainLoggerLevel
import org.jetbrains.kotlinx.jupyter.api.MimeTypes
import org.jetbrains.kotlinx.jupyter.api.Notebook
import org.jetbrains.kotlinx.jupyter.api.SessionOptions
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
import org.jetbrains.kotlinx.jupyter.config.currentKotlinVersion
import org.jetbrains.kotlinx.jupyter.messaging.CommMsg
import org.jetbrains.kotlinx.jupyter.messaging.CommOpen
import org.jetbrains.kotlinx.jupyter.messaging.DisplayDataResponse
import org.jetbrains.kotlinx.jupyter.messaging.EXECUTION_INTERRUPTED_MESSAGE
import org.jetbrains.kotlinx.jupyter.messaging.ExecuteReply
import org.jetbrains.kotlinx.jupyter.messaging.ExecuteRequest
import org.jetbrains.kotlinx.jupyter.messaging.ExecutionResultMessage
import org.jetbrains.kotlinx.jupyter.messaging.InputReply
import org.jetbrains.kotlinx.jupyter.messaging.InterruptRequest
import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteReply
import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteRequest
import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus
import org.jetbrains.kotlinx.jupyter.messaging.Message
import org.jetbrains.kotlinx.jupyter.messaging.MessageStatus
import org.jetbrains.kotlinx.jupyter.messaging.MessageType
import org.jetbrains.kotlinx.jupyter.messaging.OpenDebugPortReply
import org.jetbrains.kotlinx.jupyter.messaging.ProvidedCommMessages
import org.jetbrains.kotlinx.jupyter.messaging.StatusReply
import org.jetbrains.kotlinx.jupyter.messaging.StreamResponse
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
import org.jetbrains.kotlinx.jupyter.protocol.JupyterSocket
import org.jetbrains.kotlinx.jupyter.protocol.JupyterSocketInfo
import org.jetbrains.kotlinx.jupyter.protocol.MessageFormat
import org.jetbrains.kotlinx.jupyter.test.NotebookMock
import org.jetbrains.kotlinx.jupyter.test.assertStartsWith
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Timeout
import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode
import org.zeromq.ZMQ
import java.io.File
import java.net.URLClassLoader
import java.nio.file.Files
import java.util.concurrent.TimeUnit
import kotlin.io.path.readText
import kotlin.reflect.KProperty1
import kotlin.reflect.full.memberProperties
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

fun JsonObject.string(key: String): String {
    return (get(key) as JsonPrimitive).content
}

@Timeout(100, unit = TimeUnit.SECONDS)
@Execution(ExecutionMode.SAME_THREAD)
class ExecuteTests : KernelServerTestsBase() {
    private var _context: ZMQ.Context? = null
    override val context: ZMQ.Context
        get() = _context!!

    private var shell: JupyterSocket? = null
    private var control: JupyterSocket? = null
    private var ioPub: JupyterSocket? = null
    private var stdin: JupyterSocket? = null

    override fun beforeEach() {
        try {
            _context = ZMQ.context(1)
            shell = createClientSocket(JupyterSocketInfo.SHELL).apply {
                makeRelaxed()
            }
            ioPub = createClientSocket(JupyterSocketInfo.IOPUB)
            stdin = createClientSocket(JupyterSocketInfo.STDIN)
            control = createClientSocket(JupyterSocketInfo.CONTROL)

            ioPub?.subscribe(byteArrayOf())
            shell?.connect()
            ioPub?.connect()
            stdin?.connect()
            control?.connect()
        } catch (e: Throwable) {
            afterEach()
            throw e
        }
    }

    override fun afterEach() {
        listOf(::shell, ::ioPub, ::stdin, ::control).forEach { socketProp ->
            socketProp.get()?.close()
            socketProp.set(null)
        }
        context.term()
        _context = null
    }

    private fun doExecute(
        code: String,
        hasResult: Boolean = true,
        ioPubChecker: (JupyterSocket) -> Unit = {},
        executeRequestSent: () -> Unit = {},
        executeReplyChecker: (Message) -> Unit = {},
        inputs: List = emptyList(),
        allowStdin: Boolean = true,
        storeHistory: Boolean = true,
    ): Any? {
        try {
            val shell = this.shell!!
            val ioPub = this.ioPub!!
            val stdin = this.stdin!!
            shell.sendMessage(MessageType.EXECUTE_REQUEST, content = ExecuteRequest(code, allowStdin = allowStdin, storeHistory = storeHistory))
            executeRequestSent()
            inputs.forEach {
                stdin.sendMessage(MessageType.INPUT_REPLY, InputReply(it))
            }

            var msg = shell.receiveMessage()
            assertEquals(MessageType.EXECUTE_REPLY, msg.type)
            executeReplyChecker(msg)

            msg = ioPub.receiveMessage()
            assertEquals(MessageType.STATUS, msg.type)
            assertEquals(KernelStatus.BUSY, (msg.content as StatusReply).status)
            msg = ioPub.receiveMessage()
            assertEquals(MessageType.EXECUTE_INPUT, msg.type)

            ioPubChecker(ioPub)

            var response: Any? = null
            if (hasResult) {
                msg = ioPub.receiveMessage()
                val content = msg.content as ExecutionResultMessage
                assertEquals(MessageType.EXECUTE_RESULT, msg.type)
                response = content.data
            }

            msg = ioPub.receiveMessage()
            assertEquals(MessageType.STATUS, msg.type)
            assertEquals(KernelStatus.IDLE, (msg.content as StatusReply).status)
            return response
        } catch (e: Throwable) {
            afterEach()
            throw e
        }
    }

    private fun executeWithNoStdin(code: String) {
        doExecute(
            code,
            hasResult = false,
            allowStdin = false,
            ioPubChecker = {
                val msg = it.receiveMessage()
                assertEquals(MessageType.STREAM, msg.type)
                assertStartsWith("Input from stdin is unsupported by the client", (msg.content as StreamResponse).text)
            },
        )
    }

    private fun doIsComplete(code: String): String {
        try {
            val shell = this.shell!!
            shell.sendMessage(MessageType.IS_COMPLETE_REQUEST, content = IsCompleteRequest(code))

            val responseMsg = shell.receiveMessage()
            assertEquals(MessageType.IS_COMPLETE_REPLY, responseMsg.type)

            val content = responseMsg.content as IsCompleteReply
            return content.status
        } catch (e: Throwable) {
            afterEach()
            throw e
        }
    }

    private fun interruptExecution() {
        control!!.sendMessage(MessageType.INTERRUPT_REQUEST, InterruptRequest())
    }

    private inline fun  JupyterSocket.receiveMessageOfType(messageType: MessageType): T {
        val msg = receiveMessage()
        assertEquals(messageType, msg.type)
        val content = msg.content
        content.shouldBeTypeOf()
        return content
    }

    private fun JupyterSocket.receiveStreamResponse(): String {
        return receiveMessageOfType(MessageType.STREAM).text
    }

    private fun JupyterSocket.receiveDisplayDataResponse(): DisplayDataResponse {
        return receiveMessageOfType(MessageType.DISPLAY_DATA)
    }

    @Test
    fun testExecute() {
        val res = doExecute("2+2") as JsonObject
        assertEquals("4", res.string(MimeTypes.PLAIN_TEXT))
    }

    @Test
    fun testOutput() {
        val code =
            """
            for (i in 1..5) {
                Thread.sleep(200)
                print(i)
            }
            """.trimIndent()

        fun checker(ioPub: JupyterSocket) {
            for (i in 1..5) {
                val msg = ioPub.receiveMessage()
                assertEquals(MessageType.STREAM, msg.type)
                assertEquals(i.toString(), (msg.content as StreamResponse).text)
            }
        }

        val res = doExecute(code, false, ::checker)
        assertNull(res)
    }

    @Test
    fun testOutputMagic() {
        val code =
            """
            %output --max-buffer=2 --max-time=10000
            for (i in 1..5) {
                print(i)
            }
            """.trimIndent()

        val expected = arrayOf("12", "34", "5")

        fun checker(ioPub: JupyterSocket) {
            for (el in expected) {
                val msgText = ioPub.receiveStreamResponse()
                assertEquals(el, msgText)
            }
        }

        val res = doExecute(code, false, ::checker)
        assertNull(res)
    }

    @Test
    fun testOutputStrings() {
        val repetitions = 5
        val code =
            """
            for (i in 1..$repetitions) {
                Thread.sleep(200)
                println("text" + i)
            }
            """.trimIndent()

        fun checker(ioPub: JupyterSocket) {
            val lineSeparator = System.lineSeparator()
            val actualText = (1..repetitions).joinToString("") { ioPub.receiveStreamResponse() }
            val expectedText = (1..repetitions).joinToString("") { i -> "text$i$lineSeparator" }
            actualText shouldBe expectedText
        }

        val res = doExecute(code, false, ::checker)
        assertNull(res)
    }

    // TODO: investigate, why this test is hanging
    @Test
    fun testReadLine() {
        val code =
            """
            val answer = readLine()
            answer
            """.trimIndent()
        val res = doExecute(code, inputs = listOf("42"))
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "42"), res)
    }

    @Test
    fun testCompiledData() {
        doExecute(
            """
            SessionOptions.serializeScriptData = true
            """.trimIndent(),
            hasResult = false,
        )

        val code =
            """
            val xyz = 42
            """.trimIndent()
        val res = doExecute(
            code,
            hasResult = false,
            executeReplyChecker = { message ->
                val metadata = message.data.metadata
                assertTrue(metadata is JsonObject)
                val snippetMetadata = MessageFormat.decodeFromJsonElement(
                    metadata["eval_metadata"] ?: JsonNull,
                )
                val compiledData = snippetMetadata?.compiledData
                assertNotNull(compiledData)

                val deserializer = org.jetbrains.kotlinx.jupyter.compiler.CompiledScriptsSerializer()
                val dir = Files.createTempDirectory("kotlin-jupyter-exec-test")
                dir.toFile().deleteOnExit()
                val classesDir = dir.resolve("classes")
                val sourcesDir = dir.resolve("sources")

                val names = deserializer.deserializeAndSave(compiledData, classesDir, sourcesDir)
                val kClassName = names.single()
                val classLoader = URLClassLoader(arrayOf(classesDir.toUri().toURL()), ClassLoader.getSystemClassLoader())
                val loadedClass = classLoader.loadClass(kClassName).kotlin

                @Suppress("UNCHECKED_CAST")
                val xyzProperty = loadedClass.memberProperties.single { it.name == "xyz" } as KProperty1
                val constructor = loadedClass.constructors.single()

                val userHandlesProvider = object : UserHandlesProvider {
                    override val host: Nothing? = null
                    override val notebook: Notebook = NotebookMock
                    override val sessionOptions: SessionOptions
                        get() = throw NotImplementedError()
                }

                val instance = constructor.call(emptyArray(), userHandlesProvider)
                xyzProperty.get(instance) shouldBe 42

                val sourceFile = sourcesDir.resolve("Line_1.kts")
                sourceFile.shouldBeAFile()
                sourceFile.readText() shouldBe "val xyz = 42"
            },
        )
        assertNull(res)
    }

    @Test
    fun testLibraryLoadingErrors() {
        doExecute(
            """
                USE {
                    import("xyz.ods")
                }
            """.trimIndent(),
            false,
            ioPubChecker = {
                val msgText = it.receiveStreamResponse()
                assertTrue("The problem is found in one of the loaded libraries" in msgText)
            },
        )
    }

    @Test
    fun testCounter() {
        fun checkCounter(message: Message, expectedCounter: Long) {
            val data = message.data.content as ExecuteReply
            assertEquals(expectedCounter, data.executionCount)
        }
        val res1 = doExecute("41", executeReplyChecker = { checkCounter(it, 1) })
        val res2 = doExecute("42", executeReplyChecker = { checkCounter(it, 2) })
        val res3 = doExecute(
            " \"\${Out[1]} \${Out[2]}\" ",
            storeHistory = false,
            executeReplyChecker = { checkCounter(it, 3) },
        )
        val res4 = doExecute(
            "val x = try { Out[3] } catch(e: ArrayIndexOutOfBoundsException) { null }; x",
            storeHistory = false,
            executeReplyChecker = { checkCounter(it, 3) },
        )

        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "41"), res1)
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "42"), res2)
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "41 42"), res3)
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "null"), res4)
    }

    @Test
    fun testReadLineWithNoStdin() {
        executeWithNoStdin("readLine() ?: \"blah\"")
    }

    @Test
    fun testStdinReadWithNoStdin() {
        executeWithNoStdin("System.`in`.read()")
    }

    @Test
    fun testIsComplete() {
        assertEquals("complete", doIsComplete("2 + 2"))
        assertEquals("incomplete", doIsComplete("fun f() : Int { return 1"))
        assertEquals(if (runInSeparateProcess) DEBUG else OFF, mainLoggerLevel())
    }

    @Test
    fun testLoggerAppender() {
        val file = File.createTempFile("kotlin-jupyter-logger-appender-test", ".txt")
        doExecute("%logHandler add f1 --file ${file.absolutePath}", false)
        val result1 = doExecute("2 + 2")
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "4"), result1)

        doExecute("%logHandler remove f1", false)
        val result2 = doExecute("3 + 4")
        assertEquals(jsonObject(MimeTypes.PLAIN_TEXT to "7"), result2)

        val logText = file.readText()
        assertTrue("2 + 2" in logText)
        assertTrue("3 + 4" !in logText)

        file.delete()
    }

    @Test
    fun testComms() {
        val shell = shell!!
        val iopub = ioPub!!

        val targetName = "my_comm"
        val commId = "xyz"

        val registerCode = """
            import kotlinx.serialization.*
            import kotlinx.serialization.json.*
            
            notebook.commManager.registerCommTarget("$targetName") { comm, openData ->
                comm.send(
                    JsonObject(
                        mapOf(
                            "xo" to JsonPrimitive(comm.id)
                        )
                    )
                )
                
                comm.onMessage { d ->
                    comm.send(
                        JsonObject(
                            mapOf(
                                "y" to JsonPrimitive("received: " + d["x"]!!.jsonPrimitive.content)
                            )
                        )
                    )
                }
            }
        """.trimIndent()
        doExecute(registerCode, false)

        shell.sendMessage(MessageType.COMM_OPEN, CommOpen(commId, targetName))

        iopub.receiveMessage().apply {
            val c = content.shouldBeTypeOf()
            c.commId shouldBe commId
            c.data["xo"]!!.jsonPrimitive.content shouldBe commId
        }

        // Thread.sleep(5000)

        shell.sendMessage(
            MessageType.COMM_MSG,
            CommMsg(
                commId,
                JsonObject(
                    mapOf(
                        "x" to JsonPrimitive("4321"),
                    ),
                ),
            ),
        )

        iopub.wrapActionInBusyIdleStatusChange {
            iopub.receiveMessage().apply {
                val c = content.shouldBeTypeOf()
                c.commId shouldBe commId
                c.data["y"]!!.jsonPrimitive.content shouldBe "received: 4321"
            }
        }
    }

    @Test
    fun testDebugPortCommHandler() {
        val shell = shell!!
        val iopub = ioPub!!

        val targetName = ProvidedCommMessages.OPEN_DEBUG_PORT_TARGET
        val commId = "some"
        val actualDebugPort = kernelConfig.debugPort

        shell.sendMessage(
            MessageType.COMM_OPEN,
            CommOpen(
                commId,
                targetName,
            ),
        )

        shell.sendMessage(
            MessageType.COMM_MSG,
            CommMsg(commId),
        )

        iopub.wrapActionInBusyIdleStatusChange {
            iopub.receiveMessage().apply {
                val c = content.shouldBeTypeOf()
                val data = MessageFormat.decodeFromJsonElement(c.data).shouldBeTypeOf()
                c.commId shouldBe commId
                data.port shouldBe actualDebugPort
                data.status shouldBe MessageStatus.OK
            }
        }
    }

    @Test
    fun testCommand() {
        val res = doExecute(":help")
        res.shouldBeTypeOf()
        val text = res[MimeTypes.PLAIN_TEXT]!!.jsonPrimitive.content
        text.shouldContain(currentKotlinVersion)
        print(text)
    }

    @Test
    fun testInterrupt() {
        doExecute(
            "while(true);",
            hasResult = false,
            executeRequestSent = {
                Thread.sleep(15000)
                interruptExecution()
            },
            ioPubChecker = { iopubSocket ->
                val msgText = iopubSocket.receiveStreamResponse()
                msgText shouldBe EXECUTION_INTERRUPTED_MESSAGE
            },
        ) shouldBe null
    }

    @Test
    @Disabled
    fun testBigDataFrame() {
        doExecute(
            """
                %use dataframe
                DataFrame.read("https://api.apis.guru/v2/list.json")
            """.trimIndent(),
            ioPubChecker = { iopubSocket ->
                iopubSocket.receiveDisplayDataResponse()
                iopubSocket.receiveDisplayDataResponse()
            },
        )
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy