All Downloads are FREE. Search and download functionalities are using the official Maven repository.

polynote.kernel.remote.RemoteKernel.scala Maven / Gradle / Ivy

The newest version!
package polynote.kernel
package remote

import java.net.{InetAddress, InetSocketAddress}
import java.nio.channels.ClosedChannelException
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import polynote.app.Environment
import polynote.config.PolynoteConfig
import polynote.kernel.Kernel.Factory
import polynote.kernel.environment.{Config, CurrentNotebook, Env, PublishResult, PublishStatus}
import polynote.kernel.interpreter.{Interpreter, InterpreterState}
import polynote.kernel.logging.Logging
import polynote.kernel.task.TaskManager
import polynote.kernel.util.{Publish, RPublish, TPublish, UPublish}
import polynote.messages._
import polynote.runtime.{StreamingDataRepr, TableOp}
import zio.{Cause, Layer, Promise, RIO, Semaphore, Task, UIO, URIO, ZIO, ZLayer}
import zio.blocking.effectBlocking
import zio.duration.Duration
import zio.stream.ZStream

import scala.collection.JavaConverters._

class RemoteKernel[ServerAddress](
  private[polynote] val transport: TransportServer[ServerAddress],
  closing: Semaphore,
  closed: Promise[Throwable, Unit]
) extends Kernel {

  private val requestId = new AtomicInteger(0)
  private def nextReq = requestId.getAndIncrement()

  private case class RequestHandler[R, T](handler: PartialFunction[RemoteRequestResponse, RIO[R, T]], promise: Promise[Throwable, T], env: R) {
    def run(rep: RemoteRequestResponse): UIO[Unit] = handler match {
      case handler if handler isDefinedAt rep => handler(rep).provide(env).to(promise).unit
      case _ => rep match {
        case ErrorResponse(_, err) => promise.fail(err).unit
        case _                     => ZIO.unit
      }
    }
  }

  private val waiting = new ConcurrentHashMap[Int, RequestHandler[_, _]]()

  private def wait[R, T](reqId: Int)(fn: PartialFunction[RemoteRequestResponse, RIO[R, T]]): RIO[R, Task[T]] = for {
    promise <- Promise.make[Throwable, T]
    env     <- ZIO.access[R](identity)
    _       <- ZIO(waiting.put(reqId, RequestHandler(fn, promise, env)))
  } yield promise.await

  private def done[T](reqId: Int, t: T): UIO[T] = ZIO.effectTotal(waiting.remove(reqId)).as(t)

  private def request[R <: BaseEnv, T](req: RemoteRequest)(fn: PartialFunction[RemoteRequestResponse, RIO[R, T]]): RIO[R, T] = closed.isDone.flatMap {
    case false =>
      for {
        promise <- Promise.make[Throwable, T]
        env     <- ZIO.access[R](identity)
        _       <- ZIO(waiting.put(req.reqId, RequestHandler(fn, promise, env)))
        _       <- transport.sendRequest(req)
        result  <- promise.await.onError {
          cause =>
            if (!cause.interruptedOnly)
              Logging.error("Error from remote kernel", cause)
            else
              ZIO.unit
        }
      } yield result
    case true => ZIO.fail(new ClosedChannelException())
  }

  private def withHandler(rep: RemoteRequestResponse)(fn: RequestHandler[_, _] => Task[Unit]): Task[Unit] = Option(waiting.get(rep.reqId)) match {
    case Some(rh@RequestHandler(handler, promise, env)) => fn(rh)
    case _ => ZIO.unit
  }

  private val processResponses = transport.responses.mapM {
    case KernelStatusResponse(status)    => PublishStatus(status)
    case rep@ResultResponse(_, result)   => withHandler(rep)(handler => PublishResult(result).provide(handler.env.asInstanceOf[PublishResult]))
    case rep@ResultsResponse(_, results) => withHandler(rep)(handler => PublishResult(results).provide(handler.env.asInstanceOf[PublishResult]))
    case rep: RemoteRequestResponse      => withHandler(rep)(_.run(rep).ignore)
  }

  def init(): RIO[BaseEnv with GlobalEnv with CellEnv, Unit] = for {
    notebookRef    <- CurrentNotebook.access
    pullUpdates    <- notebookRef.updates.interruptWhen(closed.await.ignore).foreach(transport.sendNotebookUpdate).forkDaemon
    versioned      <- notebookRef.getVersioned
    (ver, notebook) = versioned
    config         <- Config.access
    process        <- processResponses.runDrain.forkDaemon
    startup        <- request(StartupRequest(nextReq, notebook, ver, config)) {
      case Announce(reqId, remoteAddress) => done(reqId, ()) // TODO: may want to keep the remote address to try reconnecting?
    }
  } yield ()

  def queueCell(id: CellID): RIO[BaseEnv with GlobalEnv with CellEnv, Task[Unit]] = {
    val reqId = nextReq
    request(QueueCellRequest(reqId, id)) {
      case UnitResponse(_) => wait(reqId) {
        case RunCompleteResponse(`reqId`) =>
          done(reqId, ())
      }
      case RunCompleteResponse(`reqId`) => wait(reqId) {
        case UnitResponse(_) =>
          done(reqId, ())
      }
    }
  }

  override def cancelAll(): RIO[BaseEnv with TaskManager, Unit] = super.cancelAll() &> request(CancelRequest(nextReq, None)) {
    case UnitResponse(reqId) => done(reqId, ())
  }

  override def cancelTask(taskId: String): RIO[BaseEnv with TaskManager, Unit] = super.cancelTask(taskId) &> request(CancelRequest(nextReq, Some(taskId))) {
    case UnitResponse(reqId) => done(reqId, ())
  }

  override def tasks(): RIO[BaseEnv with TaskManager, List[TaskInfo]] = super.tasks().zip {
    request(ListTasksRequest(nextReq)) {
      case ListTasksResponse(reqId, result) => done(reqId, result)
    }
  }.map {
    case (a, b) => a ++ b
  }

  def completionsAt(id: CellID, pos: Int): RIO[BaseEnv with GlobalEnv with CellEnv, List[Completion]] =
    request(CompletionsAtRequest(nextReq, id, pos)) {
      case CompletionsAtResponse(reqId, result) => done(reqId, result)
    }

  def parametersAt(id: CellID, pos: Int): RIO[BaseEnv with GlobalEnv with CellEnv, Option[Signatures]] =
    request(ParametersAtRequest(nextReq, id, pos)) {
      case ParametersAtResponse(reqId, result) => done(reqId, result)
    }

  def shutdown(): TaskB[Unit] = closing.withPermit {
    ZIO.whenM(ZIO.mapN(closed.isDone, transport.isConnected)(!_ && _)) {
      request(ShutdownRequest(nextReq)) {
        case ShutdownResponse(reqId) =>
          done(reqId, ())
      }.timeout(Duration(10, TimeUnit.SECONDS)).flatMap {
        case Some(()) => ZIO.succeed(())
        case None =>
          Logging.warn("Waited for remote kernel to shut down for 10 seconds; killing the process")
      }
    }.ensuring {
      close().orDie
    }
  }

  def status(): TaskB[KernelBusyState] = request(StatusRequest(nextReq)) {
    case StatusResponse(reqId, status) => done(reqId, status)
  }

  def values(): TaskB[List[ResultValue]] = request(ValuesRequest(nextReq)) {
    case ValuesResponse(reqId, values) => done(reqId, values)
  }

  def getHandleData(handleType: HandleType, handle: Int, count: Int): RIO[BaseEnv with StreamingHandles, Array[ByteVector32]] =
    ZIO.access[StreamingHandles](_.get.sessionID).flatMap {
      sessionID => request(GetHandleDataRequest(nextReq, sessionID, handleType, handle, count)) {
        case GetHandleDataResponse(reqId, data) => done(reqId, data)
      }
    }

  def modifyStream(handleId: Int, ops: List[TableOp]): RIO[BaseEnv with StreamingHandles, Option[StreamingDataRepr]] =
    ZIO.access[StreamingHandles](_.get.sessionID).flatMap {
      sessionID => request(ModifyStreamRequest(nextReq, sessionID, handleId, ops)) {
        case ModifyStreamResponse(reqId, result) => done(reqId, result)
      }
    }

  def releaseHandle(handleType: HandleType, handleId: Int): RIO[BaseEnv with StreamingHandles, Unit] =
    ZIO.access[StreamingHandles](_.get.sessionID).flatMap {
      sessionID => request(ReleaseHandleRequest(nextReq, sessionID, handleType, handleId)) {
        case UnitResponse(reqId) => done(reqId, ())
      }
    }

  override def info(): TaskG[KernelInfo] = request(KernelInfoRequest(nextReq)) {
    case KernelInfoResponse(reqId, info) => done(reqId, info)
  }

  private[this] def close(): TaskB[Unit] = for {
    _ <- closed.succeed(())
    _ <- transport.close()
    _ <- ZIO.foreach_(waiting.values().asScala.toList)(_.promise.interrupt)
    _ <- ZIO.effectTotal(waiting.clear())
  } yield ()

  override def awaitClosed: Task[Unit] = closed.await

}

object RemoteKernel extends Kernel.Factory.Service {
  def apply[ServerAddress](transport: Transport[ServerAddress]): RIO[BaseEnv with GlobalEnv with CellEnv, RemoteKernel[ServerAddress]] = for {
    closed  <- Promise.make[Throwable, Unit]
    closing <- Semaphore.make(1L)
    server  <- transport.serve()
    kernel   = new RemoteKernel(server, closing, closed)
    _ <- (server.awaitClosed.to(closed).ensuring(kernel.shutdown().orDie)).forkDaemon
  } yield kernel

  override def apply(): RIO[BaseEnv with GlobalEnv with CellEnv, Kernel] = apply(
    new SocketTransport(
      new SocketTransport.DeploySubprocess(
        new SocketTransport.DeploySubprocess.DeployJava[LocalKernelFactory])
    ))

  def service[ServerAddress](transport: Transport[ServerAddress]): Kernel.Factory.Service = new Factory.Service {
    override def apply(): RIO[BaseEnv with GlobalEnv with CellEnv, Kernel] = RemoteKernel(transport)
  }

  def factory[ServerAddress](transport: Transport[ServerAddress]): Kernel.Factory.Service = service(transport)
}


class RemoteKernelClient(
  kernel: Kernel,
  requests: ZStream[BaseEnv, Throwable, RemoteRequest],
  publishResponse: RPublish[BaseEnv, RemoteResponse],
  cleanup: TaskB[Unit],
  closed: Promise[Throwable, Unit],
  private[remote] val notebookRef: RemoteNotebookRef // for testing
) {

  private val sessionHandles = new ConcurrentHashMap[Int, StreamingHandles.Service]()

  def run(): RIO[BaseEnv with GlobalEnv with CellEnv with PublishRemoteResponse, Unit] =
    requests
      .map(handleRequest)
      .map(z => ZStream.fromEffect(z))
      .flattenParUnbounded()
      .tap(publishResponse.publish)
      .mapM(closeOnShutdown)
      .haltWhen(closed)
      .runDrain.uninterruptible


  private def closeOnShutdown(rep: RemoteResponse): URIO[BaseEnv, Unit] = rep match {
    case _: ShutdownResponse => close()
    case _ => ZIO.unit
  }

  private def handleRequest(req: RemoteRequest): RIO[BaseEnv with GlobalEnv with CellEnv with PublishRemoteResponse, RemoteResponse] = {
      val response = req match {
        case QueueCellRequest(reqId, cellID)              => kernel.queueCell(cellID).flatMap {
          completed =>
            completed
              .catchAll(err => publishResponse.publish(ResultResponse(reqId, ErrorResult(err))))
              .ensuring(publishResponse.publish(RunCompleteResponse(reqId)).catchAll(Logging.error))
              .forkDaemon
        }.as(UnitResponse(reqId))
        case ListTasksRequest(reqId)                      => kernel.tasks().map(ListTasksResponse(reqId, _))
        case CancelRequest(reqId, None)                   => kernel.cancelAll().as(UnitResponse(reqId))
        case CancelRequest(reqId, Some(taskId))           => kernel.cancelTask(taskId).as(UnitResponse(reqId))
        case CompletionsAtRequest(reqId, cellID, pos)     => kernel.completionsAt(cellID, pos).map(CompletionsAtResponse(reqId, _))
        case ParametersAtRequest(reqId, cellID, pos)      => kernel.parametersAt(cellID, pos).map(ParametersAtResponse(reqId, _))
        case ShutdownRequest(reqId)                       => kernel.shutdown().as(ShutdownResponse(reqId)).uninterruptible
        case StatusRequest(reqId)                         => kernel.status().map(StatusResponse(reqId, _))
        case ValuesRequest(reqId)                         => kernel.values().map(ValuesResponse(reqId, _))
        case GetHandleDataRequest(reqId, sid, ht, hid, c) => kernel.getHandleData(ht, hid, c).map(GetHandleDataResponse(reqId, _)).provideSomeLayer(streamingHandles(sid))
        case ModifyStreamRequest(reqId, sid, hid, ops)    => kernel.modifyStream(hid, ops).map(ModifyStreamResponse(reqId, _)).provideSomeLayer(streamingHandles(sid))
        case ReleaseHandleRequest(reqId, sid,  ht, hid)   => kernel.releaseHandle(ht, hid).as(UnitResponse(reqId)).provideSomeLayer(streamingHandles(sid))
        case KernelInfoRequest(reqId)                     => kernel.info().map(KernelInfoResponse(reqId, _))
        // TODO: Kernel needs an API to release all streaming handles (then we could let go of elements from sessionHandles map; right now they will just accumulate forever)
        case req => ZIO.succeed(UnitResponse(req.reqId))
      }

      response.interruptible.provideSomeLayer[BaseEnv with GlobalEnv with CellEnv with PublishRemoteResponse](RemoteKernelClient.mkResultPublisher(req.reqId))
  }.catchAll {
    err => ZIO.succeed(ErrorResponse(req.reqId, err))
  }

  private def streamingHandles(sessionId: Int): ZLayer[BaseEnv, Throwable, StreamingHandles] =
    ZLayer.fromEffect {
      Option(sessionHandles.get(sessionId)) match {
        case Some(handles) => ZIO.succeed(handles)
        case None => StreamingHandles.make(sessionId).map {
          handles => synchronized {
            Option(sessionHandles.get(sessionId)) match {
              case Some(handles) => handles
              case None =>
                sessionHandles.put(sessionId, handles)
                handles
            }
          }
        }
      }
    }

  def close(): URIO[BaseEnv, Unit] = cleanup.to(closed).unit
}

object RemoteKernelClient extends polynote.app.App {

  override def main(args: List[String]): ZIO[Environment, Nothing, Int] =
    (parseArgs(args) >>= runThrowable).orDie

  def runThrowable(args: Args): RIO[BaseEnv, Int] = tapRunThrowable(args, None)

  // for testing – so we can get out the remote kernel client
  def tapRunThrowable(args: Args, tapClient: Option[zio.Ref[RemoteKernelClient]]): RIO[BaseEnv, Int] = for {
    addr            <- args.getSocketAddress
    kernelFactory   <- args.getKernelFactory
    transport       <- SocketTransport.connectClient(addr)
    baseEnv         <- ZIO.environment[BaseEnv]
    requests         = transport.requests
    publishResponse  = Publish.fn[BaseEnv, Throwable, RemoteResponse](rep => transport.sendResponse(rep).provide(baseEnv))
    localAddress    <- effectBlocking(InetAddress.getLocalHost.getHostAddress)
    firstRequest    <- requests.take(1).runHead.someOrFail(new NoSuchElementException("No messages were received"))
    initial         <- firstRequest match {
      case req@StartupRequest(_, _, _, _) => ZIO.succeed(req)
      case other                          => ZIO.fail(new RuntimeException(s"Handshake error; expected StartupRequest but found ${other.getClass.getName}"))
    }
    _               <- Env.addLayer[BaseEnv, Throwable, InterpreterState](InterpreterState.emptyLayer)
    notebookRef     <- RemoteNotebookRef(initial.globalVersion -> initial.notebook, transport)
    closed          <- Promise.make[Throwable, Unit]
    interpFactories <- interpreter.Loader.load
    _               <- Env.addLayer(mkEnv(notebookRef, firstRequest.reqId, publishResponse, interpFactories, kernelFactory, initial.config))
    kernel          <- kernelFactory.apply()
    client           = new RemoteKernelClient(kernel, requests, publishResponse, transport.close(), closed, notebookRef)
    _               <- tapClient.fold(ZIO.unit)(_.set(client))
    _               <- kernel.init()
    _               <- publishResponse.publish(Announce(initial.reqId, localAddress))
    _               <- client.run().ensuring(client.close())
  } yield 0

  def mkEnv(
    currentNotebook: NotebookRef,
    reqId: Int,
    publishResponse: RPublish[BaseEnv, RemoteResponse],
    interpreterFactories: Map[String, List[Interpreter.Factory]],
    kernelFactory: Factory.Service,
    polynoteConfig: PolynoteConfig
  ): ZLayer[BaseEnv with InterpreterState, Throwable, GlobalEnv with CellEnv with PublishRemoteResponse] = {
    val publishStatus = ZLayer.fromFunction[BaseEnv, UPublish[KernelStatusUpdate]] {
      baseEnv => publishResponse
        .contramap(KernelStatusResponse.apply)
        .catchAll(err => baseEnv.get[Logging.Service].error(None, err))
        .provide(baseEnv)
    }

    val publishRemote: Layer[Nothing, PublishRemoteResponse] = ZLayer.succeed(publishResponse)

    Config.layerOf(polynoteConfig) ++
      CurrentNotebook.layer(currentNotebook) ++
      ZLayer.identity[InterpreterState] ++
      Interpreter.Factories.layer(interpreterFactories) ++
      ZLayer.succeed(kernelFactory) ++
      publishStatus ++
      (publishStatus >>> TaskManager.layer) ++
      ((ZLayer.requires[BaseEnv] ++ publishRemote) >>> mkResultPublisher(reqId)) ++
      publishRemote
  }

  def mkResultPublisher(reqId: Int): ZLayer[BaseEnv with PublishRemoteResponse, Nothing, PublishResult] =
    ZLayer.fromFunction[BaseEnv with PublishRemoteResponse, UPublish[Result]] {
      env => env.get[RPublish[BaseEnv, RemoteResponse]]
        .contramap[Result](result => ResultResponse(reqId, result))
        .catchAll(env.get[Logging.Service].error(None, _))
        .provide(env)
    }


  case class Args(address: Option[String] = None, port: Option[Int] = None, kernelFactory: Option[Kernel.Factory.LocalService] = None) {
    def getSocketAddress: Task[InetSocketAddress] = for {
      address       <- ZIO.fromOption(address).mapError(_ => new IllegalArgumentException("Missing required argument address"))
      port          <- ZIO.fromOption(port).mapError(_ => new IllegalArgumentException("Missing required argument port"))
      socketAddress <- ZIO(new InetSocketAddress(address, port))
    } yield socketAddress

    def getKernelFactory: Task[Kernel.Factory.LocalService] = ZIO.succeed(kernelFactory.getOrElse(LocalKernel))
  }

  private def parseArgs(args: List[String], current: Args = Args(), sideEffects: ZIO[Logging, Nothing, Unit] = ZIO.unit): RIO[Logging, Args] =
    args match {
      case Nil => sideEffects *> ZIO.succeed(current)
      case "--address" :: address :: rest => parseArgs(rest, current.copy(address = Some(address)))
      case "--port" :: port :: rest => parseArgs(rest, current.copy(port = Some(port.toInt)))
      case "--kernelFactory" :: factory :: rest => parseKernelFactory(factory).flatMap(factory => parseArgs(rest, current.copy(kernelFactory = Some(factory))))
      case other :: rest => parseArgs(rest, sideEffects = sideEffects *> Logging.warn(s"Ignoring unknown argument $other"))
    }

  private def parseKernelFactory(str: String) = for {
    factoryClass <- ZIO(Class.forName(str).asSubclass(classOf[Kernel.Factory.LocalService]))
    factoryInst  <- ZIO(factoryClass.newInstance())
  } yield factoryInst

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy