All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.client3.asynchttpclient.WebSocketImpl.scala Maven / Gradle / Ivy

The newest version!
package sttp.client3.asynchttpclient

import java.util.concurrent.atomic.AtomicBoolean

import io.netty.util.concurrent.{Future, FutureListener}
import org.asynchttpclient.ws.{WebSocket => AHCWebSocket, WebSocketListener => AHCWebSocketListener}
import sttp.client3.internal.ws.{SimpleQueue, WebSocketEvent}
import sttp.model.{Header, Headers}
import sttp.monad.syntax._
import sttp.monad.{Canceler, MonadAsyncError}
import sttp.ws.{WebSocket, WebSocketClosed, WebSocketFrame}

import scala.collection.JavaConverters.asScalaIteratorConverter

private[asynchttpclient] class WebSocketImpl[F[_]](
    ws: AHCWebSocket,
    queue: SimpleQueue[F, WebSocketEvent],
    _isOpen: AtomicBoolean,
    implicit val monad: MonadAsyncError[F]
) extends WebSocket[F] {

  override def receive(): F[WebSocketFrame] = {
    queue.poll.flatMap {
      case WebSocketEvent.Open() => receive()
      case WebSocketEvent.Frame(c: WebSocketFrame.Close) =>
        queue.offer(WebSocketEvent.Error(WebSocketClosed(Some(c))))
        monad.unit(c)
      case e @ WebSocketEvent.Error(t: Exception) =>
        // putting back the error so that subsequent invocations end in an error as well, instead of hanging
        queue.offer(e)
        monad.error(t)
      case WebSocketEvent.Error(t)                 => throw t
      case WebSocketEvent.Frame(f: WebSocketFrame) => monad.unit(f)
    }
  }

  override def send(f: WebSocketFrame, isContinuation: Boolean = false): F[Unit] =
    monad.flatten(monad.eval(f match {
      case WebSocketFrame.Text(payload, finalFragment, rsv) if !isContinuation =>
        fromNettyFuture(ws.sendTextFrame(payload, finalFragment, rsv.getOrElse(0)))
      case WebSocketFrame.Text(payload, finalFragment, rsv) /* if isContinuation */ =>
        fromNettyFuture(ws.sendContinuationFrame(payload, finalFragment, rsv.getOrElse(0)))
      case WebSocketFrame.Binary(payload, finalFragment, rsv) if !isContinuation =>
        fromNettyFuture(ws.sendBinaryFrame(payload, finalFragment, rsv.getOrElse(0)))
      case WebSocketFrame.Binary(payload, finalFragment, rsv) /* if isContinuation */ =>
        fromNettyFuture(ws.sendContinuationFrame(payload, finalFragment, rsv.getOrElse(0)))
      case WebSocketFrame.Ping(payload) => fromNettyFuture(ws.sendPingFrame(payload))
      case WebSocketFrame.Pong(payload) => fromNettyFuture(ws.sendPongFrame(payload))
      case WebSocketFrame.Close(statusCode, reasonText) =>
        val wasOpen = _isOpen.getAndSet(false)
        // making close sequentially idempotent
        if (wasOpen) {
          queue.offer(WebSocketEvent.Error(WebSocketClosed(None)))
          fromNettyFuture(ws.sendCloseFrame(statusCode, reasonText))
        } else ().unit
    }))

  override def upgradeHeaders: Headers =
    Headers(ws.getUpgradeHeaders.iteratorAsString().asScala.map(e => Header(e.getKey, e.getValue)).toList)

  override def isOpen(): F[Boolean] = monad.eval(_isOpen.get())

  private def fromNettyFuture(f: io.netty.util.concurrent.Future[Void]): F[Unit] = {
    monad.async { cb =>
      val f2 = f.addListener(new FutureListener[Void] {
        override def operationComplete(future: Future[Void]): Unit = {
          if (future.isSuccess) cb(Right(())) else cb(Left(future.cause()))
        }
      })

      Canceler(() => f2.cancel(true))
    }
  }
}

object WebSocketImpl {
  def newCoupledToAHCWebSocket[F[_]](
      ws: AHCWebSocket,
      queue: SimpleQueue[F, WebSocketEvent]
  )(implicit monad: MonadAsyncError[F]): WebSocket[F] = {
    val isOpen: AtomicBoolean = new AtomicBoolean(true)
    ws.addWebSocketListener(new AddToQueueListener(queue, isOpen))
    new WebSocketImpl(ws, queue, isOpen, monad)
  }
}

class AddToQueueListener[F[_]](queue: SimpleQueue[F, WebSocketEvent], isOpen: AtomicBoolean)
    extends AHCWebSocketListener {
  override def onOpen(websocket: AHCWebSocket): Unit = {
    throw new IllegalStateException("Should never be called!")
  }

  override def onClose(websocket: AHCWebSocket, code: Int, reason: String): Unit = {
    if (isOpen.getAndSet(false)) {
      queue.offer(WebSocketEvent.Frame(WebSocketFrame.Close(code, reason)))
    }
  }

  override def onError(t: Throwable): Unit = {
    if (isOpen.getAndSet(false)) {
      queue.offer(WebSocketEvent.Error(t))
    }
  }

  override def onBinaryFrame(payload: Array[Byte], finalFragment: Boolean, rsv: Int): Unit =
    onFrame(WebSocketFrame.Binary(payload, finalFragment, rsvToOption(rsv)))
  override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = {
    onFrame(WebSocketFrame.Text(payload, finalFragment, rsvToOption(rsv)))
  }

  override def onPingFrame(payload: Array[Byte]): Unit = onFrame(WebSocketFrame.Ping(payload))
  override def onPongFrame(payload: Array[Byte]): Unit = onFrame(WebSocketFrame.Pong(payload))

  private def onFrame(f: WebSocketFrame): Unit =
    try queue.offer(WebSocketEvent.Frame(f))
    catch {
      case e: Exception => onError(e)
    }

  private def rsvToOption(rsv: Int): Option[Int] = if (rsv == 0) None else Some(rsv)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy