All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zhttp.service.Client.scala Maven / Gradle / Ivy

There is a newer version: 2.0.0-RC11
Show newest version
package zhttp.service

import io.netty.bootstrap.Bootstrap
import io.netty.channel.{
  Channel,
  ChannelFactory => JChannelFactory,
  ChannelFuture => JChannelFuture,
  ChannelInitializer,
  EventLoopGroup => JEventLoopGroup,
}
import io.netty.handler.codec.http._
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler
import zhttp.http._
import zhttp.service
import zhttp.service.Client.Config
import zhttp.service.client.ClientSSLHandler.ClientSSLOptions
import zhttp.service.client.{ClientInboundHandler, ClientSSLHandler}
import zhttp.socket.{Socket, SocketApp}
import zio.{Promise, Task, ZIO}

import java.net.{InetSocketAddress, URI}

final case class Client[R](rtm: HttpRuntime[R], cf: JChannelFactory[Channel], el: JEventLoopGroup)
    extends HttpMessageCodec {

  private[zhttp] def request(request: Request, clientConfig: Config): Task[Response] =
    for {
      promise <- Promise.make[Throwable, Response]
      jReq    <- encode(request)
      _       <- ChannelFuture
        .unit(unsafeRequest(request, clientConfig, jReq, promise))
        .catchAll(cause => promise.fail(cause))
      res     <- promise.await
    } yield res

  def socket(
    url: URL,
    headers: Headers = Headers.empty,
    socketApp: SocketApp[R],
    sslOptions: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
  ): ZIO[R, Throwable, Response] = for {
    env <- ZIO.environment[R]
    res <- request(
      Request(
        version = Version.Http_1_1,
        Method.GET,
        url,
        headers,
      ),
      clientConfig = Client.Config(socketApp = Some(socketApp.provideEnvironment(env)), ssl = Some(sslOptions)),
    )
  } yield res

  /**
   * It handles both - Websocket and HTTP requests.
   */
  private def unsafeRequest(
    req: Request,
    clientConfig: Config,
    jReq: FullHttpRequest,
    promise: Promise[Throwable, Response],
  ): JChannelFuture = {

    try {
      val uri  = new URI(jReq.uri())
      val host = if (uri.getHost == null) jReq.headers().get(HeaderNames.host) else uri.getHost

      assert(host != null, "Host name is required")

      val port = req.url.port.getOrElse(80)

      val isWebSocket = req.url.scheme.exists(_.isWebSocket)
      val isSSL       = req.url.scheme.exists(_.isSecure)

      val initializer = new ChannelInitializer[Channel]() {
        override def initChannel(ch: Channel): Unit = {

          val pipeline                    = ch.pipeline()
          val sslOption: ClientSSLOptions = clientConfig.ssl.getOrElse(ClientSSLOptions.DefaultSSL)

          // If a https or wss request is made we need to add the ssl handler at the starting of the pipeline.
          if (isSSL) pipeline.addLast(SSL_HANDLER, ClientSSLHandler.ssl(sslOption).newHandler(ch.alloc, host, port))

          // Adding default client channel handlers
          pipeline.addLast(HTTP_CLIENT_CODEC, new HttpClientCodec)

          // ObjectAggregator is used to work with FullHttpRequests and FullHttpResponses
          // This is also required to make WebSocketHandlers work
          pipeline.addLast(HTTP_OBJECT_AGGREGATOR, new HttpObjectAggregator(Int.MaxValue))

          // ClientInboundHandler is used to take ClientResponse from FullHttpResponse
          pipeline.addLast(CLIENT_INBOUND_HANDLER, new ClientInboundHandler(rtm, jReq, promise, isWebSocket))

          // Add WebSocketHandlers if it's a `ws` or `wss` request
          if (isWebSocket) {
            val headers = req.headers.encode
            val app     = clientConfig.socketApp.getOrElse(Socket.empty.toSocketApp)
            val config  = app.protocol.clientBuilder
              .customHeaders(headers)
              .webSocketUri(req.url.encode)
              .build()

            // Handles the heavy lifting required to upgrade the connection to a WebSocket connection
            pipeline.addLast(WEB_SOCKET_CLIENT_PROTOCOL_HANDLER, new WebSocketClientProtocolHandler(config))
            pipeline.addLast(WEB_SOCKET_HANDLER, new WebSocketAppHandler(rtm, app))
          }
          ()
        }
      }

      val jBoo = new Bootstrap().channelFactory(cf).group(el).handler(initializer)

      jBoo.remoteAddress(new InetSocketAddress(host, port))

      jBoo.connect()
    } catch {
      case err: Throwable =>
        if (jReq.refCnt() > 0) {
          jReq.release(jReq.refCnt()): Unit
        }
        throw err
    }
  }
}

object Client {
  def make[R]: ZIO[R with EventLoopGroup with ChannelFactory, Nothing, Client[R]] = for {
    cf <- ZIO.service[JChannelFactory[Channel]]
    el <- ZIO.service[JEventLoopGroup]
    zx <- HttpRuntime.default[R]
  } yield service.Client(zx, cf, el)

  def request(
    url: String,
    method: Method = Method.GET,
    headers: Headers = Headers.empty,
    content: HttpData = HttpData.empty,
    ssl: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
  ): ZIO[EventLoopGroup with ChannelFactory, Throwable, Response] =
    for {
      uri <- ZIO.fromEither(URL.fromString(url))
      res <- request(
        Request(Version.Http_1_1, method, uri, headers, data = content),
        clientConfig = Config(ssl = Some(ssl)),
      )
    } yield res

  def request(
    request: Request,
    clientConfig: Config,
  ): ZIO[EventLoopGroup with ChannelFactory, Throwable, Response] =
    for {
      clt <- make[Any]
      res <- clt.request(request, clientConfig)
    } yield res

  def socket[R](
    url: String,
    app: SocketApp[R],
    headers: Headers = Headers.empty,
    sslOptions: ClientSSLOptions = ClientSSLOptions.DefaultSSL,
  ): ZIO[R with EventLoopGroup with ChannelFactory, Throwable, Response] = {
    for {
      clt <- make[R]
      uri <- ZIO.fromEither(URL.fromString(url))
      res <- clt.socket(uri, headers, app, sslOptions)
    } yield res
  }

  case class Config(socketApp: Option[SocketApp[Any]] = None, ssl: Option[ClientSSLOptions] = None) { self =>
    def withSSL(ssl: ClientSSLOptions): Config           = self.copy(ssl = Some(ssl))
    def withSocketApp(socketApp: SocketApp[Any]): Config = self.copy(socketApp = Some(socketApp))
  }

  object Config {
    def empty: Config = Config()
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy