All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scalaz.zio.interop.reactiveStreams.StreamPublisher.scala Maven / Gradle / Ivy

package scalaz.zio.interop.reactiveStreams

import org.reactivestreams.{ Publisher, Subscriber, Subscription }
import scalaz.zio.stream.Sink.Step
import scalaz.zio.stream.{ Sink, ZStream }
import scalaz.zio.{ Queue, Runtime, UIO, ZIO }

class StreamPublisher[R, E <: Throwable, A](
  stream: ZStream[R, E, A],
  runtime: Runtime[R]
) extends Publisher[A] {

  override def subscribe(subscriber: Subscriber[_ >: A]): Unit =
    if (subscriber == null) {
      throw new NullPointerException("Subscriber must not be null.")
    } else {
      runtime.unsafeRunAsync_(
        for {
          demand <- Queue.unbounded[Long]
          _      <- UIO(subscriber.onSubscribe(createSubscription(subscriber, demand)))
          fiber <- stream
                    .run(demandUnfoldSink(subscriber, demand))
                    .flatMap(_ => UIO(subscriber.onComplete()))
                    .catchAll(e => UIO(subscriber.onError(e)))
                    .flatMap(_ => demand.shutdown)
                    .fork
          // reactive streams rule 3.13
          _ <- (demand.awaitShutdown *> fiber.interrupt).fork
        } yield ()
      )
    }

  private def demandUnfoldSink(subscriber: Subscriber[_ >: A], demand: Queue[Long]): Sink[Nothing, A, A, Unit] =
    new Sink[Nothing, A, A, Unit] {
      override type State = Long

      override def initial: UIO[Step[Long, Nothing]] = UIO(Step.more(0L))

      override def step(state: Long, a: A): UIO[Step[Long, A]] =
        if (state > 0) {
          UIO(subscriber.onNext(a)).map(_ => Step.more(state - 1))
        } else {
          for {
            n <- demand.take
            _ <- UIO(subscriber.onNext(a))
          } yield Step.more(n - 1)
        }

      override def extract(state: Long): UIO[Unit] = UIO.unit

    }

  private def createSubscription(subscriber: Subscriber[_ >: A], demandQ: Queue[Long]): Subscription =
    new Subscription {
      override def request(n: Long): Unit = {
        if (n <= 0) subscriber.onError(new IllegalArgumentException("n must be > 0"))
        runtime.unsafeRunAsync_(demandQ.offer(n).void)
      }
      override def cancel(): Unit = runtime.unsafeRun(demandQ.shutdown)
    }
}

object StreamPublisher {
  def sinkToPublisher[R, E <: Throwable, A](stream: ZStream[R, E, A]): ZIO[R, Nothing, StreamPublisher[R, E, A]] =
    ZIO.runtime.map(new StreamPublisher(stream, _))
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy