All Downloads are FREE. Search and download functionalities are using the official Maven repository.

izumi.idealingua.runtime.rpc.http4s.HttpServer.scala Maven / Gradle / Ivy

package izumi.idealingua.runtime.rpc.http4s

import _root_.io.circe.parser.*
import cats.data.OptionT
import cats.effect.Async
import cats.effect.std.Queue
import fs2.Stream
import io.circe
import io.circe.syntax.EncoderOps
import io.circe.{Json, Printer}
import izumi.functional.bio.Exit.{Error, Interruption, Success, Termination}
import izumi.functional.bio.{Exit, F, IO2, Primitives2, Temporal2, UnsafeRun2}
import izumi.fundamentals.platform.language.Quirks
import izumi.fundamentals.platform.time.IzTime
import izumi.idealingua.runtime.rpc.*
import izumi.idealingua.runtime.rpc.http4s.HttpServer.{ServerWsRpcHandler, WsResponseMarker}
import izumi.idealingua.runtime.rpc.http4s.ws.*
import izumi.idealingua.runtime.rpc.http4s.ws.WsClientSession.WsClientSessionImpl
import izumi.idealingua.runtime.rpc.http4s.ws.WsContextProvider.WsAuthResult
import logstage.LogIO2
import org.http4s.*
import org.http4s.dsl.Http4sDsl
import org.http4s.server.AuthMiddleware
import org.http4s.server.websocket.WebSocketBuilder2
import org.http4s.websocket.WebSocketFrame
import org.typelevel.vault.Key

import java.time.ZonedDateTime
import java.util.concurrent.RejectedExecutionException
import scala.concurrent.duration.DurationInt

class HttpServer[F[+_, +_]: IO2: Temporal2: Primitives2: UnsafeRun2, RequestCtx, MethodCtx, ClientId](
  val muxer: IRTServerMultiplexor[F, RequestCtx],
  val codec: IRTClientMultiplexor[F],
  val contextProvider: AuthMiddleware[F[Throwable, _], RequestCtx],
  val wsContextProvider: WsContextProvider[F, RequestCtx, ClientId],
  val wsSessionStorage: WsSessionsStorage[F, RequestCtx, ClientId],
  val listeners: Seq[WsSessionListener[F, ClientId]],
  dsl: Http4sDsl[F[Throwable, _]],
  logger: LogIO2[F],
  printer: Printer,
)(implicit val AT: Async[F[Throwable, _]]
) {
  import dsl.*

  // WS Response attribute key, to differ from usual HTTP responses
  private val wsAttributeKey = UnsafeRun2[F].unsafeRun(Key.newKey[F[Throwable, _], WsResponseMarker.type])

  protected def loggingMiddle(service: HttpRoutes[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = {
    cats.data.Kleisli {
      (req: Request[F[Throwable, _]]) =>
        OptionT.apply {
          (for {
            _    <- logger.trace(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: initiated")
            resp <- service(req).value
            _ <- F.traverse(resp) {
              case Status.Successful(resp) =>
                logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: success, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}")
              case resp if resp.attributes.contains(wsAttributeKey) =>
                logger.debug(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: websocket request")
              case resp =>
                logger.info(s"${req.method.name -> "method"} ${req.pathInfo -> "uri"}: rejection, ${resp.status.code -> "code"} ${resp.status.reason -> "reason"}")
            }
          } yield resp).tapError {
            cause =>
              logger.error(s"${req.method.name -> "method"} ${req.pathInfo -> "path"}: failure, $cause")
          }
        }
    }
  }

  def service(ws: WebSocketBuilder2[F[Throwable, _]]): HttpRoutes[F[Throwable, _]] = {
    val svc                                   = AuthedRoutes.of(router(ws))
    val aservice: HttpRoutes[F[Throwable, _]] = contextProvider(svc)
    loggingMiddle(aservice)
  }

  protected def router(ws: WebSocketBuilder2[F[Throwable, _]]): PartialFunction[AuthedRequest[F[Throwable, _], RequestCtx], F[Throwable, Response[F[Throwable, _]]]] = {
    case request @ GET -> Root / "ws" as ctx =>
      setupWs(request, ctx, ws)

    case request @ GET -> Root / service / method as ctx =>
      val methodId = IRTMethodId(IRTServiceId(service), IRTMethodName(method))
      processHttpRequest(new HttpRequestContext(request, ctx), body = "{}", methodId)

    case request @ POST -> Root / service / method as ctx =>
      val methodId = IRTMethodId(IRTServiceId(service), IRTMethodName(method))
      request.req.decode[String] {
        body =>
          processHttpRequest(new HttpRequestContext(request, ctx), body, methodId)
      }
  }

  protected def handleWsClose(context: WsClientSession[F, RequestCtx, ClientId]): F[Throwable, Unit] = {
    logger.debug(s"WS Session: Websocket client disconnected ${context.id}.") *>
    context.finish()
  }

  protected def globalWsListener: WsSessionListener[F, ClientId] = new WsSessionListener[F, ClientId] {
    def onSessionOpened(context: WsClientId[ClientId]): F[Throwable, Unit] = {
      logger.debug(s"WS Session: opened ${context.id}.")
    }
    def onClientIdUpdate(context: WsClientId[ClientId], old: WsClientId[ClientId]): F[Throwable, Unit] = {
      logger.debug(s"WS Session: Id updated to ${context.id}, was: ${old.id}.")
    }
    def onSessionClosed(context: WsClientId[ClientId]): F[Throwable, Unit] = {
      logger.debug(s"WS Session: closed ${context.id}.")
    }
  }

  protected def setupWs(
    request: AuthedRequest[F[Throwable, _], RequestCtx],
    initialContext: RequestCtx,
    ws: WebSocketBuilder2[F[Throwable, _]],
  ): F[Throwable, Response[F[Throwable, _]]] = {
    Quirks.discard(request)
    def pingStream: Stream[F[Throwable, _], WebSocketFrame.Ping] = {
      Stream
        .awakeEvery[F[Throwable, _]](5.second)
        .evalMap(_ => logger.debug("WS Server: Sending ping frame.").as(WebSocketFrame.Ping()))
    }
    for {
      outQueue           <- Queue.unbounded[F[Throwable, _], WebSocketFrame]
      listenersWithGlobal = Seq(globalWsListener) ++ listeners
      context             = new WsClientSessionImpl(outQueue, initialContext, listenersWithGlobal, wsSessionStorage, printer, logger)
      _                  <- context.start()
      outStream           = Stream.fromQueueUnterminated(outQueue).merge(pingStream)
      inStream = {
        (inputStream: Stream[F[Throwable, _], WebSocketFrame]) =>
          inputStream.evalMap {
            processWsRequest(context, IzTime.utcNow)(_).flatMap {
              case Some(v) => outQueue.offer(WebSocketFrame.Text(v))
              case None    => F.unit
            }
          }
      }
      response <- ws.withOnClose(handleWsClose(context)).build(outStream, inStream)
    } yield {
      response.withAttribute(wsAttributeKey, WsResponseMarker)
    }
  }

  protected def processWsRequest(
    context: WsClientSession[F, RequestCtx, ClientId],
    requestTime: ZonedDateTime,
  )(frame: WebSocketFrame
  ): F[Throwable, Option[String]] = {
    (frame match {
      case WebSocketFrame.Text(msg, _) => wsHandler(context).processRpcMessage(msg)
      case WebSocketFrame.Close(_)     => F.pure(None)
      case _: WebSocketFrame.Pong      => onWsHeartbeat(requestTime).as(None)
      case unknownMessage =>
        val message = s"Unsupported WS frame: $unknownMessage."
        logger
          .error(s"WS request failed: $message.")
          .as(Some(RpcPacket.rpcCritical(message, None)))
    }).map(_.map(p => printer.print(p.asJson)))
  }

  protected def wsHandler(context: WsClientSession[F, RequestCtx, ClientId]): WsRpcHandler[F, RequestCtx] = {
    new ServerWsRpcHandler(muxer, wsContextProvider, context, logger)
  }

  protected def onWsHeartbeat(requestTime: ZonedDateTime): F[Throwable, Unit] = {
    logger.debug(s"WS Session: pong frame at $requestTime")
  }

  protected def processHttpRequest(
    context: HttpRequestContext[F[Throwable, _], RequestCtx],
    body: String,
    method: IRTMethodId,
  ): F[Throwable, Response[F[Throwable, _]]] = {
    val ioR = for {
      parsed      <- F.fromEither(parse(body))
      maybeResult <- muxer.doInvoke(parsed, context.context, method)
    } yield {
      maybeResult
    }

    ioR.sandboxExit.flatMap(handleHttpResult(context, method, _))
  }

  private def handleHttpResult(
    context: HttpRequestContext[F[Throwable, _], RequestCtx],
    method: IRTMethodId,
    result: Exit[Throwable, Option[Json]],
  ): F[Throwable, Response[F[Throwable, _]]] = {
    result match {
      case Success(Some(value)) =>
        dsl.Ok(printer.print(value))

      case Success(None) =>
        logger.warn(s"${context -> null}: No service handler for $method") *>
        dsl.NotFound()

      case Error(error: circe.Error, trace) =>
        logger.info(s"${context -> null}: Parsing failure while handling $method: $error $trace") *>
        dsl.BadRequest()

      case Error(error: IRTDecodingException, trace) =>
        logger.info(s"${context -> null}: Parsing failure while handling $method: $error $trace") *>
        dsl.BadRequest()

      case Error(error: IRTLimitReachedException, trace) =>
        logger.debug(s"${context -> null}: Request failed because of request limit reached $method: $error $trace") *>
        dsl.TooManyRequests()

      case Error(error: IRTUnathorizedRequestContextException, trace) =>
        // workarount because implicits conflict
        logger.debug(s"${context -> null}: Request failed because of unexpected request context reached $method: $error $trace") *>
        dsl.Forbidden().map(_.copy(status = dsl.Unauthorized))

      case Error(error, trace) =>
        logger.info(s"${context -> null}: Unexpected failure while handling $method: $error $trace") *>
        dsl.InternalServerError()

      case Termination(_, (cause: IRTHttpFailureException) :: _, trace) =>
        logger.debug(s"${context -> null}: Request rejected, $method, ${context.request}, $cause, $trace") *>
        F.pure(Response(status = cause.status))

      case Termination(_, (cause: RejectedExecutionException) :: _, trace) =>
        logger.warn(s"${context -> null}: Not enough capacity to handle $method: $cause $trace") *>
        dsl.TooManyRequests()

      case Termination(cause, _, trace) =>
        logger.error(s"${context -> null}: Execution failed, termination, $method, ${context.request}, $cause, $trace") *>
        dsl.InternalServerError()

      case Interruption(cause, _, trace) =>
        logger.info(s"${context -> null}: Unexpected interruption while handling $method: $cause $trace") *>
        dsl.InternalServerError()
    }
  }

}

object HttpServer {
  case object WsResponseMarker
  class ServerWsRpcHandler[F[+_, +_]: IO2, RequestCtx, ClientId](
    muxer: IRTServerMultiplexor[F, RequestCtx],
    wsContextProvider: WsContextProvider[F, RequestCtx, ClientId],
    context: WsClientSession[F, RequestCtx, ClientId],
    logger: LogIO2[F],
  ) extends WsRpcHandler[F, RequestCtx](muxer, context, logger) {
    override def handlePacket(packet: RpcPacket): F[Throwable, Unit] = {
      wsContextProvider.toId(context.initialContext, context.id, packet).flatMap(context.updateId)
    }
    override def handleAuthRequest(packet: RpcPacket): F[Throwable, Option[RpcPacket]] = {
      wsContextProvider.handleAuthorizationPacket(context.id, context.initialContext, packet).flatMap {
        case WsAuthResult(id, packet) => context.updateId(id).as(Some(packet))
      }
    }
    override def extractContext(packet: RpcPacket): F[Throwable, RequestCtx] = {
      wsContextProvider.toContext(context.id, context.initialContext, packet)
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy