All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sbt.internal.client.NetworkClient.scala Maven / Gradle / Ivy

The newest version!
/*
 * sbt
 * Copyright 2023, Scala center
 * Copyright 2011 - 2022, Lightbend, Inc.
 * Copyright 2008 - 2010, Mark Harrah
 * Licensed under Apache License 2.0 (see LICENSE)
 */

package sbt
package internal
package client

import java.io.{ File, IOException, InputStream, PrintStream }
import java.lang.ProcessBuilder.Redirect
import java.net.{ Socket, SocketException }
import java.nio.file.Files
import java.util.UUID
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }

import sbt.BasicCommandStrings.{ DashDashDetachStdio, DashDashServer, Shutdown, TerminateAction }
import sbt.internal.client.NetworkClient.Arguments
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
import sbt.internal.protocol._
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Signals, Terminal, Util }
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol._
import sbt.util.Level
import sjsonnew.BasicJsonProtocol._
import sjsonnew.shaded.scalajson.ast.unsafe.{ JObject, JValue }
import sjsonnew.support.scalajson.unsafe.Converter

import scala.annotation.tailrec
import scala.collection.mutable
import scala.concurrent.duration._
import scala.util.control.NonFatal
import scala.util.{ Failure, Properties, Success, Try }
import Serialization.{
  CancelAll,
  attach,
  cancelReadSystemIn,
  cancelRequest,
  promptChannel,
  readSystemIn,
  systemIn,
  systemErr,
  systemOut,
  systemOutFlush,
  systemErrFlush,
  terminalCapabilities,
  terminalCapabilitiesResponse,
  terminalGetSize,
  terminalPropertiesQuery,
  terminalPropertiesResponse,
  terminalSetEcho,
  terminalSetRawMode,
  terminalSetSize,
  getTerminalAttributes,
  setTerminalAttributes,
}
import NetworkClient.Arguments
import java.util.concurrent.TimeoutException

trait ConsoleInterface {
  def appendLog(level: Level.Value, message: => String): Unit
  def success(msg: String): Unit
}

/**
 * A NetworkClient connects to a running sbt instance or starts a
 * new instance if there isn't already one running. Once connected,
 * it can send commands for sbt to run, it can send completions to sbt
 * and print the completions to stdout so that a shell can consume
 * the completions or it can enter an interactive sbt shell session
 * in which it relays io bytes between sbt and the terminal.
 *
 * @param arguments   the arguments for the forked sbt server if the client
 *                    needs to start it. It also contains the sbt command
 *                    arguments to send to the server if any are present.
 * @param console     a logging instance. This can use a ConsoleAppender or
 *                    just simply print to a PrintStream.
 * @param inputStream the InputStream from which the client reads bytes. It
 *                    is not hardcoded to System.in so that a NetworkClient
 *                    can be remotely controlled by a java process, which
 *                    is useful in testing.
 * @param errorStream the sink for messages that we always want to be printed.
 *                    It is usually System.err but could be overridden in tests
 *                    or set to a null OutputStream if the NetworkClient needs
 *                    to be silent.
 * @param printStream the sink for standard out messages. It is typically
 *                    System.out but in the case of completions, the bytes written
 *                    to System.out are usually treated as completion results
 *                    so we need to reroute standard out messages to System.err.
 *                    It's also useful to override this in testing.
 * @param useJNI      toggles whether or not to use the jni based implementations
 *                    in org.scalasbt.ipcsocket. These are only available on
 *                    64 bit linux, mac and windows. Any other platform will need
 *                    to fall back on jna.
 */
class NetworkClient(
    arguments: Arguments,
    console: ConsoleInterface,
    inputStream: InputStream,
    errorStream: PrintStream,
    printStream: PrintStream,
    useJNI: Boolean,
) extends AutoCloseable { self =>
  def this(configuration: xsbti.AppConfiguration, arguments: Arguments) =
    this(
      arguments = arguments.withBaseDirectory(configuration.baseDirectory),
      console = NetworkClient.consoleAppenderInterface(System.out),
      inputStream = System.in,
      errorStream = System.err,
      printStream = System.out,
      useJNI = false,
    )
  def this(configuration: xsbti.AppConfiguration, args: List[String]) =
    this(
      console = NetworkClient.consoleAppenderInterface(System.out),
      arguments =
        NetworkClient.parseArgs(args.toArray).withBaseDirectory(configuration.baseDirectory),
      inputStream = System.in,
      errorStream = System.err,
      printStream = System.out,
      useJNI = false,
    )
  private val status = new AtomicReference("Ready")
  private val lock: AnyRef = new AnyRef {}
  private val running = new AtomicBoolean(true)
  private val pendingResults =
    new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long, String)]
  private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]]
  private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
  private val attached = new AtomicBoolean(false)
  private val attachUUID = new AtomicReference[String](null)
  private val connectionHolder = new AtomicReference[ServerConnection]
  private val batchMode = new AtomicBoolean(false)
  private val interactiveThread = new AtomicReference[Thread](null)
  private val rebooting = new AtomicBoolean(false)
  private lazy val noTab = arguments.completionArguments.contains("--no-tab")
  private lazy val noStdErr = arguments.completionArguments.contains("--no-stderr") &&
    !sys.env.contains("SBTN_AUTO_COMPLETE") && !sys.env.contains("SBTC_AUTO_COMPLETE")

  private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI)

  private def portfile = arguments.baseDirectory / "project" / "target" / "active.json"

  def connection: ServerConnection = connectionHolder.synchronized {
    connectionHolder.get match {
      case null => init(promptCompleteUsers = false, retry = true)
      case c    => c
    }
  }

  private[this] val stdinBytes = new LinkedBlockingQueue[Integer]
  private[this] val inLock = new Object
  private[this] val inputThread = new AtomicReference[RawInputThread]
  private[this] val exitClean = new AtomicBoolean(true)
  private[this] val sbtProcess = new AtomicReference[Process](null)
  private class ConnectionRefusedException(t: Throwable) extends Throwable(t)
  private class ServerFailedException extends Exception
  private[this] def startInputThread(): Unit = inputThread.get match {
    case null => inputThread.set(new RawInputThread)
    case _    =>
  }

  private[sbt] def connectOrStartServerAndConnect(
      promptCompleteUsers: Boolean,
      retry: Boolean
  ): (Socket, Option[String]) =
    try {
      if (!portfile.exists) {
        if (promptCompleteUsers) {
          val msg = if (noTab) "" else "No sbt server is running. Press  to start one..."
          errorStream.print(s"\n$msg")
          if (noStdErr) System.exit(0)
          else if (noTab) waitForServer(portfile, log = true, startServer = true)
          else {
            startInputThread()
            stdinBytes.poll(5, TimeUnit.SECONDS) match {
              case null => System.exit(0)
              case i if i == 9 =>
                errorStream.println("\nStarting server...")
                waitForServer(portfile, !promptCompleteUsers, startServer = true)
              case _ => System.exit(0)
            }
          }
        } else {
          waitForServer(portfile, log = true, startServer = true)
        }
      }
      @tailrec def connect(attempt: Int): (Socket, Option[String]) = {
        val res = try Some(mkSocket(portfile))
        catch {
          // This catches a pipe busy exception which can happen if two windows clients
          // attempt to connect in rapid succession
          case e: IOException if e.getMessage.contains("Couldn't open") && attempt < 10 =>
            if (e.getMessage.contains("Access is denied") || e.getMessage.contains("(5)")) {
              errorStream.println(s"Access denied for portfile $portfile")
              throw new NetworkClient.AccessDeniedException
            }
            None
          case e: IOException => throw new ConnectionRefusedException(e)
        }
        res match {
          case Some(r) => r
          case None    =>
            // Use a random sleep to spread out the competing processes
            Thread.sleep(new java.util.Random().nextInt(20).toLong)
            connect(attempt + 1)
        }
      }
      connect(0)
    } catch {
      case e: ConnectionRefusedException if retry =>
        if (Files.deleteIfExists(portfile.toPath))
          connectOrStartServerAndConnect(promptCompleteUsers, retry = false)
        else throw e
    }

  // Open server connection based on the portfile
  def init(promptCompleteUsers: Boolean, retry: Boolean): ServerConnection = {
    val (sk, tkn) = connectOrStartServerAndConnect(promptCompleteUsers, retry)
    val conn = new ServerConnection(sk) {
      override def onNotification(msg: JsonRpcNotificationMessage): Unit = {
        msg.method match {
          case `Shutdown` =>
            val (log, rebootCommands) = msg.params match {
              case Some(jvalue) =>
                Converter
                  .fromJson[(Boolean, Option[(String, String)])](jvalue)
                  .getOrElse((true, None))
              case _ => (false, None)
            }
            if (rebootCommands.nonEmpty) {
              rebooting.set(true)
              attached.set(false)
              connectionHolder.getAndSet(null) match {
                case null =>
                case c    => c.shutdown()
              }
              waitForServer(portfile, true, false)
              init(promptCompleteUsers = false, retry = false)
              attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}"""))
              rebooting.set(false)
              rebootCommands match {
                case Some((execId, cmd)) if execId.nonEmpty =>
                  if (batchMode.get && !pendingResults.containsKey(execId) && cmd.nonEmpty) {
                    console.appendLog(
                      Level.Error,
                      s"received request to re-run unknown command '$cmd' after reboot"
                    )
                  } else if (cmd.nonEmpty) {
                    if (batchMode.get) sendCommand(ExecCommand(cmd, execId))
                    else
                      inLock.synchronized {
                        val toSend = cmd.getBytes :+ '\r'.toByte
                        toSend.foreach(b => sendNotification(systemIn, b.toString))
                      }
                  } else completeExec(execId, 0)
                case _ =>
              }
            } else {
              if (!rebooting.get() && running.compareAndSet(true, false) && log) {
                if (!arguments.commandArguments.contains(Shutdown)) {
                  console.appendLog(Level.Error, "sbt server disconnected")
                  exitClean.set(false)
                }
              } else {
                console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected")
              }
              stdinBytes.offer(-1)
              Option(inputThread.get).foreach(_.close())
              Option(interactiveThread.get).foreach(_.interrupt)
            }
          case `readSystemIn` => startInputThread()
          case `cancelReadSystemIn` =>
            inputThread.get match {
              case null =>
              case t    => t.close()
            }
          case _ => self.onNotification(msg)
        }
      }
      override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg)
      override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg)
      override def onShutdown(): Unit = if (!rebooting.get) {
        if (exitClean.get != false) exitClean.set(!running.get)
        running.set(false)
        Option(interactiveThread.get).foreach(_.interrupt())
      }
    }
    // initiate handshake
    val execId = UUID.randomUUID.toString
    val initCommand = InitCommand(tkn, Option(execId), Some(true))
    conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand))
    connectionHolder.set(conn)
    conn
  }

  /**
   * Forks another instance of sbt in the background.
   * This instance must be shutdown explicitly via `sbt -client shutdown`
   */
  def waitForServer(portfile: File, log: Boolean, startServer: Boolean): Unit = {
    val bootSocketName =
      BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath())

    /*
     * For unknown reasons, linux sometimes struggles to connect to the socket in some
     * scenarios.
     */
    var socket: Option[Socket] =
      if (!Properties.isLinux) Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
      else None
    val term = Terminal.console
    term.exitRawMode()
    val process = socket match {
      case None if startServer =>
        if (log) console.appendLog(Level.Info, "server was not detected. starting an instance")

        val props =
          Seq(
            term.getWidth,
            term.getHeight,
            term.isAnsiSupported,
            term.isColorEnabled,
            term.isSupershellEnabled
          ).mkString(",")

        val cmd = arguments.sbtLaunchJar match {
          case Some(lj) =>
            if (log) {
              val sbtScript = if (Properties.isWin) "sbt.bat" else "sbt"
              console.appendLog(Level.Warn, s"server is started using sbt-launch jar directly")
              console.appendLog(
                Level.Warn,
                "this is not the recommended way: .sbtopts and .jvmopts files are not loaded and SBT_OPTS is ignored"
              )
              console.appendLog(
                Level.Warn,
                s"either upgrade $sbtScript to its latest version or make sure it is accessible from $$PATH, and run 'sbt bspConfig'"
              )
            }
            val java = Option(Properties.javaHome)
              .map { javaHome =>
                s"$javaHome/bin/java"
              }
              .getOrElse("java")
            List(java) ++ arguments.sbtArguments ++
              List("-jar", lj, DashDashDetachStdio, DashDashServer)
          case _ =>
            List(arguments.sbtScript) ++ arguments.sbtArguments ++
              List(DashDashDetachStdio, DashDashServer)
        }

        // https://github.com/sbt/sbt/issues/6271
        val nohup =
          if (Util.isEmacs && !Util.isWindows) List("nohup")
          else Nil

        val processBuilder =
          new ProcessBuilder((nohup ++ cmd): _*)
            .directory(arguments.baseDirectory)
            .redirectInput(Redirect.PIPE)
        processBuilder.environment.put(Terminal.TERMINAL_PROPS, props)
        Try(processBuilder.start()) match {
          case Success(process) =>
            sbtProcess.set(process)
            Some(process)
          case Failure(e) =>
            if (log) console.appendLog(Level.Error, s"Failed to start server : $e")
            throw new ServerFailedException
        }
      case _ =>
        if (log) {
          console.appendLog(Level.Info, "sbt server is booting up")
        }
        None
    }
    if (!startServer) {
      val deadline = 5.seconds.fromNow
      while (socket.isEmpty && !deadline.isOverdue()) {
        socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
        if (socket.isEmpty) Thread.sleep(20)
      }
    }
    val shutdown = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly()))
    Runtime.getRuntime.addShutdownHook(shutdown)
    var gotInputBack = false
    val readThreadAlive = new AtomicBoolean(true)
    /*
     * Socket.getInputStream.available doesn't always return a value greater than 0
     * so it is necessary to read the process output from the socket on a background
     * thread.
     */
    val readThread = new Thread("client-read-thread") {
      setDaemon(true)
      start()
      override def run(): Unit = {
        try {
          while (readThreadAlive.get) {
            if (socket.isEmpty) {
              socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
            }
            socket.foreach { s =>
              try {
                s.getInputStream.read match {
                  case -1 | 0 => readThreadAlive.set(false)
                  case 2 => // STX: start of text
                    gotInputBack = true
                  case 5 => // ENQ: enquiry
                    term.enterRawMode(); startInputThread()
                  case 3 if gotInputBack => // ETX: end of text
                    readThreadAlive.set(false)
                  case i if gotInputBack => stdinBytes.offer(i)
                  case i                 => printStream.write(i)
                }
              } catch {
                case e @ (_: IOException | _: InterruptedException) =>
                  readThreadAlive.set(false)
              }
            }
            if (socket.isEmpty && readThreadAlive.get) {
              try Thread.sleep(10)
              catch { case _: InterruptedException => }
            }
          }
        } catch { case e: IOException => e.printStackTrace(System.err) }
      }
    }
    @tailrec
    def blockUntilStart(): Unit = {
      val stop = try {
        socket match {
          case None =>
            process.foreach { p =>
              val output = p.getInputStream
              while (output.available > 0) {
                printStream.write(output.read())
              }
            }
          case Some(s) =>
            while (!gotInputBack && !stdinBytes.isEmpty && socket.isDefined) {
              val out = s.getOutputStream
              val b = stdinBytes.poll
              if (b == -1) {
                // server waits for user input but stinBytes has ended
                shutdown.run()
              } else {
                out.write(b)
                out.flush()
              }
            }
        }
        process.foreach { p =>
          val error = p.getErrorStream
          while (error.available > 0) {
            errorStream.write(error.read())
          }
        }
        false
      } catch { case e: IOException => true }
      Thread.sleep(10)
      printStream.flush()
      errorStream.flush()
      /*
       * If an earlier server process is launching, the process launched by this client
       * will return with exit value 2. In that case, we can treat the process as alive
       * even if it is actually dead.
       */
      val existsValidProcess =
        process.fold(readThreadAlive.get)(p => p.isAlive || (Properties.isWin || p.exitValue == 2))
      if (!portfile.exists && !stop && existsValidProcess) {
        blockUntilStart()
      } else {
        socket.foreach { s =>
          s.getInputStream.close()
          s.getOutputStream.close()
          s.close()
        }
        readThread.interrupt()
        process.foreach { p =>
          p.getOutputStream.close()
          p.getErrorStream.close()
          p.getInputStream.close()
        }
      }
    }

    try blockUntilStart()
    catch { case t: Throwable => t.printStackTrace() } finally {
      sbtProcess.set(null)
      Util.ignoreResult(Runtime.getRuntime.removeShutdownHook(shutdown))
    }
    if (!portfile.exists()) throw new ServerFailedException
    if (attached.get && !stdinBytes.isEmpty) Option(inputThread.get).foreach(_.drain())
  }

  /** Called on the response for a returning message. */
  def onReturningReponse(msg: JsonRpcResponseMessage): Unit = {
    def printResponse(): Unit = {
      msg.result match {
        case Some(result) =>
          // ignore result JSON
          console.success("completed")
        case _ =>
          msg.error match {
            case Some(err) =>
              // ignore err details
              console.appendLog(Level.Error, "completed")
            case _ => // ignore
          }
      }
    }
    printResponse()
  }

  private def getExitCode(jvalue: Option[JValue]): Integer = jvalue match {
    case Some(o: JObject) =>
      o.value
        .collectFirst {
          case v if v.field == "exitCode" =>
            Converter.fromJson[Integer](v.value).getOrElse(Integer.valueOf(1))
        }
        .getOrElse(1)
    case _ => 1
  }
  private def completeExec(execId: String, exitCode: => Int): Unit =
    pendingResults.remove(execId) match {
      case null =>
      case (q, startTime, name) =>
        val now = System.currentTimeMillis
        val message = timing(startTime, now)
        val ec = exitCode
        if (batchMode.get || !attached.get) {
          if (ec == 0) console.success(message)
          else console.appendLog(Level.Error, message)
        }
        Util.ignoreResult(q.offer(ec))
    }
  def onResponse(msg: JsonRpcResponseMessage): Unit = {
    completeExec(msg.id, getExitCode(msg.result))
    pendingCancellations.remove(msg.id) match {
      case null =>
      case q    => q.offer(msg.toString.contains("Task cancelled"))
    }
    msg.id match {
      case execId =>
        if (attachUUID.get == msg.id) {
          attachUUID.set(null)
          attached.set(true)
          Option(inputThread.get).foreach(_.drain())
        }
        pendingCompletions.remove(execId) match {
          case null =>
          case completions =>
            completions(msg.result match {
              case Some(o: JObject) =>
                o.value
                  .foldLeft(CompletionResponse(Vector.empty[String])) {
                    case (resp, i) =>
                      if (i.field == "items")
                        resp.withItems(
                          Converter
                            .fromJson[Vector[String]](i.value)
                            .getOrElse(Vector.empty[String])
                        )
                      else if (i.field == "cachedTestNames")
                        resp.withCachedTestNames(
                          Converter.fromJson[Boolean](i.value).getOrElse(true)
                        )
                      else if (i.field == "cachedMainClassNames")
                        resp.withCachedMainClassNames(
                          Converter.fromJson[Boolean](i.value).getOrElse(true)
                        )
                      else resp
                  }
              case _ => CompletionResponse(Vector.empty[String])
            })
        }
    }
  }

  def onNotification(msg: JsonRpcNotificationMessage): Unit = {
    def splitToMessage: Vector[(Level.Value, String)] =
      (msg.method, msg.params) match {
        case ("build/logMessage", Some(json)) =>
          if (!attached.get) {
            import sbt.internal.langserver.codec.JsonProtocol._
            Converter.fromJson[LogMessageParams](json) match {
              case Success(params) => splitLogMessage(params)
              case Failure(_)      => Vector()
            }
          } else Vector()
        case (`systemOut`, Some(json)) =>
          Converter.fromJson[Array[Byte]](json) match {
            case Success(bytes) if bytes.nonEmpty && attached.get =>
              synchronized(printStream.write(bytes))
            case _ =>
          }
          Vector.empty
        case (`systemErr`, Some(json)) =>
          Converter.fromJson[Array[Byte]](json) match {
            case Success(bytes) if bytes.nonEmpty && attached.get =>
              synchronized(errorStream.write(bytes))
            case _ =>
          }
          Vector.empty
        case (`systemOutFlush`, _) =>
          synchronized(printStream.flush())
          Vector.empty
        case (`systemErrFlush`, _) =>
          synchronized(errorStream.flush())
          Vector.empty
        case (`promptChannel`, _) =>
          batchMode.set(false)
          Vector.empty
        case ("textDocument/publishDiagnostics", Some(json)) =>
          import sbt.internal.langserver.codec.JsonProtocol._
          Converter.fromJson[PublishDiagnosticsParams](json) match {
            case Success(params) => splitDiagnostics(params); Vector()
            case Failure(_)      => Vector()
          }
        case (`Shutdown`, Some(_))                => Vector.empty
        case (msg, _) if msg.startsWith("build/") => Vector.empty
        case _ =>
          Vector(
            (
              Level.Warn,
              s"unknown event: ${msg.method} " + Serialization.compactPrintJsonOpt(msg.params)
            )
          )
      }
    splitToMessage foreach {
      case (level, msg) => console.appendLog(level, msg)
    }
  }

  def splitLogMessage(params: LogMessageParams): Vector[(Level.Value, String)] = {
    val level = messageTypeToLevel(params.`type`)
    if (level == Level.Debug) Vector()
    else Vector((level, params.message))
  }

  def messageTypeToLevel(severity: Long): Level.Value = {
    severity match {
      case MessageType.Error   => Level.Error
      case MessageType.Warning => Level.Warn
      case MessageType.Info    => Level.Info
      case MessageType.Log     => Level.Debug
    }
  }

  def splitDiagnostics(params: PublishDiagnosticsParams): Vector[(Level.Value, String)] = {
    val uri = new URI(params.uri)
    val f = IO.toFile(uri)

    params.diagnostics map { d =>
      val level = d.severity match {
        case Some(severity) => messageTypeToLevel(severity)
        case _              => Level.Error
      }
      val line = d.range.start.line + 1
      val offset = d.range.start.character + 1
      val msg = s"$f:$line:$offset: ${d.message}"
      (level, msg)
    }
  }

  def onRequest(msg: JsonRpcRequestMessage): Unit = {
    import sbt.protocol.codec.JsonProtocol._
    (msg.method, msg.params) match {
      case (`terminalCapabilities`, Some(json)) =>
        Converter.fromJson[TerminalCapabilitiesQuery](json) match {
          case Success(terminalCapabilitiesQuery) =>
            val response = TerminalCapabilitiesResponse(
              terminalCapabilitiesQuery.boolean
                .map(Terminal.console.getBooleanCapability(_)),
              terminalCapabilitiesQuery.numeric
                .map(c => Option(Terminal.console.getNumericCapability(c)).fold(-1)(_.toInt)),
              terminalCapabilitiesQuery.string
                .map(s => Option(Terminal.console.getStringCapability(s)).getOrElse("null")),
            )
            sendCommandResponse(
              terminalCapabilitiesResponse,
              response,
              msg.id,
            )
          case Failure(_) =>
        }
      case (`terminalPropertiesQuery`, _) =>
        val response = TerminalPropertiesResponse.apply(
          width = Terminal.console.getWidth,
          height = Terminal.console.getHeight,
          isAnsiSupported = Terminal.console.isAnsiSupported,
          isColorEnabled = Terminal.console.isColorEnabled,
          isSupershellEnabled = Terminal.console.isSupershellEnabled,
          isEchoEnabled = Terminal.console.isEchoEnabled
        )
        sendCommandResponse(terminalPropertiesResponse, response, msg.id)
      case (`setTerminalAttributes`, Some(json)) =>
        Converter.fromJson[TerminalSetAttributesCommand](json) match {
          case Success(attributes) =>
            val attrs = Map(
              "iflag" -> attributes.iflag,
              "oflag" -> attributes.oflag,
              "cflag" -> attributes.cflag,
              "lflag" -> attributes.lflag,
              "cchars" -> attributes.cchars,
            )
            Terminal.console.setAttributes(attrs)
            sendCommandResponse("", TerminalSetAttributesResponse(), msg.id)
          case Failure(_) =>
        }
      case (`getTerminalAttributes`, _) =>
        val attrs = Terminal.console.getAttributes
        val response = TerminalAttributesResponse(
          iflag = attrs.getOrElse("iflag", ""),
          oflag = attrs.getOrElse("oflag", ""),
          cflag = attrs.getOrElse("cflag", ""),
          lflag = attrs.getOrElse("lflag", ""),
          cchars = attrs.getOrElse("cchars", ""),
        )
        sendCommandResponse("", response, msg.id)
      case (`terminalGetSize`, _) =>
        val response = TerminalGetSizeResponse(
          Terminal.console.getWidth,
          Terminal.console.getHeight,
        )
        sendCommandResponse("", response, msg.id)
      case (`terminalSetSize`, Some(json)) =>
        Converter.fromJson[TerminalSetSizeCommand](json) match {
          case Success(size) =>
            Terminal.console.setSize(size.width, size.height)
            sendCommandResponse("", TerminalSetSizeResponse(), msg.id)
          case Failure(_) =>
        }
      case (`terminalSetEcho`, Some(json)) =>
        Converter.fromJson[TerminalSetEchoCommand](json) match {
          case Success(echo) =>
            Terminal.console.setEchoEnabled(echo.toggle)
            sendCommandResponse("", TerminalSetEchoResponse(), msg.id)
          case Failure(_) =>
        }
      case (`terminalSetRawMode`, Some(json)) =>
        Converter.fromJson[TerminalSetRawModeCommand](json) match {
          case Success(raw) =>
            if (raw.toggle) Terminal.console.enterRawMode()
            else Terminal.console.exitRawMode()
            sendCommandResponse("", TerminalSetRawModeResponse(), msg.id)
          case Failure(_) =>
        }
      case _ =>
    }
  }

  def connect(log: Boolean, promptCompleteUsers: Boolean): Boolean = {
    if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
    try {
      init(promptCompleteUsers, retry = true)
      true
    } catch {
      case _: ServerFailedException =>
        console.appendLog(Level.Error, "failed to connect to server")
        false
    }
  }

  private[this] val contHandler: () => Unit = () => {
    if (Terminal.console.getLastLine.nonEmpty)
      printStream.print(ConsoleAppender.DeleteLine + Terminal.console.getLastLine.get)
  }
  private[this] def withSignalHandler[R](handler: () => Unit, sig: String)(f: => R): R = {
    val registration = Signals.register(handler, sig)
    try f
    finally registration.remove()
  }
  private[this] val cancelled = new AtomicBoolean(false)

  def run(): Int =
    withSignalHandler(contHandler, Signals.CONT) {
      interactiveThread.set(Thread.currentThread)
      val cleaned = arguments.commandArguments
      val userCommands = cleaned.takeWhile(_ != TerminateAction)
      val interactive = cleaned.isEmpty
      val exit = cleaned.nonEmpty && userCommands.isEmpty
      attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
      val handler: () => Unit = () => {
        def exitAbruptly() = {
          exitClean.set(false)
          close()
        }
        if (cancelled.compareAndSet(false, true)) {
          val cancelledTasks = {
            val queue = sendCancelAllCommand()
            Option(queue.poll(1, TimeUnit.SECONDS)).getOrElse(true)
          }
          if ((batchMode.get && pendingResults.isEmpty) || !cancelledTasks) exitAbruptly()
          else cancelled.set(false)
        } else exitAbruptly() // handles double ctrl+c to force a shutdown
      }
      withSignalHandler(handler, Signals.INT) {
        def block(): Int = {
          try this.synchronized(this.wait)
          catch { case _: InterruptedException => }
          if (exitClean.get) 0 else 1
        }
        console.appendLog(Level.Info, "terminate the server with `shutdown`")
        if (interactive) {
          console.appendLog(Level.Info, "disconnect from the server with `exit`")
          block()
        } else if (exit) 0
        else {
          batchMode.set(true)
          val res = batchExecute(userCommands.toList)
          if (!batchMode.get) block() else res
        }
      }
    }

  def batchExecute(userCommands: List[String]): Int = {
    val cmd = userCommands mkString " "
    printStream.println("> " + cmd)
    sendAndWait(cmd, None)
  }

  def getCompletions(query: String): Seq[String] = {
    val quoteCount = query.foldLeft(0) {
      case (count, '"') => count + 1
      case (count, _)   => count
    }
    val inQuote = quoteCount % 2 != 0
    val (rawPrefix, prefix, rawSuffix, suffix) = if (quoteCount > 0) {
      query.lastIndexOf('"') match {
        case -1 => (query, query, None, None) // shouldn't happen
        case i =>
          val rawPrefix = query.substring(0, i)
          val prefix = rawPrefix.replace("\"", "").replace("\\;", ";")
          val rawSuffix = query.substring(i).replace("\\;", ";")
          val suffix = if (rawSuffix.length > 1) rawSuffix.substring(1) else ""
          (rawPrefix, prefix, Some(rawSuffix), Some(suffix))
      }
    } else (query, query.replace("\\;", ";"), None, None)
    val tailSpace = query.endsWith(" ") || query.endsWith("\"")
    val sanitizedQuery = suffix.foldLeft(prefix) { _ + _ }
    def getCompletions(query: String, sendCommand: Boolean): Seq[String] = {
      val result = new LinkedBlockingQueue[CompletionResponse]()
      val json = s"""{"query":"$query","level":1}"""
      val execId = sendJson("sbt/completion", json)
      pendingCompletions.put(execId, result.put)
      val response = result.poll(30, TimeUnit.SECONDS) match {
        case null => throw new TimeoutException("no response from server within 30 seconds")
        case r    => r
      }
      def fillCompletions(label: String, regex: String, command: String): Seq[String] = {
        def updateCompletions(): Seq[String] = {
          errorStream.println()
          sendJson(attach, s"""{"interactive": false}""")
          sendAndWait(query.replaceAll(regex + ".*", command).trim, None)
          getCompletions(query, false)
        }
        if (noStdErr) Nil
        else if (noTab) updateCompletions()
        else {
          errorStream.print(s"\nNo cached $label names found. Press '' to compile: ")
          startInputThread()
          stdinBytes.poll(5, TimeUnit.SECONDS) match {
            case null        => Nil
            case i if i == 9 => updateCompletions()
            case _           => Nil
          }
        }
      }
      val testNameCompletions =
        if (!response.cachedTestNames.getOrElse(true) && sendCommand)
          fillCompletions("test", "test(Only|Quick)", "definedTestNames")
        else Nil
      val classNameCompletions =
        if (!response.cachedMainClassNames.getOrElse(true) && sendCommand)
          fillCompletions("main class", "runMain", "discoveredMainClasses")
        else Nil
      val completions = response.items
      testNameCompletions ++ classNameCompletions ++ completions
    }
    getCompletions(sanitizedQuery, true) collect {
      case c if inQuote                      => c
      case c if tailSpace && c.contains(" ") => c.replace(prefix, "")
      case c if !tailSpace                   => c.split(" ").last
    }
  }

  private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = {
    val queue = sendExecCommand(cmd)
    var result: Integer = null
    while (running.get && result == null && limit.fold(true)(!_.isOverdue())) {
      try {
        result = limit match {
          case Some(l) => queue.poll((l - Deadline.now).toMillis, TimeUnit.MILLISECONDS)
          case _       => queue.take
        }
      } catch {
        case _: InterruptedException if cmd == Shutdown => result = 0
        case _: InterruptedException                    => result = if (exitClean.get) 0 else 1
      }
    }
    if (result == null) 1 else result
  }

  def sendExecCommand(commandLine: String): LinkedBlockingQueue[Integer] = {
    val execId = UUID.randomUUID.toString
    val queue = new LinkedBlockingQueue[Integer]
    sendCommand(ExecCommand(commandLine, execId))
    pendingResults.put(execId, (queue, System.currentTimeMillis, commandLine))
    queue
  }

  def sendCancelAllCommand(): LinkedBlockingQueue[Boolean] = {
    val queue = new LinkedBlockingQueue[Boolean]
    val execId = sendJson(cancelRequest, s"""{"id":"$CancelAll"}""")
    pendingCancellations.put(execId, queue)
    queue
  }

  def sendCommand(command: CommandMessage): Unit = {
    try {
      val s = Serialization.serializeCommandAsJsonMessage(command)
      connection.sendString(s)
      lock.synchronized {
        status.set("Processing")
      }
    } catch {
      case e: SocketException if command.toString.contains("exit") => running.set(false)
      case e: IOException =>
        errorStream.println(s"Caught exception writing command to server: $e")
        running.set(false)
    }
  }
  def sendCommandResponse(method: String, command: EventMessage, id: String): Unit = {
    try {
      val s = new String(Serialization.serializeEventMessage(command))
      val msg = s"""{ "jsonrpc": "2.0", "id": "$id", "result": $s }"""
      connection.sendString(msg)
    } catch {
      case e: IOException =>
        errorStream.println(s"Caught exception writing command to server: $e")
        running.set(false)
    }
  }
  def sendJson(method: String, params: String): String = {
    val uuid = UUID.randomUUID.toString
    sendJson(method, params, uuid)
    uuid
  }
  def sendJson(method: String, params: String, uuid: String): Unit = {
    val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }"""
    connection.sendString(msg)
  }

  def sendNotification(method: String, params: String): Unit = {
    connection.sendString(s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params }""")
  }

  override def close(): Unit =
    try {
      running.set(false)
      stdinBytes.offer(-1)
      val mainThread = interactiveThread.getAndSet(null)
      if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt
      connectionHolder.get match {
        case null =>
        case c =>
          try sendExecCommand("exit")
          finally c.shutdown()
      }
      Option(inputThread.get).foreach(_.interrupt())
    } catch {
      case t: Throwable => t.printStackTrace(); throw t
    }

  private[this] class RawInputThread extends Thread("sbt-read-input-thread") with AutoCloseable {
    setDaemon(true)
    start()
    val stopped = new AtomicBoolean(false)
    override final def run(): Unit = {
      def read(): Unit = {
        val b = inputStream.read
        inLock.synchronized(stdinBytes.offer(b))
        if (attached.get()) drain()
      }
      try read()
      catch { case _: InterruptedException | NonFatal(_) => stopped.set(true) } finally {
        inputThread.set(null)
      }
    }

    def drain(): Unit = inLock.synchronized {
      while (!stdinBytes.isEmpty) {
        val byte = stdinBytes.poll()
        sendNotification(systemIn, byte.toString)
      }
    }

    override def close(): Unit = {
      RawInputThread.this.interrupt()
    }
  }

  // copied from Aggregation
  private def timing(startTime: Long, endTime: Long): String = {
    import java.text.DateFormat
    val format = DateFormat.getDateTimeInstance(DateFormat.MEDIUM, DateFormat.MEDIUM)
    val nowString = format.format(new java.util.Date(endTime))
    val total = math.max(0, (endTime - startTime + 500) / 1000)
    val totalString = s"$total s" +
      (if (total <= 60) ""
       else {
         val maybeHours = total / 3600 match {
           case 0 => ""
           case h => f"$h%02d:"
         }
         val mins = f"${total % 3600 / 60}%02d"
         val secs = f"${total % 60}%02d"
         s" ($maybeHours$mins:$secs)"
       })
    s"Total time: $totalString, completed $nowString"
  }
}

object NetworkClient {
  private[sbt] val CancelAll = "__CancelAll"
  private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
    val appender = ConsoleAppender("thin", ConsoleOut.printStreamOut(printStream))
    new ConsoleInterface {
      override def appendLog(level: Level.Value, message: => String): Unit =
        appender.appendLog(level, message)
      override def success(msg: String): Unit = appender.success(msg)
    }
  }
  private def simpleConsoleInterface(doPrintln: String => Unit): ConsoleInterface =
    new ConsoleInterface {
      import scala.Console.{ GREEN, RED, RESET, YELLOW }
      override def appendLog(level: Level.Value, message: => String): Unit = synchronized {
        val prefix = level match {
          case Level.Error => s"[$RED$level$RESET]"
          case Level.Warn  => s"[$YELLOW$level$RESET]"
          case _           => s"[$RESET$level$RESET]"
        }
        message.linesIterator.foreach(line => doPrintln(s"$prefix $line"))
      }
      override def success(msg: String): Unit = doPrintln(s"[${GREEN}success$RESET] $msg")
    }
  private[client] class Arguments(
      val baseDirectory: File,
      val sbtArguments: Seq[String],
      val commandArguments: Seq[String],
      val completionArguments: Seq[String],
      val sbtScript: String,
      val bsp: Boolean,
      val sbtLaunchJar: Option[String],
  ) {
    def withBaseDirectory(file: File): Arguments =
      new Arguments(
        file,
        sbtArguments,
        commandArguments,
        completionArguments,
        sbtScript,
        bsp,
        sbtLaunchJar
      )
  }
  private[client] val completions = "--completions"
  private[client] val noTab = "--no-tab"
  private[client] val noStdErr = "--no-stderr"
  private[client] val sbtBase = "--sbt-base-directory"
  private[client] def parseArgs(args: Array[String]): Arguments = {
    val defaultSbtScript = if (Properties.isWin) "sbt.bat" else "sbt"
    var sbtScript = Properties.propOrNone("sbt.script")
    var launchJar: Option[String] = None
    var bsp = false
    val commandArgs = new mutable.ArrayBuffer[String]
    val sbtArguments = new mutable.ArrayBuffer[String]
    val completionArguments = new mutable.ArrayBuffer[String]
    val SysProp = "-D([^=]+)=(.*)".r
    val sanitized = args.flatMap {
      case a if a.startsWith("\"") => Array(a)
      case a                       => a.split(" ")
    }
    var i = 0
    while (i < sanitized.length) {
      sanitized(i) match {
        case a if completionArguments.nonEmpty => completionArguments += a
        case a if commandArgs.nonEmpty         => commandArgs += a
        case a if a == noStdErr || a == noTab || a.startsWith(completions) =>
          completionArguments += a
        case a if a.startsWith("--sbt-script=") =>
          sbtScript = a
            .split("--sbt-script=")
            .lastOption
            .orElse(sbtScript)
        case "--sbt-script" if i + 1 < sanitized.length =>
          i += 1
          sbtScript = Some(sanitized(i))
        case a if a.startsWith("--sbt-launch-jar=") =>
          launchJar = a
            .split("--sbt-launch-jar=")
            .lastOption
            .map(_.replace("%20", " "))
        case "--sbt-launch-jar" if i + 1 < sanitized.length =>
          i += 1
          launchJar = Option(sanitized(i).replace("%20", " "))
        case "-bsp" | "--bsp"        => bsp = true
        case a if !a.startsWith("-") => commandArgs += a
        case a @ SysProp(key, value) =>
          System.setProperty(key, value)
          sbtArguments += a
        case a => sbtArguments += a
      }
      i += 1
    }
    val base = new File("").getCanonicalFile
    if (!sbtArguments.contains("-Dsbt.io.virtual=true")) sbtArguments += "-Dsbt.io.virtual=true"
    if (!sbtArguments.exists(_.startsWith("-Dsbt.script"))) {
      sbtScript.foreach { sbtScript =>
        sbtArguments += s"-Dsbt.script=$sbtScript"
      }
    }
    new Arguments(
      base,
      sbtArguments.toSeq,
      commandArgs.toSeq,
      completionArguments.toSeq,
      sbtScript.getOrElse(defaultSbtScript).replace("%20", " "),
      bsp,
      launchJar
    )
  }

  def client(
      baseDirectory: File,
      args: Array[String],
      inputStream: InputStream,
      printStream: PrintStream,
      errorStream: PrintStream,
      useJNI: Boolean
  ): Int = {
    val client =
      simpleClient(
        NetworkClient.parseArgs(args).withBaseDirectory(baseDirectory),
        inputStream,
        printStream,
        errorStream,
        useJNI,
      )
    try {
      if (client.connect(log = true, promptCompleteUsers = false)) client.run()
      else 1
    } catch { case _: Exception => 1 } finally client.close()
  }
  def client(
      baseDirectory: File,
      args: Arguments,
      inputStream: InputStream,
      errorStream: PrintStream,
      terminal: Terminal,
      useJNI: Boolean
  ): Int = {
    val printStream = if (args.bsp) errorStream else terminal.printStream
    val client =
      simpleClient(
        args.withBaseDirectory(baseDirectory),
        inputStream,
        printStream,
        errorStream,
        useJNI,
      )
    clientImpl(client, args.bsp)
  }
  private def clientImpl(client: NetworkClient, isBsp: Boolean): Int = {
    try {
      if (isBsp) {
        val (socket, _) =
          client.connectOrStartServerAndConnect(promptCompleteUsers = false, retry = true)
        BspClient.bspRun(socket)
      } else {
        if (client.connect(log = true, promptCompleteUsers = false)) client.run()
        else 1
      }
    } catch { case _: Exception => 1 } finally client.close()
  }
  def client(
      baseDirectory: File,
      args: Array[String],
      inputStream: InputStream,
      errorStream: PrintStream,
      terminal: Terminal,
      useJNI: Boolean
  ): Int = client(baseDirectory, parseArgs(args), inputStream, errorStream, terminal, useJNI)

  private def simpleClient(
      arguments: Arguments,
      inputStream: InputStream,
      errorStream: PrintStream,
      useJNI: Boolean,
      terminal: Terminal
  ): NetworkClient = {
    val doPrint: String => Unit = line => {
      if (terminal.getLastLine.isDefined) terminal.printStream.println()
      terminal.printStream.println(line)
    }
    val interface = NetworkClient.simpleConsoleInterface(doPrint)
    val printStream = terminal.printStream
    new NetworkClient(arguments, interface, inputStream, errorStream, printStream, useJNI)
  }
  private def simpleClient(
      arguments: Arguments,
      inputStream: InputStream,
      printStream: PrintStream,
      errorStream: PrintStream,
      useJNI: Boolean,
  ): NetworkClient = {
    val interface = NetworkClient.simpleConsoleInterface(printStream.println)
    new NetworkClient(arguments, interface, inputStream, errorStream, printStream, useJNI)
  }
  def main(args: Array[String]): Unit = {
    val (jnaArg, restOfArgs) = args.partition(_ == "--jna")
    val useJNI = jnaArg.isEmpty
    val base = new File("").getCanonicalFile
    if (restOfArgs.exists(_.startsWith(NetworkClient.completions)))
      System.exit(complete(base, restOfArgs, useJNI, System.in, System.out))
    else {
      val hook = new Thread(() => {
        System.out.print(ConsoleAppender.ClearScreenAfterCursor)
        System.out.flush()
      })
      Runtime.getRuntime.addShutdownHook(hook)
      if (Util.isNonCygwinWindows) sbt.internal.util.JLine3.forceWindowsJansi()
      val parsed = parseArgs(restOfArgs)
      System.exit(Terminal.withStreams(isServer = false, isSubProcess = false) {
        val term = Terminal.console
        try client(base, parsed, term.inputStream, System.err, term, useJNI)
        catch { case _: AccessDeniedException => 1 } finally {
          Runtime.getRuntime.removeShutdownHook(hook)
          hook.run()
        }
      })
    }
  }
  def complete(
      baseDirectory: File,
      args: Array[String],
      useJNI: Boolean,
      in: InputStream,
      out: PrintStream
  ): Int = {
    val cmd: String = args.find(_.startsWith(NetworkClient.completions)) match {
      case Some(c) =>
        c.split('=').lastOption match {
          case Some(query) =>
            query.indexOf(" ") match {
              case -1 => throw new IllegalArgumentException(query)
              case i  => query.substring(i + 1)
            }
          case _ => throw new IllegalArgumentException(c)
        }
      case _ => throw new IllegalStateException("should be unreachable")
    }
    val quiet = args.exists(_ == "--quiet")
    val errorStream = if (quiet) new PrintStream(_ => {}, false) else System.err
    val sbtArgs = args.takeWhile(!_.startsWith(NetworkClient.completions))
    val arguments = NetworkClient.parseArgs(sbtArgs)
    val noTab = args.contains("--no-tab")
    try {
      val client =
        simpleClient(
          arguments.withBaseDirectory(baseDirectory),
          inputStream = in,
          errorStream = errorStream,
          printStream = errorStream,
          useJNI = useJNI,
        )
      try {
        val results =
          if (client.connect(log = false, promptCompleteUsers = true)) client.getCompletions(cmd)
          else Nil
        out.println(results.sorted.distinct mkString "\n")
        0
      } catch { case _: Exception => 1 } finally client.close()
    } catch { case _: AccessDeniedException => 1 }
  }

  def run(configuration: xsbti.AppConfiguration, arguments: List[String]): Int =
    run(configuration, arguments, false)
  def run(
      configuration: xsbti.AppConfiguration,
      arguments: List[String],
      redirectOutput: Boolean
  ): Int = {
    val term = Terminal.console
    val err = new PrintStream(term.errorStream)
    val out = if (redirectOutput) err else new PrintStream(term.outputStream)
    val args = parseArgs(arguments.toArray).withBaseDirectory(configuration.baseDirectory)
    val useJNI = BootServerSocket.requiresJNI || System.getProperty("sbt.ipcsocket.jni", "false") == "true"
    val client = simpleClient(args, term.inputStream, out, err, useJNI = useJNI)
    clientImpl(client, args.bsp)
  }
  private class AccessDeniedException extends Throwable
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy