All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.kt Maven / Gradle / Ivy

Go to download

Implementation of REPL compiler and preprocessor for Jupyter dialect of Kotlin (IDE-compatible)

The newest version!
package org.jetbrains.kotlinx.jupyter.messaging

import kotlinx.serialization.json.Json
import kotlinx.serialization.json.encodeToJsonElement
import org.jetbrains.kotlinx.jupyter.api.Code
import org.jetbrains.kotlinx.jupyter.api.KernelLoggerFactory
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelVersion.Companion.toMaybeUnspecifiedString
import org.jetbrains.kotlinx.jupyter.api.StreamSubstitutionType
import org.jetbrains.kotlinx.jupyter.api.getLogger
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
import org.jetbrains.kotlinx.jupyter.commands.runCommand
import org.jetbrains.kotlinx.jupyter.common.looksLikeReplCommand
import org.jetbrains.kotlinx.jupyter.config.currentKernelVersion
import org.jetbrains.kotlinx.jupyter.config.currentKotlinVersion
import org.jetbrains.kotlinx.jupyter.config.notebookLanguageInfo
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import org.jetbrains.kotlinx.jupyter.execution.ExecutionResult
import org.jetbrains.kotlinx.jupyter.execution.JupyterExecutor
import org.jetbrains.kotlinx.jupyter.messaging.StdIOSubstitutionManager.stderrContext
import org.jetbrains.kotlinx.jupyter.messaging.StdIOSubstitutionManager.stdinContext
import org.jetbrains.kotlinx.jupyter.messaging.StdIOSubstitutionManager.stdoutContext
import org.jetbrains.kotlinx.jupyter.messaging.StdIOSubstitutionManager.substitutionEngineType
import org.jetbrains.kotlinx.jupyter.messaging.comms.CommManagerInternal
import org.jetbrains.kotlinx.jupyter.protocol.PROTOCOL_VERSION
import org.jetbrains.kotlinx.jupyter.repl.EvalRequestData
import org.jetbrains.kotlinx.jupyter.repl.ReplForJupyter
import org.jetbrains.kotlinx.jupyter.repl.result.EvalResultEx
import org.jetbrains.kotlinx.jupyter.streams.CapturingOutputStream
import org.jetbrains.kotlinx.jupyter.streams.DisabledStdinInputStream
import org.jetbrains.kotlinx.jupyter.streams.StdinInputStream
import org.jetbrains.kotlinx.jupyter.streams.StreamSubstitutionManager
import org.jetbrains.kotlinx.jupyter.util.EMPTY
import java.io.InputStream
import java.io.OutputStream
import java.io.PrintStream
import kotlin.system.exitProcess

private object StdIOSubstitutionManager {
    private var engineType: StreamSubstitutionType? = null

    // We assume that inside one environment there is only one correct value for this property
    var substitutionEngineType: StreamSubstitutionType
        get() = engineType ?: throw UninitializedPropertyAccessException("Substitution engine type is not initialized yet")
        set(value) {
            if (engineType == null) {
                engineType = value
            } else {
                require(engineType == value) {
                    "Attempt to set substitution engine type to $value which is different from already set value"
                }
            }
        }

    val stdoutContext by lazy {
        StreamSubstitutionManager.StdOut(substitutionEngineType)
    }

    val stdinContext by lazy {
        StreamSubstitutionManager.StdIn(substitutionEngineType)
    }

    val stderrContext by lazy {
        StreamSubstitutionManager.StdErr(substitutionEngineType)
    }
}

@Suppress("MemberVisibilityCanBePrivate")
open class IdeCompatibleMessageRequestProcessor(
    rawIncomingMessage: RawMessage,
    messageFactoryProvider: MessageFactoryProvider,
    final override val socketManager: JupyterBaseSockets,
    protected val commManager: CommManagerInternal,
    protected val executor: JupyterExecutor,
    protected val executionCounter: ExecutionCounter,
    loggerFactory: KernelLoggerFactory,
    protected val repl: ReplForJupyter,
) : AbstractMessageRequestProcessor(rawIncomingMessage),
    JupyterCommunicationFacility {
    private val logger = loggerFactory.getLogger(this::class)

    init {
        substitutionEngineType = repl.notebook.kernelRunMode.streamSubstitutionType
    }

    final override val messageFactory =
        run {
            messageFactoryProvider.update(rawIncomingMessage)
            messageFactoryProvider.provide()!!
        }

    @Suppress("LeakingThis")
    protected val stdinIn: InputStream = StdinInputStream(this)

    override fun processUnknownShellMessage(content: MessageContent) {
        socketManager.shell.sendMessage(
            messageFactory.makeReplyMessage(MessageType.NONE),
        )
    }

    override fun processUnknownControlMessage(content: MessageContent) {
    }

    override fun processUnknownStdinMessage(content: MessageContent) {
    }

    override fun processIsCompleteRequest(content: IsCompleteRequest) {
        socketManager.shell.sendMessage(
            messageFactory.makeReplyMessage(MessageType.IS_COMPLETE_REPLY, content = IsCompleteReply("complete")),
        )
    }

    override fun processListErrorsRequest(content: ListErrorsRequest) {
        executor.launchJob {
            repl.listErrors(content.code) { result ->
                sendWrapped(messageFactory.makeReplyMessage(MessageType.LIST_ERRORS_REPLY, content = result.message))
            }
        }
    }

    override fun processCompleteRequest(content: CompleteRequest) {
        executor.launchJob {
            repl.complete(content.code, content.cursorPos) { result ->
                sendWrapped(messageFactory.makeReplyMessage(MessageType.COMPLETE_REPLY, content = result.message))
            }
        }
    }

    override fun processCommMsg(content: CommMsg) {
        executor.runExecution("Execution of comm_msg request for ${content.commId}") {
            commManager.processCommMessage(incomingMessage, content)
        }
    }

    override fun processCommClose(content: CommClose) {
        executor.runExecution("Execution of comm_close request for ${content.commId}") {
            commManager.processCommClose(incomingMessage, content)
        }
    }

    override fun processCommOpen(content: CommOpen) {
        executor.runExecution("Execution of comm_open request for ${content.commId} of target ${content.targetName}") {
            commManager.processCommOpen(incomingMessage, content)
                ?: throw ReplException("Cannot open comm for ${content.commId} of target ${content.targetName}")
        }
    }

    override fun processCommInfoRequest(content: CommInfoRequest) {
        val comms = commManager.getComms(content.targetName)
        val replyMap = comms.associate { comm -> comm.id to Comm(comm.target) }
        sendWrapped(messageFactory.makeReplyMessage(MessageType.COMM_INFO_REPLY, content = CommInfoReply(replyMap)))
    }

    override fun processExecuteRequest(content: ExecuteRequest) {
        val count = executionCounter.next(content.storeHistory)
        val startedTime = ISO8601DateNow

        doWrappedInBusyIdle {
            val code = content.code
            socketManager.iopub.sendMessage(
                messageFactory.makeReplyMessage(
                    MessageType.EXECUTE_INPUT,
                    content = ExecutionInputReply(code, count),
                ),
            )
            val response: JupyterResponse =
                if (looksLikeReplCommand(code)) {
                    runCommand(code, repl)
                } else {
                    runExecution("Execution of code '${code.presentableForThreadName()}'") {
                        evalWithIO(content.allowStdin) {
                            repl.evalEx(
                                EvalRequestData(
                                    code,
                                    count,
                                    content.storeHistory,
                                    content.silent,
                                ),
                            )
                        }
                    }
                }

            sendResponse(response, count, startedTime)
        }
    }

    override fun processConnectRequest(content: ConnectRequest) {
        sendWrapped(
            messageFactory.makeReplyMessage(
                MessageType.CONNECT_REPLY,
                content =
                    ConnectReply(
                        Json.EMPTY,
                    ),
            ),
        )
    }

    override fun processHistoryRequest(content: HistoryRequest) {
        sendWrapped(
            messageFactory.makeReplyMessage(
                MessageType.HISTORY_REPLY,
                content = HistoryReply(listOf()), // not implemented
            ),
        )
    }

    override fun processKernelInfoRequest(content: KernelInfoRequest) {
        sendWrapped(
            messageFactory.makeReplyMessage(
                MessageType.KERNEL_INFO_REPLY,
                content =
                    KernelInfoReply(
                        PROTOCOL_VERSION,
                        "Kotlin",
                        currentKernelVersion.toMaybeUnspecifiedString(),
                        "Kotlin kernel v. ${currentKernelVersion.toMaybeUnspecifiedString()}, Kotlin v. $currentKotlinVersion",
                        notebookLanguageInfo,
                        listOf(),
                    ),
                metadata =
                    Json.encodeToJsonElement(
                        KernelInfoReplyMetadata(repl.currentSessionState),
                    ),
            ),
        )
    }

    override fun processShutdownRequest(content: ShutdownRequest) {
        repl.evalOnShutdown()
        executor.close()
        socketManager.control.sendMessage(
            messageFactory.makeReplyMessage(MessageType.SHUTDOWN_REPLY, content = incomingMessage.content),
        )
        if (repl.kernelRunMode.shouldKillProcessOnShutdown) {
            exitProcess(0)
        } else {
            // exitProcess would kill the entire process that embedded the kernel
            // Instead the controlThread will be interrupted,
            // which will then interrupt the mainThread and make kernelServer return
            logger.info("Interrupting controlThread to trigger kernel shutdown")
            throw InterruptedException()
        }
    }

    override fun processInterruptRequest(content: InterruptRequest) {
        executor.interruptExecution()
        socketManager.control.sendMessage(
            messageFactory.makeReplyMessage(MessageType.INTERRUPT_REPLY, content = incomingMessage.content),
        )
    }

    override fun processInputReply(content: InputReply) {
    }

    protected open fun runExecution(
        executionName: String,
        execution: () -> EvalResultEx,
    ): JupyterResponse {
        return when (
            val res =
                executor.runExecution(
                    executionName,
                    repl.currentClassLoader,
                    execution,
                )
        ) {
            is ExecutionResult.Success -> {
                try {
                    when (val replResult = res.result) {
                        is EvalResultEx.Success -> {
                            OkJupyterResponse(replResult.displayValue, replResult.metadata)
                        }
                        is EvalResultEx.Error -> {
                            replResult.error.toErrorJupyterResponse(replResult.metadata)
                        }
                        is EvalResultEx.RenderedError -> {
                            OkJupyterResponse(replResult.displayError, replResult.metadata)
                        }
                        is EvalResultEx.Interrupted -> {
                            ErrorJupyterResponse(EXECUTION_INTERRUPTED_MESSAGE, metadata = replResult.metadata)
                        }
                    }
                } catch (e: Throwable) {
                    ErrorJupyterResponse("error:  Unable to convert result to a string: $e")
                }
            }
            is ExecutionResult.Failure -> {
                res.throwable.toErrorJupyterResponse()
            }
            ExecutionResult.Interrupted -> {
                ErrorJupyterResponse(EXECUTION_INTERRUPTED_MESSAGE)
            }
        }
    }

    private val replOutputConfig get() = repl.options.outputConfig

    /**
     * Creates a capturing [PrintStream] that forwards its output to the given [parentStream] (if not null)
     * and optionally captures the output for further processing.
     * Captured output is sent as a [MessageType.STREAM] message back to the client.
     *
     * @param parentStream the parent [PrintStream] to forward the output to, can be null
     * @param outType the type of output (stdout or stderr) to be associated with the captured output
     * @param captureOutput a flag indicating whether to capture the output
     * @return a new [PrintStream] that wraps the forwards output to the [parentStream]
     * and captures output if [captureOutput] is true
     */
    private fun getCapturingStream(
        parentStream: PrintStream?,
        outType: JupyterOutType,
        captureOutput: Boolean,
    ): PrintStream {
        return CapturingOutputStream(
            parentStream,
            replOutputConfig,
            captureOutput,
        ) { text ->
            repl.notebook.currentCell?.appendStreamOutput(text)
            this.sendOut(outType, text)
        }.asPrintStream()
    }

    private fun OutputStream.asPrintStream() = PrintStream(this, false, "UTF-8")

    private fun  withForkedOut(body: () -> T): T {
        return stdoutContext.withSubstitutedStreams(
            systemStreamFactory = { out: PrintStream? -> getCapturingStream(out, JupyterOutType.STDOUT, replOutputConfig.captureOutput) },
            kernelStreamFactory = { null },
            body = body,
        )
    }

    private fun  withForkedErr(body: () -> T): T {
        return stderrContext.withSubstitutedStreams(
            systemStreamFactory = { err: PrintStream? -> getCapturingStream(err, JupyterOutType.STDERR, false) },
            kernelStreamFactory = { getCapturingStream(null, JupyterOutType.STDERR, true) },
            body = body,
        )
    }

    private fun  withForkedIn(
        allowStdIn: Boolean,
        body: () -> T,
    ): T {
        return stdinContext.withSubstitutedStreams(
            systemStreamFactory = { if (allowStdIn) stdinIn else DisabledStdinInputStream },
            kernelStreamFactory = { null },
            body = body,
        )
    }

    protected open fun  evalWithIO(
        allowStdIn: Boolean,
        body: () -> T,
    ): T {
        repl.notebook.beginEvalSession()
        return withForkedOut {
            withForkedErr {
                withForkedIn(allowStdIn, body)
            }
        }
    }

    private fun Code.presentableForThreadName(): String {
        val newName = substringBefore('\n').take(20)
        return if (newName.length < length) {
            "$newName..."
        } else {
            this
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy