All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.finagle.Mux.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.finagle

import com.twitter.conversions.storage._
import com.twitter.finagle.client._
import com.twitter.finagle.factory.BindingFactory
import com.twitter.finagle.filter.PayloadSizeFilter
import com.twitter.finagle.mux.lease.exp.Lessor
import com.twitter.finagle.mux.transport.{Message, MuxFramer, Netty3Framer}
import com.twitter.finagle.mux.{Handshake, FailureDetector}
import com.twitter.finagle.netty3.{Netty3Listener, Netty3Transporter}
import com.twitter.finagle.param.{WithDefaultLoadBalancer, ProtocolLibrary}
import com.twitter.finagle.pool.SingletonPool
import com.twitter.finagle.server._
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.tracing._
import com.twitter.finagle.transport.Transport
import com.twitter.finagle.{param => fparam}
import com.twitter.util.{Closable, Future, StorageUnit}
import java.net.SocketAddress
import org.jboss.netty.buffer.ChannelBuffer

/**
 * A client and server for the mux protocol described in [[com.twitter.finagle.mux]].
 */
object Mux extends Client[mux.Request, mux.Response] with Server[mux.Request, mux.Response] {
  /**
   * The current version of the mux protocol.
   */
  val LatestVersion: Short = 0x0001

  /**
   * Mux-specific stack params.
   */
  object param {
    /**
     * A class eligible for configuring the maximum size of a mux frame.
     * Any message that is larger than this value is fragmented across multiple
     * transmissions. Clients and Servers can use this to set an upper bound
     * on the size of messages they are willing to receive. The value is exchanged
     * and applied during the mux handshake.
     */
    case class MaxFrameSize(size: StorageUnit) {
      assert(size.inBytes <= Int.MaxValue, s"$size is not <= Int.MaxValue bytes")
      assert(size.inBytes > 0, s"$size must be positive")

      def mk(): (MaxFrameSize, Stack.Param[MaxFrameSize]) =
        (this, MaxFrameSize.param)
    }
    object MaxFrameSize {
      implicit val param = Stack.Param(MaxFrameSize(Int.MaxValue.bytes))
    }
  }

  /**
   * Extract feature flags from peer headers and decorate the trans.
   *
   * @param maxFrameSize the maximum frame size that was sent to the peer.
   *
   * @param statsReceiver the stats receiver used to configure various modules
   * configured during negotiation.
   */
  private[finagle] def negotiate(
    maxFrameSize: StorageUnit,
    statsReceiver: StatsReceiver
  ): Handshake.Negotiator = (peerHeaders, trans) => {
    val remoteMaxFrameSize = Handshake.valueOf(MuxFramer.Header.KeyBuf, peerHeaders)
      .map { cb => MuxFramer.Header.decodeFrameSize(cb) }
    // Decorate the transport with the MuxFramer. We need to handle the
    // cross product of local and remote configuration. The idea is that
    // both clients and servers can specify the maximum frame size they
    // would like their peer to send.
    val framerStats = statsReceiver.scope("framer")
    (maxFrameSize, remoteMaxFrameSize) match {
      // The remote peer has suggested a max frame size less than the
      // sentinal value. We need to configure the framer to fragment.
      case (_, s@Some(remote)) if remote < Int.MaxValue =>
        MuxFramer(trans, s, framerStats)
      // The local instance has requested a max frame size less than the
      // sentinal value. We need to be prepared for the remote to send
      // fragments.
      case (local, _) if local.inBytes < Int.MaxValue =>
        MuxFramer(trans, None, framerStats)
      case (_, _) => trans.map(Message.encode, Message.decode)
    }
  }

  private[finagle] abstract class ProtoTracing(
    process: String,
    val role: Stack.Role
  ) extends Stack.Module0[ServiceFactory[mux.Request, mux.Response]] {
    val description = s"Mux specific $process traces"

    private[this] val tracingFilter = new SimpleFilter[mux.Request, mux.Response] {
      def apply(req: mux.Request, svc: Service[mux.Request, mux.Response]): Future[mux.Response] = {
        Trace.recordBinary(s"$process/mux/enabled", true)
        svc(req)
      }
    }

    def make(next: ServiceFactory[mux.Request, mux.Response]) =
      tracingFilter andThen next
  }

  private[finagle] class ClientProtoTracing extends ProtoTracing("clnt", StackClient.Role.protoTracing)

  object Client {
    /** Prepends bound residual paths to outbound Mux requests's destinations. */
    private object MuxBindingFactory extends BindingFactory.Module[mux.Request, mux.Response] {
      protected[this] def boundPathFilter(residual: Path) =
        Filter.mk[mux.Request, mux.Response, mux.Request, mux.Response] { (req, service) =>
          service(mux.Request(residual ++ req.destination, req.body))
        }
    }

    val stack: Stack[ServiceFactory[mux.Request, mux.Response]] = StackClient.newStack
      .replace(StackClient.Role.pool, SingletonPool.module[mux.Request, mux.Response])
      .replace(StackClient.Role.protoTracing, new ClientProtoTracing)
      .replace(BindingFactory.role, MuxBindingFactory)
      .prepend(PayloadSizeFilter.module(_.body.length, _.body.length))

    /**
     * Returns the headers that a client sends to a server.
     *
     * @param maxFrameSize the maximum mux fragment size the client is willing to
     * receive from a server.
     */
    private def headers(maxFrameSize: StorageUnit): Handshake.Headers = Seq(
      MuxFramer.Header.KeyBuf -> MuxFramer.Header.encodeFrameSize(
        maxFrameSize.inBytes.toInt)
    )
  }

  case class Client(
      stack: Stack[ServiceFactory[mux.Request, mux.Response]] = Client.stack,
      params: Stack.Params = StackClient.defaultParams + ProtocolLibrary("mux"))
    extends StdStackClient[mux.Request, mux.Response, Client]
    with WithDefaultLoadBalancer[Client] {

    protected def copy1(
      stack: Stack[ServiceFactory[mux.Request, mux.Response]] = this.stack,
      params: Stack.Params = this.params
    ): Client = copy(stack, params)

    protected type In = ChannelBuffer
    protected type Out = ChannelBuffer

    private[this] val statsReceiver = params[fparam.Stats].statsReceiver.scope("mux")

    protected def newTransporter(): Transporter[In, Out] =
      Netty3Transporter(Netty3Framer, params)

    protected def newDispatcher(
      transport: Transport[In, Out]
    ): Service[mux.Request, mux.Response] = {
      val fparam.Label(name) = params[fparam.Label]
      val param.MaxFrameSize(maxFrameSize) = params[param.MaxFrameSize]
      val FailureDetector.Param(detectorConfig) = params[FailureDetector.Param]

      val negotiatedTrans = mux.Handshake.client(
        trans = transport,
        version = LatestVersion,
        headers = Client.headers(maxFrameSize),
        negotiate = negotiate(maxFrameSize, statsReceiver))

      val session = new mux.ClientSession(
        negotiatedTrans,
        detectorConfig,
        name,
        statsReceiver)

      mux.ClientDispatcher.newRequestResponse(session)
    }
  }

  val client = Client()

  def newService(dest: Name, label: String): Service[mux.Request, mux.Response] =
    client.newService(dest, label)

  def newClient(dest: Name, label: String): ServiceFactory[mux.Request, mux.Response] =
    client.newClient(dest, label)

  private[finagle] class ServerProtoTracing extends ProtoTracing("srv", StackServer.Role.protoTracing)

  object Server {
    val stack: Stack[ServiceFactory[mux.Request, mux.Response]] = StackServer.newStack
      .remove(TraceInitializerFilter.role)
      .replace(StackServer.Role.protoTracing, new ServerProtoTracing)
      .prepend(PayloadSizeFilter.module(_.body.length, _.body.length))

    /**
     * Returns the headers that a server sends to a client.
     *
     * @param clientHeaders The headers received from the client. This is useful since
     * the headers the server responds with can be based on the clients.
     *
     * @param maxFrameSize the maximum mux fragment size the server is willing to
     * receive from a client.
     */
    private[finagle] def headers(
      clientHeaders: Handshake.Headers,
      maxFrameSize: StorageUnit
    ): Handshake.Headers = {
      Seq(MuxFramer.Header.KeyBuf -> MuxFramer.Header.encodeFrameSize(
        maxFrameSize.inBytes.toInt))
    }
  }

  case class Server(
      stack: Stack[ServiceFactory[mux.Request, mux.Response]] = Server.stack,
      params: Stack.Params = StackServer.defaultParams + ProtocolLibrary("mux"))
    extends StdStackServer[mux.Request, mux.Response, Server] {

    protected def copy1(
      stack: Stack[ServiceFactory[mux.Request, mux.Response]] = this.stack,
      params: Stack.Params = this.params
    ): Server = copy(stack, params)

    protected type In = ChannelBuffer
    protected type Out = ChannelBuffer

    private[this] val statsReceiver = params[fparam.Stats].statsReceiver.scope("mux")

    protected def newListener(): Listener[In, Out] =
      Netty3Listener(Netty3Framer, params)

    protected def newDispatcher(
      transport: Transport[In, Out],
      service: Service[mux.Request, mux.Response]
    ): Closable = {
      val fparam.Tracer(tracer) = params[fparam.Tracer]
      val param.MaxFrameSize(maxFrameSize) = params[param.MaxFrameSize]
      val Lessor.Param(lessor) = params[Lessor.Param]

      val negotiatedTrans = mux.Handshake.server(
        trans = transport,
        version = LatestVersion,
        headers = Server.headers(_, maxFrameSize),
        negotiate = negotiate(maxFrameSize, statsReceiver))

      mux.ServerDispatcher.newRequestResponse(
        negotiatedTrans,
        service,
        lessor,
        tracer,
        statsReceiver)
    }
  }

  val server = Server()

  def serve(
    addr: SocketAddress,
    service: ServiceFactory[mux.Request, mux.Response]
  ): ListeningServer = server.serve(addr, service)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy