All Downloads are FREE. Search and download functionalities are using the official Maven repository.

izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.scala Maven / Gradle / Ivy

package izumi.idealingua.runtime.rpc.http4s.clients

import io.circe.syntax.*
import io.circe.{Json, Printer}
import izumi.functional.bio.{Async2, Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.functional.lifecycle.Lifecycle
import izumi.fundamentals.platform.language.Quirks.Discarder
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcher.IRTDispatcherWs
import izumi.idealingua.runtime.rpc.http4s.clients.WsRpcDispatcherFactory.{ClientWsRpcHandler, WsRpcClientConnection, WsRpcContextProvider, fromNettyFuture}
import izumi.idealingua.runtime.rpc.http4s.ws.{RawResponse, WsRequestState, WsRpcHandler}
import izumi.logstage.api.IzLogger
import logstage.LogIO2
import org.asynchttpclient.netty.ws.NettyWebSocket
import org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler}
import org.asynchttpclient.{DefaultAsyncHttpClient, DefaultAsyncHttpClientConfig}
import org.http4s.Uri

import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters.*

class WsRpcDispatcherFactory[F[+_, +_]: Async2: Temporal2: Primitives2: UnsafeRun2](
  codec: IRTClientMultiplexor[F],
  printer: Printer,
  logger: LogIO2[F],
  izLogger: IzLogger,
) {

  def connect[ServerContext](
    uri: Uri,
    muxer: IRTServerMultiplexor[F, ServerContext],
    contextProvider: WsRpcContextProvider[ServerContext],
  ): Lifecycle[F[Throwable, _], WsRpcClientConnection[F]] = {
    for {
      client       <- WsRpcDispatcherFactory.asyncHttpClient[F]
      requestState <- Lifecycle.liftF(F.syncThrowable(WsRequestState.create[F]))
      listener     <- Lifecycle.liftF(F.syncThrowable(createListener(muxer, contextProvider, requestState, dispatcherLogger(uri, logger))))
      handler      <- Lifecycle.liftF(F.syncThrowable(new WebSocketUpgradeHandler(List(listener).asJava)))
      nettyWebSocket <- Lifecycle.make(
        F.fromFutureJava(client.prepareGet(uri.toString()).execute(handler).toCompletableFuture)
      )(nettyWebSocket => fromNettyFuture(nettyWebSocket.sendCloseFrame()).void)
      // fill promises before closing WS connection, potentially giving a chance to send out an error response before closing
      _ <- Lifecycle.make(F.unit)(_ => requestState.clear())
    } yield {
      new WsRpcClientConnection.Netty(nettyWebSocket, requestState, printer)
    }
  }

  def dispatcher[ServerContext](
    uri: Uri,
    muxer: IRTServerMultiplexor[F, ServerContext],
    contextProvider: WsRpcContextProvider[ServerContext],
    tweakRequest: RpcPacket => RpcPacket = identity,
    timeout: FiniteDuration              = 30.seconds,
  ): Lifecycle[F[Throwable, _], IRTDispatcherWs[F]] = {
    connect(uri, muxer, contextProvider).map {
      new WsRpcDispatcher(_, timeout, codec, dispatcherLogger(uri, logger)) {
        override protected def buildRequest(rpcPacketId: RpcPacketId, method: IRTMethodId, body: Json): RpcPacket = {
          tweakRequest(super.buildRequest(rpcPacketId, method, body))
        }
      }
    }
  }

  protected def wsHandler[ServerContext](
    logger: LogIO2[F],
    muxer: IRTServerMultiplexor[F, ServerContext],
    contextProvider: WsRpcContextProvider[ServerContext],
    requestState: WsRequestState[F],
  ): WsRpcHandler[F, ServerContext] = {
    new ClientWsRpcHandler(muxer, requestState, contextProvider, logger)
  }

  protected def createListener[ServerContext](
    muxer: IRTServerMultiplexor[F, ServerContext],
    contextProvider: WsRpcContextProvider[ServerContext],
    requestState: WsRequestState[F],
    logger: LogIO2[F],
  ): WebSocketListener = new WebSocketListener() {
    private val handler   = wsHandler(logger, muxer, contextProvider, requestState)
    private val socketRef = new AtomicReference[Option[WebSocket]](None)

    override def onOpen(websocket: WebSocket): Unit = {
      socketRef.set(Some(websocket))
    }

    override def onClose(websocket: WebSocket, code: Int, reason: String): Unit = {
      socketRef.set(None)
      websocket.sendCloseFrame()
      ()
    }

    override def onError(t: Throwable): Unit = {
      socketRef.getAndSet(None).foreach(_.sendCloseFrame())
    }

    override def onPingFrame(payload: Array[Byte]): Unit = {
      socketRef.get().foreach(_.sendPongFrame())
    }

    override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int): Unit = {
      UnsafeRun2[F].unsafeRunAsync(handler.processRpcMessage(payload)) {
        exit =>
          val maybeResponse = exit match {
            case Exit.Success(response)         => response
            case Exit.Error(error, _)           => handleWsError(List(error), "errored")
            case Exit.Termination(error, _, _)  => handleWsError(List(error), "terminated")
            case Exit.Interruption(error, _, _) => handleWsError(List(error), "interrupted")
          }
          maybeResponse.foreach {
            response =>
              socketRef.get().foreach {
                ws => ws.sendTextFrame(printer.print(response.asJson))
              }
          }
      }
    }
  }

  protected def dispatcherLogger(uri: Uri, logger: LogIO2[F]): LogIO2[F] = {
    uri.discard()
    logger
  }

  private def handleWsError(causes: List[Throwable], message: String): Option[RpcPacket] = {
    causes.headOption match {
      case Some(cause) =>
        izLogger.error(s"WS request failed: $message, $cause")
        Some(RpcPacket.rpcCritical(s"$message, cause: $cause", None))
      case None =>
        izLogger.error(s"WS request failed: $message.")
        Some(RpcPacket.rpcCritical(message, None))
    }
  }
}

object WsRpcDispatcherFactory {
  def asyncHttpClient[F[+_, +_]: IO2]: Lifecycle[F[Throwable, _], DefaultAsyncHttpClient] = {
    Lifecycle.fromAutoCloseable(F.syncThrowable {
      new DefaultAsyncHttpClient(
        new DefaultAsyncHttpClientConfig.Builder()
          .setWebSocketMaxBufferSize(64 * 1024 * 1024 * 8) // increase buffer size for 64MB, 128000000 - is default value
          .setWebSocketMaxFrameSize(64 * 1024 * 1024 * 8) // increase frame size for 64MB
          .setKeepAlive(true)
          .setSoKeepAlive(true)
          .setRequestTimeout(30 * 1000) // 60 seconds is default
          .setPooledConnectionIdleTimeout(60 * 1000) // 60 seconds is default
          .setConnectTimeout(30 * 1000) // 5 seconds is default
          .setReadTimeout(60 * 1000) // 60 seconds is default
          .setShutdownTimeout(15 * 1000) // 15 seconds is default
          .build()
      )
    })
  }

  class ClientWsRpcHandler[F[+_, +_]: IO2, ServerCtx](
    muxer: IRTServerMultiplexor[F, ServerCtx],
    requestState: WsRequestState[F],
    contextProvider: WsRpcContextProvider[ServerCtx],
    logger: LogIO2[F],
  ) extends WsRpcHandler[F, ServerCtx](muxer, requestState, logger) {
    override def handlePacket(packet: RpcPacket): F[Throwable, Unit] = {
      F.unit
    }
    override def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = {
      F.pure(None)
    }
    override def extractContext(packet: RpcPacket): F[Throwable, ServerCtx] = {
      F.sync(contextProvider.toContext(packet))
    }
  }

  trait WsRpcClientConnection[F[_, _]] {
    private[clients] def requestAndAwait(id: RpcPacketId, packet: RpcPacket, method: Option[IRTMethodId], timeout: FiniteDuration): F[Throwable, Option[RawResponse]]
    def authorize(headers: Map[String, String], timeout: FiniteDuration = 30.seconds): F[Throwable, Unit]
  }
  object WsRpcClientConnection {
    class Netty[F[+_, +_]: Async2](
      nettyWebSocket: NettyWebSocket,
      requestState: WsRequestState[F],
      printer: Printer,
    ) extends WsRpcClientConnection[F] {

      override def authorize(headers: Map[String, String], timeout: FiniteDuration): F[Throwable, Unit] = {
        val packetId = RpcPacketId.random()
        requestAndAwait(packetId, RpcPacket.auth(packetId, headers), None, timeout).flatMap {
          case Some(_: RawResponse.GoodRawResponse)    => F.unit
          case Some(_: RawResponse.EmptyRawResponse)   => F.unit
          case Some(value: RawResponse.BadRawResponse) => F.fail(new IRTGenericFailure(s"Authorization failed: ${value.error}."))
          case None                                    => F.fail(new IRTGenericFailure("Unable to authorize."))
        }
      }

      override private[clients] def requestAndAwait(
        id: RpcPacketId,
        packet: RpcPacket,
        method: Option[IRTMethodId],
        timeout: FiniteDuration,
      ): F[Throwable, Option[RawResponse]] = {
        requestState.requestAndAwait(id, method, timeout) {
          fromNettyFuture(nettyWebSocket.sendTextFrame(printer.print(packet.asJson)))
        }
      }
    }
  }

  trait WsRpcContextProvider[Ctx] {
    def toContext(packet: RpcPacket): Ctx
  }
  object WsRpcContextProvider {
    def unit: WsRpcContextProvider[Unit] = _ => ()
  }

  private def fromNettyFuture[F[+_, +_]: Async2, A](mkNettyFuture: => io.netty.util.concurrent.Future[A]): F[Throwable, A] = {
    F.syncThrowable(mkNettyFuture).flatMap {
      nettyFuture =>
        F.asyncCancelable {
          callback =>
            nettyFuture.addListener {
              (completedFuture: io.netty.util.concurrent.Future[A]) =>
                try {
                  if (!completedFuture.isDone) {
                    // shouldn't be possible, future should already be completed
                    completedFuture.await(1000L)
                  }
                  if (completedFuture.isSuccess) {
                    callback(Right(completedFuture.getNow))
                  } else {
                    Option(completedFuture.cause()) match {
                      case Some(error) => callback(Left(error))
                      case None        => callback(Left(new RuntimeException("Awaiting NettyFuture failed, but no exception was available.")))
                    }
                  }
                } catch {
                  case exception: Throwable =>
                    callback(Left(new RuntimeException(s"Awaiting NettyFuture threw an exception=$exception")))
                }
            }
            val canceler = F.sync {
              nettyFuture.cancel(false);
              ()
            }
            canceler
        }
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy