All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zio.http.netty.client.NettyConnectionPool.scala Maven / Gradle / Ivy

/*
 * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package zio.http.netty.client

import java.net.InetSocketAddress
import java.util.concurrent.TimeUnit

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.URL.Location
import zio.http._
import zio.http.netty.{Names, NettyFutureExecutor, NettyProxy, NettyRuntime}

import io.netty.bootstrap.Bootstrap
import io.netty.channel.{Channel => JChannel, ChannelFactory => JChannelFactory, EventLoopGroup => JEventLoopGroup, _}
import io.netty.handler.codec.http.{HttpClientCodec, HttpContentDecompressor}
import io.netty.handler.proxy.HttpProxyHandler
import io.netty.handler.timeout.{ReadTimeoutException, ReadTimeoutHandler}

trait NettyConnectionPool extends ConnectionPool[JChannel]

object NettyConnectionPool {

  protected def createChannel(
    channelFactory: JChannelFactory[JChannel],
    eventLoopGroup: JEventLoopGroup,
    nettyRuntime: NettyRuntime,
    location: URL.Location.Absolute,
    proxy: Option[Proxy],
    sslOptions: ClientSSLConfig,
    maxInitialLineLength: Int,
    maxHeaderSize: Int,
    decompression: Decompression,
    idleTimeout: Option[Duration],
    connectionTimeout: Option[Duration],
    localAddress: Option[InetSocketAddress],
    dnsResolver: DnsResolver,
  )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = {
    val initializer = new ChannelInitializer[JChannel] {
      override def initChannel(ch: JChannel): Unit = {
        val pipeline = ch.pipeline()

        proxy match {
          case Some(proxy) =>
            pipeline.addLast(
              Names.ProxyHandler,
              NettyProxy
                .fromProxy(proxy)
                .encode
                .getOrElse(new HttpProxyHandler(new InetSocketAddress(location.host, location.port))),
            )
          case None        =>
        }

        if (location.scheme.isSecure.getOrElse(false)) {
          pipeline.addLast(
            Names.SSLHandler,
            ClientSSLConverter
              .toNettySSLContext(sslOptions)
              .newHandler(ch.alloc, location.host, location.port),
          )
        }

        idleTimeout.foreach { timeout =>
          pipeline.addLast(Names.ReadTimeoutHandler, new ReadTimeoutHandler(timeout.toMillis, TimeUnit.MILLISECONDS))
        }

        pipeline.addLast(Names.ClientReadTimeoutErrorHandler, new ReadTimeoutErrorHandler(nettyRuntime))

        // Adding default client channel handlers
        // Defaults from netty:
        //   maxInitialLineLength=4096
        //   maxHeaderSize=8192
        //   maxChunkSize=8192
        // and we add: failOnMissingResponse=true
        // This way, if the server closes the connection before the whole response has been sent,
        // we get an error. (We can also handle the channelInactive callback, but since for now
        // we always buffer the whole HTTP response we can letty Netty take care of this)
        pipeline.addLast(Names.HttpClientCodec, new HttpClientCodec(maxInitialLineLength, maxHeaderSize, 8192, true))

        // HttpContentDecompressor
        if (decompression.enabled)
          pipeline.addLast(
            Names.HttpRequestDecompression,
            new HttpContentDecompressor(decompression.strict),
          )

        ()
      }
    }

    for {
      resolvedHosts <- dnsResolver.resolve(location.host)
      hosts         <- Random.shuffle(resolvedHosts.toList)
      hostsNec      <- ZIO.succeed(NonEmptyChunk.fromIterable(hosts.head, hosts.tail))
      ch            <- collectFirstSuccess(hostsNec) { host =>
        ZIO.suspend {
          val bootstrap = new Bootstrap()
            .channelFactory(channelFactory)
            .group(eventLoopGroup)
            .remoteAddress(new InetSocketAddress(host, location.port))
            .withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt))
            .handler(initializer)
          localAddress.foreach(bootstrap.localAddress)

          val channelFuture = bootstrap.connect()
          val ch            = channelFuture.channel()
          Scope.addFinalizer {
            NettyFutureExecutor.executed {
              channelFuture.cancel(true)
              ch.close()
            }.when(ch.isOpen).ignoreLogged
          } *> NettyFutureExecutor.executed(channelFuture).as(ch)
        }
      }
    } yield ch
  }

  private def collectFirstSuccess[R, E, A, B](
    as: NonEmptyChunk[A],
  )(f: A => ZIO[R, E, B])(implicit trace: Trace): ZIO[R, E, B] = {
    ZIO.suspendSucceed {
      val it                 = as.iterator
      def loop: ZIO[R, E, B] = f(it.next()).catchAll(e => if (it.hasNext) loop else ZIO.fail(e))
      loop
    }
  }

  /**
   * Refreshes the idle timeout handler on the channel pipeline.
   * @return
   *   true if the handler was successfully refreshed prior to the channel being
   *   closed
   */
  private def refreshIdleTimeoutHandler(
    channel: JChannel,
    timeout: Duration,
  ): Boolean = {
    channel
      .pipeline()
      .replace(
        Names.ReadTimeoutHandler,
        Names.ReadTimeoutHandler,
        new ReadTimeoutHandler(timeout.toMillis, TimeUnit.MILLISECONDS),
      )
    channel.isOpen
  }

  private final class ReadTimeoutErrorHandler(nettyRuntime: NettyRuntime)(implicit trace: Trace)
      extends ChannelInboundHandlerAdapter {

    implicit private val unsafe: Unsafe = Unsafe.unsafe

    override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
      cause match {
        case _: ReadTimeoutException => nettyRuntime.unsafeRunSync(ZIO.logDebug("ReadTimeoutException caught"))
        case _                       => super.exceptionCaught(ctx, cause)
      }
    }
  }

  private final class NoNettyConnectionPool(
    channelFactory: JChannelFactory[JChannel],
    eventLoopGroup: JEventLoopGroup,
    nettyRuntime: NettyRuntime,
    dnsResolver: DnsResolver,
  ) extends NettyConnectionPool {
    override def get(
      location: Location.Absolute,
      proxy: Option[Proxy],
      sslOptions: ClientSSLConfig,
      maxInitialLineLength: Int,
      maxHeaderSize: Int,
      decompression: Decompression,
      idleTimeout: Option[Duration],
      connectionTimeout: Option[Duration],
      localAddress: Option[InetSocketAddress] = None,
    )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] =
      createChannel(
        channelFactory,
        eventLoopGroup,
        nettyRuntime,
        location,
        proxy,
        sslOptions,
        maxInitialLineLength,
        maxHeaderSize,
        decompression,
        idleTimeout,
        connectionTimeout,
        localAddress,
        dnsResolver,
      )

    override def invalidate(channel: JChannel)(implicit trace: Trace): ZIO[Any, Nothing, Unit] =
      ZIO.unit

    override def enableKeepAlive: Boolean =
      false
  }

  case class PoolKey(
    location: Location.Absolute,
    proxy: Option[Proxy],
    sslOptions: ClientSSLConfig,
    maxInitialLineLength: Int,
    maxHeaderSize: Int,
    decompression: Decompression,
    idleTimeout: Option[Duration],
    connectionTimeout: Option[Duration],
  )

  private final class ZioNettyConnectionPool(
    pool: ZKeyedPool[Throwable, PoolKey, JChannel],
    maxItems: PoolKey => Int,
  ) extends NettyConnectionPool {
    override def get(
      location: Location.Absolute,
      proxy: Option[Proxy],
      sslOptions: ClientSSLConfig,
      maxInitialLineLength: Int,
      maxHeaderSize: Int,
      decompression: Decompression,
      idleTimeout: Option[Duration],
      connectionTimeout: Option[Duration],
      localAddress: Option[InetSocketAddress] = None,
    )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = ZIO.uninterruptibleMask { restore =>
      val key = PoolKey(
        location,
        proxy,
        sslOptions,
        maxInitialLineLength,
        maxHeaderSize,
        decompression,
        idleTimeout,
        connectionTimeout,
      )

      restore(pool.get(key)).withEarlyRelease.flatMap { case (release, channel) =>
        // Channel might have closed while in the pool, either because of a timeout or because of a connection error
        // We retry a few times hoping to obtain an open channel
        // NOTE: We need to release the channel before retrying, so that it can be closed and removed from the pool
        // We do that in a forked fiber so that we don't "block" the current fiber while the new resource is obtained
        if (channel.isOpen && idleTimeout.fold(true)(refreshIdleTimeoutHandler(channel, _)))
          ZIO.succeed(channel)
        else
          invalidate(channel) *> release.forkDaemon *> ZIO.fail(None)
      }
        .retry(retrySchedule(key))
        .catchAll {
          case None         => pool.get(key) // We did all we could, let the caller handle it
          case e: Throwable => ZIO.fail(e)
        }
        .withFinalizer(c => ZIO.unless(c.isOpen)(invalidate(c)))
    }

    override def invalidate(channel: JChannel)(implicit trace: Trace): ZIO[Any, Nothing, Unit] =
      pool.invalidate(channel)

    override def enableKeepAlive: Boolean = true

    private def retrySchedule[E](key: PoolKey)(implicit trace: Trace) =
      Schedule.recurWhile[E] {
        case None => true
        case _    => false
      } && Schedule.recurs(maxItems(key))
  }

  def fromConfig(
    config: ConnectionPoolConfig,
  )(implicit
    trace: Trace,
  ): ZIO[Scope with NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    for {
      pool <- config match {
        case ConnectionPoolConfig.Disabled                         =>
          createDisabled
        case ConnectionPoolConfig.Fixed(size)                      =>
          createFixed(size)
        case ConnectionPoolConfig.FixedPerHost(sizes, default)     =>
          createFixedPerHost(location => sizes.getOrElse(location, default).size)
        case ConnectionPoolConfig.Dynamic(minimum, maximum, ttl)   =>
          createDynamic(minimum, maximum, ttl)
        case ConnectionPoolConfig.DynamicPerHost(configs, default) =>
          createDynamicPerHost(
            location => configs.getOrElse(location, default).minimum,
            location => configs.getOrElse(location, default).maximum,
            location => configs.getOrElse(location, default).ttl,
          )
      }
    } yield pool

  private def createDisabled(implicit
    trace: Trace,
  ): ZIO[NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    for {
      driver      <- ZIO.service[NettyClientDriver]
      dnsResolver <- ZIO.service[DnsResolver]
    } yield new NoNettyConnectionPool(driver.channelFactory, driver.eventLoopGroup, driver.nettyRuntime, dnsResolver)

  private def createFixed(size: Int)(implicit
    trace: Trace,
  ): ZIO[Scope with NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    createFixedPerHost(_ => size)

  private def createFixedPerHost(
    size: URL.Location.Absolute => Int,
  )(implicit trace: Trace): ZIO[Scope with NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    for {
      driver      <- ZIO.service[NettyClientDriver]
      dnsResolver <- ZIO.service[DnsResolver]
      poolFn = (key: PoolKey) =>
        createChannel(
          driver.channelFactory,
          driver.eventLoopGroup,
          driver.nettyRuntime,
          key.location,
          key.proxy,
          key.sslOptions,
          key.maxInitialLineLength,
          key.maxHeaderSize,
          key.decompression,
          key.idleTimeout,
          key.connectionTimeout,
          None,
          dnsResolver,
        ).uninterruptible
      _size  = (key: PoolKey) => size(key.location)
      keyedPool <- ZKeyedPool.make(poolFn, _size)
    } yield new ZioNettyConnectionPool(keyedPool, _size)

  private def createDynamic(
    min: Int,
    max: Int,
    ttl: Duration,
  )(implicit trace: Trace): ZIO[Scope with NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    createDynamicPerHost(_ => min, _ => max, _ => ttl)

  private def createDynamicPerHost(
    min: URL.Location.Absolute => Int,
    max: URL.Location.Absolute => Int,
    ttl: URL.Location.Absolute => Duration,
  )(implicit trace: Trace): ZIO[Scope with NettyClientDriver with DnsResolver, Nothing, NettyConnectionPool] =
    for {
      driver      <- ZIO.service[NettyClientDriver]
      dnsResolver <- ZIO.service[DnsResolver]
      poolFn = (key: PoolKey) =>
        createChannel(
          driver.channelFactory,
          driver.eventLoopGroup,
          driver.nettyRuntime,
          key.location,
          key.proxy,
          key.sslOptions,
          key.maxInitialLineLength,
          key.maxHeaderSize,
          key.decompression,
          key.idleTimeout,
          key.connectionTimeout,
          None,
          dnsResolver,
        ).uninterruptible
      keyedPool <- ZKeyedPool.make(
        poolFn,
        (key: PoolKey) => min(key.location) to max(key.location),
        (key: PoolKey) => ttl(key.location),
      )
    } yield new ZioNettyConnectionPool(keyedPool, key => max(key.location))

  implicit final class BootstrapSyntax(val bootstrap: Bootstrap) extends AnyVal {
    def withOption[T](option: ChannelOption[T], value: Option[T]): Bootstrap =
      value match {
        case Some(value) => bootstrap.option(option, value)
        case None        => bootstrap
      }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy