All Downloads are FREE. Search and download functionalities are using the official Maven repository.

polynote.kernel.task.package.scala Maven / Gradle / Ivy

The newest version!
package polynote.kernel

import polynote.kernel.environment.{CurrentTask, PublishStatus, TaskRef}
import polynote.kernel.logging.Logging
import polynote.kernel.util.UPublish
import polynote.messages.TinyString
import zio.blocking.Blocking
import zio.clock.Clock
import zio.stream.SubscriptionRef
import zio.{Cause, Fiber, Has, Promise, Queue, RIO, Semaphore, Task, UIO, URIO, ZIO, ZLayer, ZManaged}

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import scala.collection.JavaConverters._

package object task {

  type TaskManager = Has[TaskManager.Service]

  object TaskManager {

    trait Service {

      /**
        * Queue a task, which can access a reference to the TaskInfo and modify it to broadcast updates.
        * When the given task finishes, errors, or is interrupted, a completion message of the appropriate status will
        * be broadcast.
        *
        * Evaluating the returned outer [[Task]] results in queueing of the given task, which will eventually cause the
        * given task to be evaluated. Evaluating the inner [[Task]] results in blocking (asynchronously) until the given
        * task completes.
        *
        * Interrupting the inner Task results in cancelling the queued task, or interrupting it if it's running.
        *
        * Note that status updates are sent somewhat lazily, and for a series of rapid updates to the task status only the
        * last update might get sent.
        */
      def queue[R <: CurrentTask, A, R1 >: R <: Has[_]](
        id: String,
        label: String = "",
        detail: String = "",
        errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause)
      )(task: RIO[R, A])(implicit ev: R1 with CurrentTask <:< R): RIO[R1, Task[A]]

      /**
        * This overload is more useful if the return type needs to be inferred
        */
      def queue_[A, R <: Has[_]](id: String, label: String = "", detail: String = "")(task: RIO[R with CurrentTask, A]): RIO[R, Task[A]] =
        queue[R with CurrentTask, A, R](id, label, detail)(task)

      /**
        * Run the given task. The task will be independent of the task queue, but will have access to a [[TaskInfo]]
        * reference; it can update this reference to broadcast task updates. The first update will be broadcast when the
        * task is evaluated, and a completion update will be broadcast when it completes or fails.
        */
      def run[R <: Has[_], A](id: String, label: String = "", detail: String = "", errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause))(task: RIO[CurrentTask with R, A]): RIO[R, A]

      /**
        * Run the given task as a subtask of the current task.
        * @see [[run]]
        */
      def runSubtask[R <: CurrentTask, A](id: String, label: String = "", detail: String = "", errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause))(task: RIO[R, A]): RIO[R, A]

      /**
        * Register an external task for status broadcasting and cancellation, by providing a function which will receive a
        * function for modifying the [[TaskInfo]] and return a [[UIO]] that cancels the external task upon evaluation. The
        * cancellation task returned by the provided function may be evaluated even after the external task is finished.
        *
        * The [[TaskInfo]] should be updated by invoking the function (TaskInfo => TaskInfo) => Unit, passing as an argument
        * the function which updates the status (e.g. `_.progress(0.5)`).
        *
        * The external task must report itself as being finished by updating the [[TaskInfo]] to a completed or failed
        * state. All updates to the task reference will be broadcast. If [[cancelAll]] is run on this task manager, the
        * given task will be interrupted (if it has not yet reported completion) using the return cancellation task.
        *
        * Returns the [[Fiber]] which updates the task status. Interrupting this fiber results in the cancellation task
        * returned from cancelCallback being evaluated.
        */
      def register(id: String, label: String = "", detail: String = "", parent: Option[String] = None, errorWith: DoneStatus = ErrorStatus)(cancelCallback: ((TaskInfo => TaskInfo) => Unit) => ZIO[Logging, Nothing, Unit]): RIO[Blocking with Clock with Logging, Fiber[Throwable, Unit]]

      /**
        * Cancel all tasks. If a task has not yet begun running, it will simply be cancelled. If a task is already running,
        * it will be interrupted. (see [[ZIO.interrupt]])
        */
      def cancelAll(): UIO[Unit]

      /**
        * Cancel the task with the given ID, and all currently queued or running tasks which were queued or started
        * after the given task was queued or started.
        */
      def cancelTask(id: String): UIO[Unit]

      /**
        * Shut down the task manager and stop executing any tasks in the queue. Cancels all running tasks (see [[cancelAll]])
        */
      def shutdown(): UIO[Unit]

      /**
        * List all currently running tasks
        */
      def list: UIO[List[TaskInfo]]
    }

    private case class QueuedTask(id: String, ready: Promise[Throwable, Unit], done: Promise[Throwable, Unit])
    private case class TaskDescriptor(
      id: String,
      taskInfoRef: TaskRef,
      fiber: Fiber[Throwable, Any],
      counter: Long,
      ready: Option[Promise[Throwable, Unit]]
    ) {
      def cancel: UIO[Unit] = {
        val unqueued = ready match {
          case Some(promise) => promise.interrupt
          case None          => ZIO.succeed(false)
        }

        ZIO.unlessM(unqueued)(fiber.interrupt)
      }
    }

    private class Impl (
      queueing: Semaphore,
      statusUpdates: UPublish[KernelStatusUpdate],
      readyQueue: zio.Queue[QueuedTask],
      process: Fiber[Throwable, Nothing]
    ) extends Service {

      private val taskCounter = new AtomicLong(0)
      private val tasks = new ConcurrentHashMap[String, TaskDescriptor]
      private val updates = statusUpdates.contramap(UpdatedTasks.one)

      private def lbl(id: String, label: String) = if (label.isEmpty) id else label

      override def queue[R <: CurrentTask, A, R1 >: R <: Has[_]](
        id: String, label: String = "",
        detail: String = "",
        errorWith: Cause[Throwable] => TaskInfo => TaskInfo
      )(task: RIO[R, A])(implicit ev: R1 with CurrentTask <:< R): RIO[R1, Task[A]] = queueing.withPermit {
        for {
          statusSubRef  <- SubscriptionRef.make(TaskInfo(id, lbl(id, label), detail, Queued))
          statusRef      = statusSubRef.ref
          remove         = ZIO.effectTotal(tasks.remove(id))
          myTurn        <- Promise.make[Throwable, Unit]
          imDone        <- Promise.make[Throwable, Unit]
          fail           = (err: Cause[Throwable]) => statusRef.update(t => ZIO.succeed(errorWith(err)(t))).ensuring(remove &> imDone.interrupt).uninterruptible
          complete       = statusRef.update(t => ZIO.succeed(t.completed)).ensuring(remove).uninterruptible
          updater       <- statusSubRef.changes.foreachWhile(t => updates.publish(t).as(!t.status.isDone)).uninterruptible.forkDaemon
          taskBody       = ZIO.absolve {
            task.provideSomeLayer[R1](ZLayer.succeed[TaskRef](statusRef))
              .either
              .tap(_.fold(err => fail(Cause.fail(err)), _ => complete)) <* updater.join
          }
          _             <- readyQueue.offer(QueuedTask(id, myTurn, imDone))
          wait           = myTurn.await.onInterrupt(imDone.interrupt)
          runTask        = (wait *> statusRef.update(t => ZIO.succeed(t.running)) *> taskBody)
            .onTermination(fail)
            .ensuring(imDone.succeed(()))
          _             <- statusRef.get >>= updates.publish
          taskFiber     <- runTask.ensuring(remove).forkDaemon
          descriptor     = TaskDescriptor(id, statusRef, taskFiber, taskCounter.getAndIncrement(), Some(myTurn))
          _             <- Option(tasks.put(id, descriptor)).map(_.cancel).getOrElse(ZIO.unit)
        } yield taskFiber.join
      }

      private def runImpl[R <: Has[_], A](
        task: RIO[R with CurrentTask, A],
        id: String,
        label: String,
        detail: String,
        parent: Option[TinyString],
        errorWith: Cause[Throwable] => TaskInfo => TaskInfo
      ): RIO[R, A] = for {
        statusSubRef  <- SubscriptionRef.make(TaskInfo(id, lbl(id, label), detail, Running, progress = 0, parent = parent))
        statusRef      = statusSubRef.ref
        remove         = ZIO.effectTotal(tasks.remove(id))
        updater       <- statusSubRef.changes
          .foreachWhile(t => updates.publish(t).as(!t.status.isDone))
          .uninterruptible
          .ensuring(remove).fork
        taskBody       = task
          .onInterrupt(fibers => statusRef.update(t => ZIO.succeed(errorWith(Cause.interrupt(fibers.headOption.getOrElse(Fiber.Id.None)))(t))).ignore.unit)
          .provideSomeLayer[R](CurrentTask.layer(statusRef))
        _             <- statusRef.get >>= updates.publish
        taskFiber     <- (taskBody <* statusRef.update(t => ZIO.succeed(t.completed)) <* updater.join)
          .onError(cause => statusRef.update(t => ZIO.succeed(errorWith(cause)(t))))
          .onInterrupt(fibers => statusRef.update(t => ZIO.succeed(errorWith(Cause.interrupt(fibers.head))(t))))
          .fork
        descriptor     = TaskDescriptor(id, statusRef, taskFiber, taskCounter.getAndIncrement(), None)
        _             <- Option(tasks.put(id, descriptor)).map(_.cancel).getOrElse(ZIO.unit)
        result        <- taskFiber.join.onInterrupt(taskFiber.interrupt)
      } yield result

      override def run[R <: Has[_], A](id: String, label: String = "", detail: String = "", errorWith: Cause[Throwable] => TaskInfo => TaskInfo)(task: RIO[CurrentTask with R, A]): RIO[R, A] =
        runImpl[R, A](task, id, label, detail, None, errorWith)

      override def runSubtask[R <: CurrentTask, A](id: String, label: String, detail: String, errorWith: Cause[Throwable] => TaskInfo => TaskInfo)(task: RIO[R, A]): RIO[R, A] =
        ZIO.accessM[CurrentTask](_.get.get).map(_.id).flatMap {
          parent =>
            runImpl[R, A](task, id, label, detail, Some(parent), errorWith)
        }

      override def register(id: String, label: String = "", detail: String = "", parent: Option[String], errorWith: DoneStatus)(cancelCallback: ((TaskInfo => TaskInfo) => Unit) => ZIO[Logging, Nothing, Unit]): RIO[Blocking with Clock with Logging, Fiber[Throwable, Unit]] =
        for {
          runtime      <- ZIO.runtime[Any]
          statusSubRef <- SubscriptionRef.make(TaskInfo(id, lbl(id, label), detail, Running, progress = 0, parent = parent.map(TinyString(_))))
          statusRef     = statusSubRef.ref
          updateTasks  <- Queue.unbounded[TaskInfo => TaskInfo]
          completed     = new AtomicBoolean(false)
          updater      <- statusSubRef.changes
            .foreachWhile(t => updates.publish(t).as(!t.status.isDone))
            .uninterruptible
            .ensuring(ZIO.effectTotal(completed.set(true)))
            .forkDaemon
          onUpdate      = (fn: TaskInfo => TaskInfo) => runtime.unsafeRun(updateTasks.offer(fn).unit)
          cancel        = cancelCallback(onUpdate)
          process      <- updateTasks.take.flatMap(updater => statusRef.update(t => ZIO.succeed(updater(t))) *> statusRef.get)
            .repeatUntil(_.status.isDone).unit
            .ensuring(ZIO.effectTotal(tasks.remove(id)))
            .onInterrupt(statusRef.update(t => ZIO.succeed(t.done(errorWith))).ensuring(cancel))
            .forkDaemon
          descriptor    = TaskDescriptor(id, statusRef, process, taskCounter.getAndIncrement(), None)
          _            <- Option(tasks.put(id, descriptor)).map(_.cancel).getOrElse(ZIO.unit)
        } yield process

      override def cancelAll(): UIO[Unit] = {
        val runningFibers    = ZIO.effectTotal(tasks.values().asScala.map(_.fiber))
        val interruptRunning = runningFibers.flatMap {
          fibers => ZIO.foreachPar_(fibers)(_.interruptFork)
        }

        val interruptQueued = for {
          promises <- readyQueue.takeAll
          _        <- ZIO.foreachPar_(promises) {
            case QueuedTask(_, ready, done) => ready.interrupt &> done.interrupt
          }
        } yield ()

        interruptQueued *> interruptRunning
      }

      override def cancelTask(id: String): UIO[Unit] =
        ZIO.effectTotal(Option(tasks.get(id))).some.flatMap {
          descriptor =>
            val tasksAfter = tasks.values().asScala.toSeq.filter(_.counter > descriptor.counter).sortBy(-_.counter)
            ZIO.foreachPar_(tasksAfter)(_.cancel) *> descriptor.cancel
        }.orElseSucceed(())

      override def list: UIO[List[TaskInfo]] = ZIO.foreach(tasks.values().asScala.toList.sortBy(_.counter))(_.taskInfoRef.get)

      override def shutdown(): UIO[Unit] = cancelAll() *> readyQueue.shutdown *> process.interrupt.unit
    }

    def apply(
      statusUpdates: UPublish[KernelStatusUpdate]
    ): Task[TaskManager.Service] = for {
      queueing <- Semaphore.make(1)
      queue    <- Queue.unbounded[QueuedTask]
      run      <- (ZIO.yieldNow *> ZIO.allowInterrupt *> queue.take).flatMap {
        case QueuedTask(_, ready, done) => ready.succeed(()) *> done.await.run
      }.forever.forkDaemon
    } yield new Impl(queueing, statusUpdates, queue, run)

    def make: ZManaged[PublishStatus, Throwable, TaskManager.Service] = for {
      statusUpdates <- ZIO.access[PublishStatus](_.get).toManaged_
      taskManager   <- apply(statusUpdates).toManaged(_.shutdown())
    } yield taskManager

    def layer: ZLayer[PublishStatus, Throwable, TaskManager] = ZLayer.fromManaged(make)

    def access: URIO[TaskManager, TaskManager.Service] = ZIO.access[TaskManager](_.get)

    def of(service: Service): TaskManager = Has(service)

    def queue[R <: CurrentTask, A, R1 >: R <: TaskManager](
      id: String,
      label: String = "",
      detail: String = "",
      errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause)
    )(task: RIO[R, A])(implicit ev: R1 with CurrentTask <:< R): RIO[R1, Task[A]] =
      access.flatMap(_.queue[R, A, R1](id, label, detail, errorWith)(task))

    def run[R <: TaskManager, A](id: String, label: String = "", detail: String = "", errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause))(task: RIO[CurrentTask with R, A]): RIO[R, A] =
      access.flatMap(_.run[R, A](id, label, detail, errorWith)(task))

    def runSubtask[R <: CurrentTask, A](id: String, label: String = "", detail: String = "", errorWith: Cause[Throwable] => TaskInfo => TaskInfo = cause => _.failed(cause))(task: RIO[R, A]): RIO[R with TaskManager, A] =
      access.flatMap(_.runSubtask[R, A](id, label, detail, errorWith)(task))

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy