
almond.kernel.Kernel.scala Maven / Gradle / Ivy
package almond.kernel
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files, Path, Paths}
import java.util.UUID
import almond.channels.zeromq.ZeromqThreads
import almond.channels.{Channel, Connection, ConnectionParameters, Message => RawMessage}
import almond.interpreter.{IOInterpreter, Interpreter, InterpreterToIOInterpreter, Message}
import almond.interpreter.comm.DefaultCommHandler
import almond.interpreter.input.InputHandler
import almond.interpreter.messagehandlers.{
CloseExecutionException,
CommMessageHandlers,
InterpreterMessageHandlers,
MessageHandler
}
import almond.logger.LoggerContext
import almond.protocol.{Header, Protocol, Status, Connection => JsonConnection}
import cats.effect.IO
import cats.effect.std.Queue
import fs2.concurrent.SignallingRef
import fs2.{Pipe, Stream}
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.DurationInt
final case class Kernel(
interpreter: IOInterpreter,
backgroundMessagesQueue: Queue[IO, (Channel, RawMessage)],
executeQueue: Queue[IO, Option[(
Option[(Channel, RawMessage)],
Stream[IO, (Channel, RawMessage)]
)]],
otherQueue: Queue[IO, Option[Stream[IO, (Channel, RawMessage)]]],
backgroundCommHandlerOpt: Option[DefaultCommHandler],
inputHandler: InputHandler,
kernelThreads: KernelThreads,
logCtx: LoggerContext,
extraHandler: MessageHandler,
noExecuteInputFor: Set[String]
) {
private lazy val log = logCtx(getClass)
def replies(requests: Stream[IO, (Channel, RawMessage)]): Stream[IO, (Channel, RawMessage)] = {
val exitSignal = SignallingRef[IO, Boolean](false)
Stream.eval(exitSignal).flatMap { exitSignal0 =>
val interpreterMessageHandler = InterpreterMessageHandlers(
interpreter,
backgroundCommHandlerOpt,
Some(inputHandler),
kernelThreads.queueEc,
logCtx,
io => executeQueue.offer(Some(None -> Stream.exec(io))),
exitSignal0,
noExecuteInputFor
)
val commMessageHandler = backgroundCommHandlerOpt match {
case None =>
MessageHandler.empty
case Some(commHandler) =>
CommMessageHandlers(commHandler.commTargetManager, kernelThreads.queueEc, logCtx)
.messageHandler
}
// handlers whose messages are processed straightaway (no queueing to enforce sequential processing)
val immediateHandlers = inputHandler.messageHandler
.orElse(interpreterMessageHandler.otherHandlers)
.orElse(commMessageHandler)
.orElse(extraHandler)
// for w/e reason, these seem not to be processed on time by the Jupyter classic UI
// (don't know about lab, nteract seems fine, unless it just marks kernels as starting by itself)
val initStream = {
def sendStatus(status: Status) =
Stream(
Message(
Header(
UUID.randomUUID().toString,
"username",
UUID.randomUUID().toString, // Would there be a way to get the session id from the client?
Status.messageType.messageType,
Some(Protocol.versionStr)
),
status,
idents = List(Status.messageType.messageType.getBytes(UTF_8).toSeq)
).on(Channel.Publish)
)
val attemptInit = interpreter.init.attempt.flatMap { a =>
for (e <- a.left)
log.error("Error initializing interpreter", e)
IO.fromEither(a)
}
sendStatus(Status.starting) ++
sendStatus(Status.busy) ++
Stream.exec(attemptInit) ++
sendStatus(Status.idle)
}
val mainStream = {
val requests0 = requests.interruptWhen(exitSignal0)
// For each incoming message, an IO that processes it, and gives the response messages
val scatterMessages: Stream[IO, Unit] =
requests0.evalMap {
case (channel, rawMessage) =>
val outputOpt = interpreterMessageHandler.executeHandler.handleOrLogError(
channel,
rawMessage,
log
)
outputOpt match {
case None =>
// interpreter message handler passes, try with the other handlers
immediateHandlers.handleOrLogError(channel, rawMessage, log) match {
case None =>
log.warn(s"Ignoring unhandled message on $channel:\n$rawMessage")
IO.unit
case Some(output) =>
// process stdin messages and send response back straightaway
otherQueue.offer(Some(output))
}
case Some(output) =>
// enqueue stream that processes the incoming message, so that the main messages are
// still processed and answered in order
executeQueue.offer(Some(Some((channel, rawMessage)), output))
}
}
// Put poison pill (null) at the end of executeQueue when all input messages have been scattered
val scatterMessages0: Stream[IO, Nothing] = {
val bracket = Stream.bracket(IO.unit) { _ =>
executeQueue.offer(None).flatMap(_ => otherQueue.offer(None))
}
bracket.flatMap(_ => Stream.exec(scatterMessages.compile.drain))
}
// Responses for the main messages
val executeReplies = Stream.repeatEval(executeQueue.take)
.takeWhile(_.nonEmpty)
.flatMap(s => s.map(_._2).getOrElse[Stream[IO, (Channel, RawMessage)]](Stream.empty))
// Responses for the other messages
val otherReplies = Stream.repeatEval(otherQueue.take)
.takeWhile(_.nonEmpty)
.flatMap(s => s.getOrElse[Stream[IO, (Channel, RawMessage)]](Stream.empty))
// Merge scatterMessages0 (messages scattered straightaway), executeReplies (responses of execute messages, that are processed sequentially
// via executeQueue), and otherReplies (responses of other messages, that are processed in parallel)
scatterMessages0.merge(executeReplies).merge(otherReplies)
}
// Put poison pill (null) at the end of backgroundMessagesQueue when all input messages have been processed
// and answered.
val mainStream0 = Stream.bracket(IO.unit)(_ => backgroundMessagesQueue.offer(null))
.flatMap(_ => initStream ++ mainStream)
// Merge responses to all incoming messages with background messages (comm messages sent by user code when it
// is run)
mainStream0.merge(Stream.repeatEval(backgroundMessagesQueue.take).takeWhile(_ != null))
}
}
def run(
stream: Stream[IO, (Channel, RawMessage)],
sink: Pipe[IO, (Channel, RawMessage), Unit],
leftoverMessages: Seq[(Channel, RawMessage)]
): IO[Unit] =
sink(replies(Stream(leftoverMessages: _*) ++ stream)).compile.drain
def runOnConnection(
connection: ConnectionParameters,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)]
): IO[Unit] =
for {
t <- runOnConnectionAllowClose0(
connection,
kernelId,
zeromqThreads,
leftoverMessages,
autoClose = true
)
(run, _) = t
_ <- run
} yield ()
private def runOnConnectionAllowClose0(
connection: ConnectionParameters,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[(IO[Unit], Connection)] =
for {
c <- connection.channels(
bind = true,
zeromqThreads,
lingerPeriod = Some(5.minutes),
logCtx = logCtx,
identityOpt = Some(kernelId)
)
} yield {
val run0 =
for {
_ <- c.open
_ <- run(c.stream(), c.autoCloseSink(partial = !autoClose), leftoverMessages)
} yield ()
(run0, c)
}
private def drainExecuteMessages: IO[Seq[(Channel, RawMessage)]] =
Stream.repeatEval(executeQueue.take)
.takeWhile(_.nonEmpty)
.flatMap(s => Stream(s.flatMap(_._1).toSeq: _*))
.compile
.toVector
def runOnConnectionAllowClose(
connection: ConnectionParameters,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] =
runOnConnectionAllowClose0(
connection,
kernelId,
zeromqThreads,
leftoverMessages,
autoClose
).map {
case (run, conn) =>
val run0 = run.attempt.flatMap {
case Left(e: CloseExecutionException) =>
drainExecuteMessages.map { messages =>
e.messages ++ messages
}
case Left(e) =>
IO.raiseError(e)
case Right(()) =>
IO.pure(Nil)
}
(run0, conn)
}
def runOnConnectionFileAllowClose(
connectionPath: Path,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] =
for {
_ <- {
if (Files.exists(connectionPath))
IO.unit
else
IO.raiseError(new Exception(s"Connection file $connectionPath not found"))
}
_ <- {
if (Files.isRegularFile(connectionPath))
IO.unit
else
IO.raiseError(new Exception(s"Connection file $connectionPath not a regular file"))
}
connection <- JsonConnection.fromPath(connectionPath)
value <- runOnConnectionAllowClose(
connection.connectionParameters,
kernelId,
zeromqThreads,
leftoverMessages,
autoClose
)
} yield value
def runOnConnectionFile(
connectionPath: Path,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[Unit] =
for {
t <- runOnConnectionFileAllowClose(
connectionPath,
kernelId,
zeromqThreads,
leftoverMessages,
autoClose
)
(run, _) = t
_ <- run
} yield ()
def runOnConnectionFile(
connectionPath: String,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[Unit] =
for {
t <- runOnConnectionFileAllowClose(
connectionPath,
kernelId,
zeromqThreads,
leftoverMessages,
autoClose
)
(run, _) = t
_ <- run
} yield ()
def runOnConnectionFileAllowClose(
connectionPath: String,
kernelId: String,
zeromqThreads: ZeromqThreads,
leftoverMessages: Seq[(Channel, RawMessage)],
autoClose: Boolean
): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] =
runOnConnectionFileAllowClose(
Paths.get(connectionPath),
kernelId,
zeromqThreads,
leftoverMessages,
autoClose
)
}
object Kernel {
def create(
interpreter: Interpreter,
interpreterEc: ExecutionContext,
kernelThreads: KernelThreads,
logCtx: LoggerContext,
extraHandler: MessageHandler,
noExecuteInputFor: Set[String]
): IO[Kernel] =
create(
new InterpreterToIOInterpreter(interpreter, interpreterEc, logCtx),
kernelThreads,
logCtx,
extraHandler,
noExecuteInputFor
)
def create(
interpreter: Interpreter,
interpreterEc: ExecutionContext,
kernelThreads: KernelThreads,
logCtx: LoggerContext = LoggerContext.nop
): IO[Kernel] =
create(
interpreter,
interpreterEc,
kernelThreads,
logCtx,
MessageHandler.empty,
Set.empty
)
def create(
interpreter: IOInterpreter,
kernelThreads: KernelThreads,
logCtx: LoggerContext,
extraHandler: MessageHandler,
noExecuteInputFor: Set[String]
): IO[Kernel] =
for {
backgroundMessagesQueue <- Queue.unbounded[IO, (Channel, RawMessage)]
executeQueue <- Queue.unbounded[IO, Option[(
Option[(Channel, RawMessage)],
Stream[IO, (Channel, RawMessage)]
)]]
otherQueue <- Queue.unbounded[IO, Option[Stream[IO, (Channel, RawMessage)]]]
backgroundCommHandlerOpt <- IO {
if (interpreter.supportComm)
Some {
val h = new DefaultCommHandler(backgroundMessagesQueue, kernelThreads.commEc)
interpreter.setCommHandler(h)
h
}
else
None
}
inputHandler <- IO {
new InputHandler(kernelThreads.futureEc, logCtx)
}
} yield Kernel(
interpreter,
backgroundMessagesQueue,
executeQueue,
otherQueue,
backgroundCommHandlerOpt,
inputHandler,
kernelThreads,
logCtx,
extraHandler,
noExecuteInputFor
)
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy