All Downloads are FREE. Search and download functionalities are using the official Maven repository.

japgolly.webapputil.websocket.WebSocketClient.scala Maven / Gradle / Ivy

There is a newer version: 2.0.0-RC12
Show newest version
package japgolly.webapputil.websocket

import japgolly.scalajs.react.{AsyncCallback, Callback, CallbackTo}
import japgolly.univeq._
import japgolly.webapputil.binary._
import japgolly.webapputil.general.JsExt._
import japgolly.webapputil.general._
import japgolly.webapputil.websocket.WebSocket
import japgolly.webapputil.websocket.WebSocket.ReadyState
import japgolly.webapputil.websocket.WebSocketShared._
import org.scalajs.dom.{CloseEvent, Event, MessageEvent}
import scala.annotation.{elidable, nowarn}
import scala.scalajs.js
import scala.scalajs.js.timers.SetTimeoutHandle
import scala.scalajs.js.typedarray.ArrayBuffer
import scala.util.{Failure, Success, Try}

trait WebSocketClient[Codec[_], ReqRes <: Protocol.RequestResponse[Codec]] {
  val readyState: CallbackTo[Option[ReadyState]]
  val keepAlive: Callback
  def send(p: ReqRes)(request: p.RequestType): CallbackTo[AsyncCallback[p.ResponseType]]
  def asyncFunction(p: ReqRes): AsyncFunction[p.RequestType, ErrorMsg, p.ResponseType]
  def connect: Callback
  val close: Callback
}

object WebSocketClient {

  type OnDecodeError = CallbackTo[Option[CloseReason]]

  trait Builder[Codec[_], ReqRes <: Protocol.RequestResponse[Codec], Push] {
    def build(reauthorise  : AsyncCallback[Permission],
              onServerPush : Push => Callback,
              onStateChange: WebSocketClient[Codec, ReqRes] => State => Callback,
              retries      : Retries,
              sessionExpiry: VarJs[Boolean],
              timers       : TimersJs,
              logger       : LoggerJs): WebSocketClient[Codec, ReqRes]
  }

  object Builder {

    def apply[Codec[_]](protocol     : Protocol.WebSocket.ClientReqServerPush[Codec])
                       (mkProtocolCS : Codec[protocol.Req] => Protocol.Of[Codec, ClientToServer[protocol.Req]],
                        mkProtocolSC : (ReqId => Option[Protocol[Codec]], Codec[protocol.Push]) => Protocol.Of[Codec, ServerToClient[Codec, protocol.Push]],
                        createWS     : CallbackTo[WebSocket],
                        codecEngine  : CodecEngine[Codec, OnDecodeError],
                       ): Builder[Codec, protocol.ReqRes, protocol.Push] =
      new Builder[Codec, protocol.ReqRes, protocol.Push] {
        override def build(reauthorise  : AsyncCallback[Permission],
                           onServerPush : protocol.Push => Callback,
                           onStateChange: WebSocketClient[Codec, protocol.ReqRes] => State => Callback,
                           retries      : Retries,
                           sessionExpiry: VarJs[Boolean],
                           timers       : TimersJs,
                           logger       : LoggerJs) =
          new Impl[Codec, protocol.Req, protocol.ReqRes, protocol.Push](
            createWS,
            retries,
            reauthorise,
            onStateChange,
            codecEngine,
            mkProtocolCS(protocol.req.codec),
            mkProtocolSC(_, protocol.push.codec),
            onServerPush,
            timers,
            sessionExpiry,
            logger)
      }
  }

  object CloseReasons {
    val parseError = CloseReason(CloseCode.protocolError, CloseReasonPhrase("Failed to parse server response"))
  }

  sealed trait State
  object State {
    case object Unauthorised extends State
    final case class Authorised(readyState: ReadyState) extends State

    implicit def univEq: UnivEq[State] = UnivEq.derive
  }

  // ===================================================================================================================

  final class Impl[
      Codec[_],
      Req,
      ReqRes <: Protocol.RequestResponse[Codec] { type PreparedRequestType = Req },
      Push](
      createWS         : CallbackTo[WebSocket],
      connectionRetries: Retries,
      reauthorise      : AsyncCallback[Permission],
      onStateChange    : WebSocketClient[Codec, ReqRes] => State => Callback,
      codecEngine      : CodecEngine[Codec, OnDecodeError],
      protocolCS       : Protocol.Of[Codec, ClientToServer[Req]],
      mkProtocolSC     : (ReqId => Option[Protocol[Codec]]) => Protocol.Of[Codec, ServerToClient[Codec, Push]],
      recvPush         : Push => Callback,
      timers           : TimersJs,
      sessionExpiry    : VarJs[Boolean],
      logger           : LoggerJs) extends WebSocketClient[Codec, ReqRes] { self =>

    private val requestManager: RequestManager[ReqId, Protocol.AndValue[Codec], Request[Codec, ReqRes]] =
      RequestManager.arrayStore

    private val protocolSC: Protocol.Of[Codec, ServerToClient[Codec, Push]] =
      mkProtocolSC(requestManager.get(_).runNow().map(_.state.prep.response))

    /** All state for the WebSocket client.
      *
      * This is impure because of .instance.
      *
      * Unfortunately splitting this into multiple, clearer cases isn't an option. Even if you just had Connected
      * and Disconnected as cases, Connected holds a WebSocket instance which has it's own mutable state, and can change
      * to Closed without gives us a chance to change the state over to Disconnected.
      *
      * Correctness, confidence and sanity are instead achieved by formal specification in `websocket_client.tla`.
      *
      * @param authorised Whether we think we're logged in or not.
      * @param instance An optional WebSocket instance that may or may not be connected. See it's readyState.
      * @param retries Remaining retries when attempting to reconnect.
      * @param scheduled A scheduled task that will attempt to (re)connect when it eventually executes.
      * @param prevState The last [[State]] used to notify readyState-change listeners. Used to prevent sending
      *                  consecutive, identical notifications.
      */
    private case class InternalState(authorised: Boolean,
                                     instance  : Option[Instance],
                                     retries   : Retries,
                                     scheduled : Option[SetTimeoutHandle],
                                     prevState : Option[State])

    private var state = InternalState(
      authorised = true,
      instance   = None,
      retries    = connectionRetries,
      scheduled  = None,
      prevState  = None)

    private var queueOldestToNewest = Vector.empty[(ReqId, BinaryData)]

    private val readyStateCB: CallbackTo[Option[ReadyState]] =
      CallbackTo(state.instance.map(_.readyState()))

    private def setPublicState(newState: State): Callback =
      Callback {
        if (!state.prevState.exists(_ ==* newState)) {
          state = state.copy(prevState = Some(newState))
          onReadyStateChange2(newState).runNow()
        }
      }

    override lazy val connect: Callback = {
      val attemptConnect: Callback =
        Callback {
          state.scheduled.foreach(timers.clearTimeout)
          val i = unsafeNewInstance()
          state = state.copy(instance = i, scheduled = None)
          i match {
            case Some(_) =>
              state = state.copy(retries = connectionRetries)
              for (e <- requestManager.getAll.runNow())
                resend(e.reqId, e.state).runNow()
            case None =>
              unsafeScheduleReconnect()
          }
        }

      val reauthoriseAndReconnect: Callback =
        reauthorise.attempt.flatMap {

          case Right(Allow) =>
            AsyncCallback.delay {
              state = state.copy(authorised = true)
            } >> connect.asAsyncCallback

          case Right(Deny) | Left(_) =>
            AsyncCallback.delay {
              unsafeFailQueued(errorUnauthorised)
            }

        }.toCallback

      readyStateCB.flatMap {
        case None | Some(ReadyState.Closed) =>

          // Check if re-authorisation has occurred in a different tab
          if (!state.authorised) {
            val sessionExpired = sessionExpiry.unsafeGet()
            if (!sessionExpired)
              state = state.copy(authorised = true)
          }

          if (state.authorised)
            attemptConnect
          else
            reauthoriseAndReconnect

        case Some(ReadyState.Connecting | ReadyState.Open | ReadyState.Closing) =>
          Callback.empty
      }
    }

    private def unsafeNewInstance(): Option[Instance] = {
      // console.info("Connecting to server...")
      createWS.map(new Instance(_)).attempt.runNow() match {
        case Right(i) => Some(i)
        case Left(e) =>
          logger(_.warn(s"Failed to create WebSocket instance."))
          LoggerJs.exception(e)
          None
      }
    }

    private def unsafeScheduleReconnect(): Unit =
      state.retries.pop match {
        case Some((retry, nextRetries)) =>
          logger(_.info(s"WebSocketClient: retry connection in ${retry.toMillis} ms..."))
          val h = timers.setTimeout(retry.toMillis.toDouble) {
            // This bit here is Schedule in websocket_client.tla
            val i = unsafeNewInstance()
            state = state.copy(instance = i, scheduled = None)
            if (i.isEmpty)
              unsafeScheduleReconnect()
          }
          state = state.copy(retries = nextRetries, scheduled = Some(h))

        case None =>
          logger(_.info("WebSocketClient: out of retries. Leaving disconnected."))
          state = state.copy(scheduled = None)
          unsafeFailQueued(errorClosed)
      }

    private val onReadyStateChange2 = onStateChange(this)

    private class Instance(val ws: WebSocket) {
      ws.binaryType.unsafeSet(WebSocket.BinaryType.ArrayBuffer)
      ws.onOpen    .unsafeSet(onOpen _)
      ws.onClose   .unsafeSet(onClose _)
      ws.onMessage .unsafeSet(onMessage _)
      ws.onError   .unsafeSet(onError _)

      private var opened = false

      ws.readyState() match {
        case ReadyState.Open =>
          onOpened()

        case ReadyState.Closed =>
          // Sometimes (like when the security policy blocks the connection), the WebSocket starts in a closed state
          // without sending an onClose event. Catch that here so that the normal retry process kicks in.
          onClosed(None)

          case ReadyState.Closing
             | ReadyState.Connecting => ()
        }

      @elidable(elidable.INFO)
      override def toString = s"WebSocketClient.Instance(${readyState()})"

      def readyState(): ReadyState =
        ws.readyState()

      def isOpen(): Boolean =
        ws.readyState() == ReadyState.Open

      private def runReadyStateChange(): Unit = {
        val newState = State.Authorised(readyState())
        setPublicState(newState).runNow()
      }

      @nowarn("cat=unused")
      private def onOpen(e: Event): Unit =
        onOpened()

      private def onOpened(): Unit = {
        opened = true
        state = state.copy(retries = connectionRetries) // reset retry counter
        runReadyStateChange()
        processQueue()
      }

      def processQueue(): Unit = {
        while (queueOldestToNewest.nonEmpty) {
          val (reqId, payload) = queueOldestToNewest.head

          if (requestManager.get(reqId).runNow().isDefined)
            try
              ws.send(payload.toArrayBuffer)
            catch {
              case t: Throwable =>
                LoggerJs.exception(t)
                logger(_.warn(s"WebSocket.send($payload) failed"))
                throw t
            }

          queueOldestToNewest = queueOldestToNewest.tail
        }
      }

      private def onMessage(e: MessageEvent): Unit = {
        // console.log(s"[${ws.readyState}] onmessage: ", e.data.asInstanceOf[ArrayBuffer])
        def msg = e.data.asInstanceOf[ArrayBuffer]

        val decode = CallbackTo(codecEngine.decode(BinaryData.unsafeFromArrayBuffer(msg))(protocolSC.codec))

        val handler: Callback =
          decode.attempt.flatMap {
            case Right(Right(Right((id, null)))) =>
              logger.pure(_.debug(s"Unable to decode response to req #${id.value}; request has been removed")) >>
                requestManager.remove(id)

            case Right(Right(Right((id, res)))) =>
              logger.pure(_.debug(s"WebSocketClient received response to req #${id.value}: ${res.value}")) >>
                requestManager.complete(id, Success(res))

            case Right(Right(Left(push))) =>
              logger.pure(_.debug(s"WebSocketClient received push: $push")) >>
                recvPush(push)

            case Right(Left(err)) =>
              logger.pure(_.error(s"WebSocketClient failed to process msg: ${BinaryData.fromArrayBuffer(msg)}\n$err")) >>
                err.map(_.foreach(ws.close))

            case Left(err) =>
              logger.pure(_.error(s"WebSocketClient failed to process msg: ${BinaryData.fromArrayBuffer(msg)}\n$err")) >>
                onException(err)
          }
        handler.runNow()
      }

      private def onException(err: Throwable): Callback =
        Callback {
          LoggerJs.exception(err)
          val message = Option(err.getMessage)
          unsafeCloseDueToError(message)
        }

      private def onError(e: Event): Unit = {
        val message = e.asInstanceOf[js.Dynamic].message.asInstanceOf[js.UndefOr[String]]
        unsafeCloseDueToError(message.toOption)
      }

      private def unsafeCloseDueToError(msg: Option[String]): Unit = {
        val desc = s"WebSocket error occurred.${msg.fold("")(" " + _)}"
        //console.error(desc, e)
        val reason = CloseReason(CloseCode.unhandledException, CloseReasonPhrase(desc))
        ws.close(reason)
      }

      private def onClose(e: CloseEvent): Unit = {
        onClosed(Some(CloseCode(e.code)))
      }

      private def onClosed(code: Option[CloseCode]): Unit =
        code match {
          case Some(CloseCode.`unauthorised`) =>
            sessionExpiry.unsafeSet(true)
            state = state.copy(authorised = false)
            setPublicState(State.Unauthorised).runNow()
            connect.runNow()

          case _ =>
            runReadyStateChange()
            unsafeFailQueued(if (opened) errorClosed else errorFailed)
            unsafeScheduleReconnect()
        }
    }

    private val processQueue: Callback =
      Callback {
        for (i <- state.instance)
          if (i.isOpen())
            i.processQueue()
      }

    private def unsafeFailQueued(error: Throwable): Unit =
      requestManager.completeAll(Failure(error))
        .map(_.foreach(_.printStackTrace()))
        .runNow()

    override val close: Callback =
      Callback {
        state = state.copy(retries = Retries.none)
        for (i <- state.instance)
          if (i.isOpen())
            i.ws.close(CloseReason.normalClosure)
      }

    override val readyState: CallbackTo[Option[ReadyState]] =
      CallbackTo(state.instance.map(_.readyState()))

    /** If a connection is open, send an empty message to prevent server-side timeout and keep the connection alive. */
    override val keepAlive: Callback = {
      val ab = new ArrayBuffer(0)
      Callback {
        for (i <- state.instance)
          if (i.isOpen())
            i.ws.send(ab)
      }
    }

    override def send(p: ReqRes)(request: p.RequestType): CallbackTo[AsyncCallback[p.ResponseType]] =
      _send(Request[Codec, ReqRes](p)(request))

    private def _send(request: Request[Codec, ReqRes]): CallbackTo[AsyncCallback[request.p.ResponseType]] =
      _sendGeneric(request) {
        requestManager.newRequest(request).flatMap { case (reqId, promise) =>
          _addToQueue(reqId, request)(promise)
        }
      }

    private def resend(reqId: ReqId, request: Request[Codec, ReqRes]): CallbackTo[AsyncCallback[request.p.ResponseType]] =
      _sendGeneric(request) {
        requestManager.get(reqId).flatMap {
          case Some(e) => _addToQueue(reqId, request)(e.promise)
          case None    => _send(request)
        }
      }

    private def _sendGeneric(request: Request[Codec, ReqRes])
                            (addToQueue: => CallbackTo[AsyncCallback[request.p.ResponseType]]): CallbackTo[AsyncCallback[request.p.ResponseType]] = {
      type R = CallbackTo[AsyncCallback[request.p.ResponseType]]

      def rejectImmediately(error: Throwable): R =
        CallbackTo.pure(AsyncCallback.throwException(error))

      CallbackTo.suspend {
        state.instance.map(_.readyState()) match {
          case Some(ReadyState.Connecting) => addToQueue
          case Some(ReadyState.Open)       => addToQueue <* processQueue
          case Some(ReadyState.Closing)    => rejectImmediately(errorClosing)
          case None
             | Some(ReadyState.Closed)     => addToQueue <* connect
        }
      }
    }

    private def _addToQueue(reqId  : ReqId,
                            request: Request[Codec, ReqRes])
                           (promise: AsyncCallback[Protocol.AndValue[Codec]]): CallbackTo[AsyncCallback[request.p.ResponseType]] =
      CallbackTo {
        val msgValue = (reqId, request.prep.request)
        val msgBin   = codecEngine.encode(msgValue)(protocolCS.codec)
        queueOldestToNewest :+= ((reqId, msgBin))
        promise.map(_.unsafeForceType[request.p.ResponseType].value)
      }

    override def asyncFunction(p: ReqRes): AsyncFunction[p.RequestType, ErrorMsg, p.ResponseType] =
      AsyncFunction.simple(send(p)(_).asAsyncCallback.flatten)
  }

  private val errorClosing      = js.JavaScriptException("Connection is closing.")
  private val errorClosed       = js.JavaScriptException("Connection is closed.")
  private val errorFailed       = js.JavaScriptException("Failed to connect to server.")
  private val errorUnauthorised = js.JavaScriptException("Session expired. You must login again.")

  // ===================================================================================================================

  private sealed trait Request[Codec[_], ReqRes <: Protocol.RequestResponse[Codec]] {
    val p: ReqRes
    val req: p.RequestType

    final lazy val prep = p.prepareSend(req)
  }

  private object Request {
    type Aux[Codec[_], ReqRes <: Protocol.RequestResponse[Codec], P <: ReqRes] = Request[Codec, ReqRes] { val p: P }

    def apply[Codec[_], ReqRes <: Protocol.RequestResponse[Codec]](p: ReqRes)(req: p.RequestType): Aux[Codec, ReqRes, p.type] = {
      val _p: p.type = p
      val _req = req
      new Request[Codec, ReqRes] {
        override val p: _p.type = _p
        override val req = _req
      }
    }
  }

  // ===================================================================================================================

  private trait RequestManager[Id, A, S] {
    def newRequest(state: S): CallbackTo[(Id, AsyncCallback[A])]
    def get(id: Id): CallbackTo[Option[RequestManager.StateEntry[A, S]]]
    def getAll: CallbackTo[List[RequestManager.StateEntry[A, S]]]
    def complete(id: Id, a: Try[A]): Callback
    def completeAll(t: Try[A]): CallbackTo[List[Throwable]]
    def remove(id: Id): Callback
  }

  private object RequestManager {

    final case class StateEntry[A, S](reqId   : ReqId,
                                      state   : S,
                                      complete: Try[A] => Callback,
                                      promise : AsyncCallback[A])

    private final class ArrayStore[A, S] extends RequestManager[ReqId, A, S] {
      private var prevId = 0
      private var size = 0
      private var state = new js.Array[js.UndefOr[StateEntry[A, S]]]

      private def unsafeGet(id: ReqId): Option[StateEntry[A, S]] = {
        // console.log("GET: ", state.asInstanceOf[js.Array[js.Any]])
        val e = state(id.value)
        e.toOption
      }

      override def get(id: ReqId) =
        CallbackTo(unsafeGet(id))

      override val getAll =
        CallbackTo(state.iterator.flatMap(_.toList).toList)

      override def newRequest(s: S): CallbackTo[(ReqId, AsyncCallback[A])] =
        CallbackTo {
          prevId += 1
          val id = prevId
          val reqId = ReqId(id)
          val (ac, tryToCb) = AsyncCallback.promise[A].runNow()
          state(id) = StateEntry(reqId, s, tryToCb, ac)
          size += 1
          (reqId, ac)
        }

      private def unsafeGetAndRemove(reqId: ReqId): Option[StateEntry[A, S]] = {
        val r = unsafeGet(reqId)
        if (r.isDefined) {
          assert(size > 0, s"WebSocketClient.RequestManager: size=$size, reqId=${reqId.value}, state=$state")
          size -= 1
          if (size == 0)
            state = new js.Array
          else
            js.special.delete(state, reqId.value)
        }
        r
      }

      override def complete(reqId: ReqId, a: Try[A]): Callback =
        Callback {
          for (e <- unsafeGetAndRemove(reqId))
            e.complete(a).runNow()
        }

      override def completeAll(t: Try[A]): CallbackTo[List[Throwable]] =
        CallbackTo {
          // Extract handlers
          var l = List.empty[Try[A] => Callback]
          state.forEachJs(_.foreach(l ::= _.complete))

          // Clear state
          state = new js.Array
          size = 0

          // Call handlers
          l.flatMap(_(t).attempt.runNow().left.toSeq)
        }

      override def remove(reqId: ReqId): Callback =
        Callback {
          unsafeGetAndRemove(reqId)
          ()
        }
    }

    def arrayStore[A, S]: RequestManager[ReqId, A, S] =
      new ArrayStore[A, S]
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy