All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.zeal18.zio.mongodb.driver.reactivestreams.IterableSubscriber.scala Maven / Gradle / Ivy

package io.github.zeal18.zio.mongodb.driver.reactivestreams

import org.reactivestreams.Subscription
import zio.Scope
import zio.Task
import zio.Trace
import zio.UIO
import zio.URIO
import zio.ZIO

import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.IterableFactory
import scala.collection.mutable.Builder

private object IterableSubscriber {
  def make[A, I[B] <: Iterable[B], B](
    factory: IterableFactory[I],
  )(implicit trace: Trace): URIO[Scope, InterruptibleSubscriber[A, I[A]]] = for {
    subscriptionP <- ZIO.acquireRelease(
      Promise.make[Throwable, Subscription],
    )(
      _.poll.flatMap(_.fold(ZIO.unit)(_.foldZIO(_ => ZIO.unit, sub => ZIO.succeed(sub.cancel())))),
    )
    promise <- Promise.make[Throwable, I[A]]
  } yield new InterruptibleSubscriber[A, I[A]] {

    val isSubscribedOrInterrupted           = new AtomicBoolean
    val collectionBuilder: Builder[A, I[A]] = factory.newBuilder

    override def interrupt(implicit trace: Trace): UIO[Unit] = {
      isSubscribedOrInterrupted.set(true)
      promise.interrupt.unit
    }

    override def await(implicit trace: Trace): Task[I[A]] = promise.await

    override def onSubscribe(s: Subscription): Unit =
      if (s == null)
        failNPE("s was null in onSubscribe")
      else {
        val shouldCancel = isSubscribedOrInterrupted.getAndSet(true)
        if (shouldCancel) s.cancel()
        else {
          subscriptionP.unsafe.done(ZIO.succeed(s))
          s.request(Int.MaxValue)
        }
      }

    override def onNext(t: A): Unit =
      if (t == null)
        failNPE("t was null in onNext")
      else
        collectionBuilder += t

    override def onError(e: Throwable): Unit =
      if (e == null)
        failNPE("t was null in onError")
      else
        fail(e)

    override def onComplete(): Unit =
      promise.unsafe.done(ZIO.succeed(collectionBuilder.result()))

    private def failNPE(msg: String) = {
      val e = new NullPointerException(msg)
      fail(e)
      throw e
    }

    private def fail(e: Throwable): Unit =
      promise.unsafe.done(ZIO.fail(e))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy