All Downloads are FREE. Search and download functionalities are using the official Maven repository.

almond.launcher.Launcher.scala Maven / Gradle / Ivy

The newest version!
package almond.launcher

import almond.channels.{Channel, Connection, Message => RawMessage}
import almond.channels.zeromq.ZeromqThreads
import almond.cslogger.NotebookCacheLogger
import almond.directives.KernelOptions
import almond.interpreter.ExecuteResult
import almond.interpreter.api.{DisplayData, OutputHandler}
import almond.kernel.install.Install
import almond.kernel.{Kernel, KernelThreads, MessageFile}
import almond.launcher.directives.LauncherParameters
import almond.logger.{Level, LoggerContext}
import almond.protocol.{Execute, RawJson}
import almond.util.ThreadUtil.singleThreadedExecutionContext
import caseapp.core.RemainingArgs
import caseapp.core.app.CaseApp
import cats.effect.IO
import cats.effect.unsafe.IORuntime
import com.github.plokhotnyuk.jsoniter_scala.core._
import coursier.launcher.{BootstrapGenerator, ClassLoaderContent, ClassPathEntry, Parameters}
import dependency.ScalaParameters

import java.io.{File, FileOutputStream, PrintStream}
import java.nio.channels.ClosedSelectorException

import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import scala.concurrent.duration.Duration

object Launcher extends CaseApp[LauncherOptions] {

  private def actualKernelCommand(
    connectionFile: String,
    msgFileOpt: Option[os.Path],
    currentCellCount: Int,
    options: LauncherOptions,
    noExecuteInputFor: Seq[String],
    params0: LauncherParameters,
    kernelOptions: KernelOptions,
    outputHandler: OutputHandler,
    logCtx: LoggerContext
  ): (os.proc, String, Option[String]) = {

    val (scalaVersion, _) = LauncherInterpreter.computeScalaVersion(params0, options)

    def content(entries: Seq[(coursierapi.Artifact, File)]): ClassLoaderContent = {
      val entries0 = entries.map {
        case (a, _) =>
          ClassPathEntry.Url(a.getUrl)
      }
      ClassLoaderContent(entries0)
    }

    val logger =
      if (options.quiet0)
        coursierapi.Logger.nop()
      else
        new NotebookCacheLogger(outputHandler, logCtx)

    val cache = coursierapi.Cache.create().withLogger(logger)
    val forcedVersions =
      if (scalaVersion.startsWith("2."))
        Map(
          coursierapi.Module.of("org.scala-lang", "scala-library")  -> scalaVersion,
          coursierapi.Module.of("org.scala-lang", "scala-compiler") -> scalaVersion,
          coursierapi.Module.of("org.scala-lang", "scala-reflect")  -> scalaVersion
        )
      else
        Map(
          coursierapi.Module.of("org.scala-lang", "scala3-library_3")  -> scalaVersion,
          coursierapi.Module.of("org.scala-lang", "scala3-compiler_3") -> scalaVersion,
          coursierapi.Module.of("org.scala-lang", "scala3-interfaces") -> scalaVersion
        )
    val resolutionParams = coursierapi.ResolutionParams.create()
      .forceVersions(forcedVersions.asJava)
    def fetcher = coursierapi.Fetch.create()
      .withCache(cache)
      .addRepositories(coursierapi.MavenRepository.of("https://jitpack.io"))
      .withResolutionParams(resolutionParams)
    def fetch(
      dep: coursierapi.Dependency,
      extraDeps: Seq[coursierapi.Dependency] = Nil,
      sources: Boolean = false
    ) = {
      var fetcher0 = fetcher
      if (sources)
        fetcher0 = fetcher0.addClassifiers("sources")
      fetcher0 = fetcher0.addDependencies(dep)
      for (dep0 <- extraDeps)
        fetcher0 = fetcher0.addDependencies(dep0)
      fetcher0
        .fetchResult()
        .getArtifacts
        .asScala
        .toVector
        .map(e => (e.getKey, e.getValue))
    }
    val apiFiles = {
      val scalaParams = ScalaParameters(scalaVersion)
      val extraDeps = options.sharedDependencies.map { str =>
        val dep = dependency.parser.DependencyParser.parse(str) match {
          case Left(err) =>
            sys.error(s"Malformed shared dependency '$str': $err")
          case Right(dep0) => dep0
        }

        val javaDep = dep.applyParams(scalaParams)
        val version = if (javaDep.version == "_") Properties.version else javaDep.version
        coursierapi.Dependency.of(javaDep.organization, javaDep.name, version)
      }
      val dep = coursierapi.Dependency.of(
        "sh.almond",
        s"scala-kernel-api_$scalaVersion",
        Properties.version
      )
      fetch(dep, extraDeps = extraDeps) ++ fetch(dep, extraDeps = extraDeps, sources = true)
    }
    val kernelFiles = {
      val dep = coursierapi.Dependency.of(
        "sh.almond",
        s"scala-kernel_$scalaVersion",
        Properties.version
      )
      val files = fetch(dep) ++ fetch(dep, sources = true)
      files.filter {
        val set = apiFiles.map(_._2).toSet
        t =>
          val f = t._2
          !set.contains(f)
      }
    }

    val launcher = os.temp(prefix = "almond", suffix = ".jar")
    val params = Parameters.Bootstrap(
      Seq(content(apiFiles), content(kernelFiles)),
      Properties.kernelMainClass
    )
    BootstrapGenerator.generate(params, launcher.toNIO)

    val msgFileArgs = msgFileOpt.toSeq.flatMap { msgFile =>
      Seq[os.Shellable]("--leftover-messages", msgFile)
    }
    val noExecuteInputArgs = noExecuteInputFor.flatMap { id =>
      Seq("--no-execute-input-for", id, "--ignore-launcher-directives-in", id)
    }

    val optionsArgs =
      if (kernelOptions.isEmpty) Nil
      else {
        val asJson      = KernelOptions.AsJson(kernelOptions)
        val bytes       = writeToArray(asJson)(KernelOptions.AsJson.codec)
        val optionsFile = os.temp(bytes, prefix = "almond-options-", suffix = ".json")
        Seq[os.Shellable]("--kernel-options", optionsFile)
      }

    val jvmIdOpt = params0.jvm.filter(_.trim.nonEmpty)
    val javaCommand = jvmIdOpt match {
      case Some(jvmId) =>
        val jvmManager = coursierapi.JvmManager.create().setArchiveCache(
          coursierapi.ArchiveCache.create().withCache(cache)
        )
        val javaHome = os.Path(jvmManager.get(jvmId), os.pwd)
        val ext      = if (scala.util.Properties.isWin) ".exe" else ""
        Seq((javaHome / "bin" / s"java$ext").toString)
      case None =>
        params0.javaCmd.getOrElse(Seq("java"))
    }

    val javaOptions = options.javaOpt ++ params0.javaOptions

    val memOptions =
      if (javaOptions.exists(_.startsWith("-Xmx"))) Nil
      else Seq("-Xmx512m")

    val proc = os.proc(
      javaCommand,
      memOptions,
      javaOptions,
      "-cp",
      (options.extraStartupClassPath :+ launcher.toString)
        .filter(_.nonEmpty)
        .mkString(File.pathSeparator),
      "coursier.bootstrap.launcher.Launcher",
      "--connection-file",
      connectionFile,
      "--initial-cell-count",
      currentCellCount,
      msgFileArgs,
      noExecuteInputArgs,
      optionsArgs,
      options.kernelOptions,
      params0.kernelOptions
    )

    (proc, scalaVersion, jvmIdOpt)
  }

  private def launchActualKernel(proc: os.proc): Unit = {

    System.err.println(s"Launching ${proc.command.flatMap(_.value).mkString(" ")}")
    val p = proc.spawn(stdin = os.Inherit, stdout = os.Inherit)
    val hook: Thread =
      new Thread("shutdown-kernel") {
        setDaemon(true)
        override def run(): Unit =
          if (p.isAlive()) {
            p.close()
            val timeout = 500.millis
            if (!p.waitFor(timeout.toMillis)) {
              System.err.println(
                s"Underlying kernel still running after $timeout, destroying it forcibly"
              )
              p.destroyForcibly()
            }
          }
      }
    Runtime.getRuntime.addShutdownHook(hook)
    p.waitFor()
    try Runtime.getRuntime.removeShutdownHook(hook)
    catch {
      case e: IllegalStateException =>
        System.err.println("Ignoring error while trying to remove shutdown hook")
        e.printStackTrace(System.err)
    }
    val exitCode = p.exitCode()
    System.err.println(s"Sub-kernel exited with return code $exitCode")
    if (exitCode != 0)
      sys.exit(exitCode)
  }

  def run(options: LauncherOptions, remainingArgs: RemainingArgs): Unit = {

    // FIXME We'd need coursier-interface to allow us to do these:

    // if (Properties.isWin && isGraalvmNativeImage)
    //   // have to be initialized before running (new Argv0).get because Argv0SubstWindows uses csjniutils library
    //   // The DLL loaded by LoadWindowsLibrary is statically linke/d in
    //   // the Scala CLI native image, no need to manually load it.
    //   coursier.jniutils.LoadWindowsLibrary.assumeInitialized()

    // coursier.Resolve.proxySetup()

    // if (Properties.isWin && System.console() != null && coursier.paths.Util.useJni())
    //   // Enable ANSI output in Windows terminal
    //   coursier.jniutils.WindowsAnsiTerminal.enableAnsiOutput()

    val logCtx = Level.fromString(options.log.getOrElse("warn")) match {
      case Left(err) =>
        Console.err.println(err)
        sys.exit(1)
      case Right(level) =>
        options.logTo match {
          case None =>
            LoggerContext.stderr(
              level,
              options.color.getOrElse(true),
              addPid = true
            )
          case Some(f) =>
            LoggerContext.printStream(
              level,
              new PrintStream(new FileOutputStream(new File(f))),
              options.color.getOrElse(true),
              addPid = true
            )
        }
    }

    val log = logCtx(getClass)

    if (options.install)
      Install.installOrError(
        defaultId = "scala",
        defaultDisplayName = "Scala",
        language = "scala",
        options = options.installOptions,
        defaultLogoOpt = Option(
          Thread.currentThread()
            .getContextClassLoader
            .getResource("almond/scala-logo-64x64.png")
        ),
        connectionFileArgs = Install.defaultConnectionFileArgs,
        interruptMode =
          if (options.installOptions.interruptViaMessage)
            Some("message")
          else
            None,
        env = options.installOptions.envMap(),
        extraStartupClassPath = Nil
      ) match {
        case Left(e) =>
          log.debug("Cannot install kernel", e)
          Console.err.println(s"Error: ${e.getMessage}")
          sys.exit(1)
        case Right(dir) =>
          println(s"Installed scala kernel under $dir")
          sys.exit(0)
      }

    val connectionFile = options.connectionFile.getOrElse {
      Console.err.println(
        "No connection file passed, and installation not asked. Run with --install to install the kernel, " +
          "or pass a connection file via --connection-file to run the kernel."
      )
      sys.exit(1)
    }

    val colors =
      if (options.color.getOrElse(true)) LauncherInterpreter.Colors.default
      else LauncherInterpreter.Colors.blackWhite

    val interpreterEc = singleThreadedExecutionContext("scala-launcher-interpreter")

    val zeromqThreads = ZeromqThreads.create("scala-kernel-launcher")
    val kernelThreads = KernelThreads.create("scala-kernel-launcher")

    val interpreter = new LauncherInterpreter(
      connectionFile,
      options
    )

    val (run, conn) = Kernel.create(interpreter, interpreterEc, kernelThreads, logCtx)
      .flatMap(_.runOnConnectionFileAllowClose(
        connectionFile,
        "scala",
        zeromqThreads,
        Nil,
        autoClose = false,
        lingerDuration = Duration.Inf // unused here
      ))
      .unsafeRunSync()(IORuntime.global)
    val leftoverMessages: Seq[(Channel, RawMessage)] = run.unsafeRunSync()(IORuntime.global)

    val leftoverMessagesFileOpt =
      if (leftoverMessages.isEmpty) None
      else {
        val msgFile = MessageFile.from(leftoverMessages)
        val leftoverMessagesFile = os.temp(
          msgFile.asJson.value,
          prefix = "almond-launcher-leftover-messages-",
          suffix = ".json"
        )
        Some(leftoverMessagesFile)
      }

    val firstMessageOpt = leftoverMessages
      .headOption
      .collect {
        case (Channel.Requests, m) =>
          almond.interpreter.Message.parse[RawJson](m).toOption // FIXME Log any error on the left?
      }
      .flatten

    val firstMessageIdOpt = firstMessageOpt.map(_.header.msg_id)

    val outputHandlerOpt = firstMessageOpt.map { firstMessage =>
      new LauncherOutputHandler(firstMessage, conn)
    }

    val maybeActualKernelCommand =
      try {
        val (launcherParams, kernelParams) =
          interpreter.params.processCustomDirectives(interpreter.kernelOptions)
        val (actualKernelCommand0, scalaVersion, jvmOpt) = actualKernelCommand(
          connectionFile,
          leftoverMessagesFileOpt,
          interpreter.lineCount,
          options,
          firstMessageIdOpt.toSeq,
          launcherParams,
          kernelParams,
          outputHandlerOpt.getOrElse(OutputHandler.NopOutputHandler),
          logCtx
        )

        if (!options.quiet0)
          for (outputHandler <- outputHandlerOpt) {
            val toPrint =
              s"Launching Scala $scalaVersion kernel" + jvmOpt.fold("")(jvm => s" with JVM $jvm")
            val toPrintHtml =
              s"Launching Scala $scalaVersion kernel" +
                jvmOpt.fold("")(jvm => s" with JVM $jvm")
            outputHandler.display(
              DisplayData(
                Map(
                  DisplayData.ContentType.text -> toPrint,
                  DisplayData.ContentType.html -> toPrintHtml
                )
              )
            )
          }

        Right(actualKernelCommand0)
      }
      catch {
        case NonFatal(e) if firstMessageOpt.nonEmpty =>
          val firstMessage = firstMessageOpt.getOrElse(sys.error("Cannot happen"))
          val err = ExecuteResult.Error.error(fansi.Color.Red, fansi.Color.Green, Some(e), "")
          val errMsg = firstMessage.publish(
            Execute.errorType,
            Execute.Error("", "", List(err.message))
          )
          try conn.send(Channel.Publish, errMsg.asRawMessage).unsafeRunSync()(IORuntime.global)
          catch {
            case NonFatal(e) =>
              throw new Exception(e)
          }
          Left(e)
      }

    for (outputHandler <- outputHandlerOpt)
      outputHandler.done()

    try
      conn.close(partial = false, lingerDuration = options.lingerDuration)
        .unsafeRunSync()(IORuntime.global)
    catch {
      case NonFatal(e) =>
        throw new Exception(e)
    }

    log.debug("Closing ZeroMQ context")
    IO(zeromqThreads.context.close())
      .evalOn(zeromqThreads.pollingEc)
      .unsafeRunSync()(IORuntime.global)
    log.debug("ZeroMQ context closed")

    maybeActualKernelCommand match {
      case Right(actualKernelCommand0) =>
        val proc0 = os.proc(actualKernelCommand0.commandChunks, remainingArgs.unparsed)
        launchActualKernel(proc0)
      case Left(e) =>
        throw new Exception(e)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy