All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.twitter.finagle.thriftmux.Netty3.scala Maven / Gradle / Ivy

The newest version!
package com.twitter.finagle.thriftmux

import com.twitter.finagle.{mux, Dtab, ThriftMuxUtil}
import com.twitter.finagle.mux.{BadMessageException, Message}
import com.twitter.finagle.netty3.Conversions._
import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver}
import com.twitter.finagle.thrift._
import com.twitter.finagle.thrift.thrift.{
  ResponseHeader, RequestContext, RequestHeader, UpgradeReply}
import com.twitter.finagle.tracing.{Flags, SpanId, TraceContext, TraceId}
import com.twitter.util.{Try, Return, Throw, NonFatal}
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.atomic.AtomicInteger
import org.apache.thrift.protocol.{TProtocolFactory, TMessage, TMessageType}
import org.jboss.netty.buffer.{ChannelBuffers, ChannelBuffer}
import org.jboss.netty.channel._
import scala.collection.mutable.ArrayBuffer

/**
 * A [[org.jboss.netty.channel.ChannelPipelineFactory]] that records the number
 * of open ThriftMux and non-Mux downgraded connections in a pair of
 * [[java.util.concurrent.atomic.AtomicInteger AtomicIntegers]].
 */
private[finagle] class PipelineFactory(
    statsReceiver: StatsReceiver = NullStatsReceiver,
    protocolFactory: TProtocolFactory = Protocols.binaryFactory())
  extends ChannelPipelineFactory
{
  case class UnexpectedRequestException(err: String) extends Exception(err)

  private object TTwitterToMux {
    private val responseHeader = ChannelBuffers.wrappedBuffer(
      OutputBuffer.messageToArray(new ResponseHeader, protocolFactory))
  }

  private class TTwitterToMux extends SimpleChannelHandler {
    import TTwitterToMux._

    private[this] def contextStructToKVTuple(c: RequestContext): (ChannelBuffer, ChannelBuffer) =
      (ChannelBuffers.wrappedBuffer(c.getKey), ChannelBuffers.wrappedBuffer(c.getValue))

    private[this] def thriftToMux(req: ChannelBuffer): Message.Tdispatch = {
      val header = new RequestHeader
      val request_ = InputBuffer.peelMessage(
        ThriftMuxUtil.bufferToArray(req),
        header,
        protocolFactory
      )
      val sampled = if (header.isSetSampled) Some(header.isSampled) else None
      val traceId = TraceId(
        if (header.isSetTrace_id) Some(SpanId(header.getTrace_id)) else None,
        if (header.isSetParent_span_id) Some(SpanId(header.getParent_span_id)) else None,
        SpanId(header.getSpan_id),
        sampled,
        if (header.isSetFlags) Flags(header.getFlags) else Flags()
      )

      val clientIdOpt = Option(header.client_id) map { _.name }
      val contextBuf = ArrayBuffer.empty[(ChannelBuffer, ChannelBuffer)]
      contextBuf += TraceContext.newKVTuple(traceId)
      contextBuf += ClientIdContext.newKVTuple(clientIdOpt)
      if (header.contexts != null) {
        val iter = header.contexts.iterator()
        while (iter.hasNext) {
          contextBuf += contextStructToKVTuple(iter.next())
        }
      }

      Message.Tdispatch(
        Message.MinTag, contextBuf.toSeq, "", Dtab.empty, ChannelBuffers.wrappedBuffer(request_))
    }

    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent)  {
      val buf = e.getMessage.asInstanceOf[ChannelBuffer]
      super.messageReceived(ctx, new UpstreamMessageEvent(
        e.getChannel, Message.encode(thriftToMux(buf)), e.getRemoteAddress))
    }

    override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent) {
      Message.decode(e.getMessage.asInstanceOf[ChannelBuffer]) match {
        case Message.RdispatchOk(_, _, rep) =>
          super.writeRequested(ctx,
            new DownstreamMessageEvent(e.getChannel, e.getFuture,
              ChannelBuffers.wrappedBuffer(responseHeader, rep), e.getRemoteAddress))

        case Message.RdispatchNack(_, _) =>
          // The only mechanism for negative acknowledgement afforded by non-Mux
          // clients is to tear down the connection.
          Channels.close(e.getChannel)

        case Message.Tdrain(_) =>
          // Ignore Tdrains because they are advisory and non-Mux clients
          // cannot handle them.
          e.getFuture.setSuccess()

        case Message.Tping(tag) =>
          e.getFuture.setSuccess()
          super.messageReceived(ctx,
            new UpstreamMessageEvent(
              e.getChannel,
              Message.encode(Message.Rping(tag)),
              e.getRemoteAddress))

        case Message.ControlMessage(tag) =>
          e.getFuture.setSuccess()
          super.messageReceived(ctx,
            new UpstreamMessageEvent(
              e.getChannel,
              Message.encode(
                Message.Rerr(tag, "Unable to send Mux control message to non-Mux client")
              ),
              e.getRemoteAddress))

        case Message.RdispatchError(_, _, error) =>
          // OK to throw an exception here as ServerBridge take cares it
          // by logging the error and then closing the channel.
          throw UnexpectedRequestException(error)

        case unexpected =>
          throw UnexpectedRequestException(
            "Unexpected request type %s".format(unexpected.getClass.getName))
      }
    }
  }

  private class TFramedToMux extends SimpleChannelHandler {
    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent)  {
      val buf = e.getMessage.asInstanceOf[ChannelBuffer]
      super.messageReceived(ctx,
        new UpstreamMessageEvent(
          e.getChannel,
          Message.encode(Message.Tdispatch(Message.MinTag, Seq.empty, "", Dtab.empty, buf)),
          e.getRemoteAddress))
    }

    override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent) {
      Message.decode(e.getMessage.asInstanceOf[ChannelBuffer]) match {
        case Message.RdispatchOk(_, _, rep) =>
          super.writeRequested(ctx,
            new DownstreamMessageEvent(e.getChannel, e.getFuture, rep, e.getRemoteAddress))

        case Message.RdispatchNack(_, _) =>
          // The only mechanism for negative acknowledgement afforded by non-Mux
          // clients is to tear down the connection.
          Channels.close(e.getChannel)

        case Message.Tdrain(_) =>
          // Ignore Tdrains because they are advisory and non-Mux clients
          // cannot handle them.
          e.getFuture.setSuccess()

        // Non-mux clients can't handle T-type control messages, so we
        // simulate responses.
        case Message.Tping(tag) =>
          e.getFuture.setSuccess()
          super.messageReceived(ctx,
            new UpstreamMessageEvent(
              e.getChannel,
              Message.encode(Message.Rping(tag)),
              e.getRemoteAddress))

        case Message.ControlMessage(tag) =>
          e.getFuture.setSuccess()
          super.messageReceived(ctx,
            new UpstreamMessageEvent(
              e.getChannel,
              Message.encode(
                Message.Rerr(tag, "Unable to send Mux control message to non-Mux client")
              ),
              e.getRemoteAddress))

        case Message.RdispatchError(_, _, error) =>
          // OK to throw an exception here as ServerBridge take cares it
          // by logging the error and then closing the channel.
          throw UnexpectedRequestException(error)

        case unexpected =>
          throw UnexpectedRequestException(
            "Unexpected request type %s".format(unexpected.getClass.getName))
      }
    }
  }

  class RequestSerializer(pendingReqs: Int = 0) extends SimpleChannelHandler {
    // Note: Since there can only be at most one pending request at any time,
    // the only race condition that needs to be handled is one thread (a
    // Netty worker thread) executes messageReceived while another thread
    // executes writeRequested (the thread satisfies the request)
    private[this] val q = new LinkedBlockingDeque[MessageEvent]
    private[this] val n = new AtomicInteger(pendingReqs)

    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
      if (n.incrementAndGet() > 1) q.offer(e)
      else super.messageReceived(ctx, e)
    }

    override def writeRequested(ctx: ChannelHandlerContext, e: MessageEvent) {
      super.writeRequested(ctx, e)
      if (n.decrementAndGet() > 0) {
        // Need to call q.take() Since incrementing n and enqueueing the
        // request are not atomic. n>0 guarantees q.take() does not block forever.
        super.messageReceived(ctx, q.take())
      }
    }
  }

  private object Upgrader {
    val upNegotiationAck = {
      val buffer = new OutputBuffer(protocolFactory)
      buffer().writeMessageBegin(
        new TMessage(ThriftTracing.CanTraceMethodName, TMessageType.REPLY, 0))
      val upgradeReply = new UpgradeReply
      upgradeReply.write(buffer())
      buffer().writeMessageEnd()
      ChannelBuffers.copiedBuffer(buffer.toArray)
    }
  }

  private class Upgrader extends SimpleChannelHandler {
    import Upgrader._

    private[this] def isTTwitterUpNegotiation(req: ChannelBuffer): Boolean = {
      try {
        val buffer = new InputBuffer(ThriftMuxUtil.bufferToArray(req), protocolFactory)
        val msg = buffer().readMessageBegin()
        msg.`type` == TMessageType.CALL &&
          msg.name == ThriftTracing.CanTraceMethodName
      } catch {
        case NonFatal(_) => false
      }
    }

    override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
      val buf = e.getMessage.asInstanceOf[ChannelBuffer]
      Try { Message.decode(buf.duplicate()) } match {
        // We assume that a bad message decode indicates an old-style
        // session. Due to Mux message numbering, a binary-encoded
        // thrift frame corresponds to an Rerr message with tag
        // 65537. Note that in this context, an R-message is never
        // valid.
        //
        // Binary-encoded thrift messages have the format
        //
        //     header:4 n:4 method:n seqid:4
        //
        // The header is
        //
        //     0x80010000 | type
        //
        // where the type of CALL is 1; the type of ONEWAY is 4. This makes
        // the first four bytes of a CALL message 0x80010001.
        //
        // Mux messages begin with
        //
        //     Type:1 tag:3
        //
        // Rerr is type 0x80, so we see the above thrift header
        // Rerr corresponds to (tag=0x010001).
        //
        // The hazards of protocol multiplexing.
        case Throw(_: BadMessageException) | Return(Message.Rerr(65537, _)) | Return(Message.Rerr(65540, _))=>
          // Increment ThriftMux connection count stats and wire up a callback to
          // decrement on channel closure.
          downgradedConnects.incr()
          downgradedConnectionCount.incrementAndGet()
          ctx.getChannel.getCloseFuture() onSuccessOrFailure {
            downgradedConnectionCount.decrementAndGet()
          }

          // Add a ChannelHandler to serialize the requests since we may
          // deal with a client that pipelines requests
          ctx.getPipeline.addBefore(ctx.getName, "request_serializer", new RequestSerializer(1))
          if (isTTwitterUpNegotiation(buf)) {
            ctx.getPipeline.replace(this, "twitter_thrift_to_mux", new TTwitterToMux)
            Channels.write(ctx, e.getFuture, upNegotiationAck, e.getRemoteAddress)
          } else {
            ctx.getPipeline.replace(this, "framed_thrift_to_mux", new TFramedToMux)
            super.messageReceived(ctx,
              new UpstreamMessageEvent(
                e.getChannel,
                Message.encode(Message.Tdispatch(Message.MinTag, Seq.empty, "", Dtab.empty, buf)),
                e.getRemoteAddress))
          }

        case Return(_) =>
          // Increment ThriftMux connection count stats and wire up a callback to
          // decrement on channel closure.
          thriftmuxConnects.incr()
          thriftMuxConnectionCount.incrementAndGet()
          ctx.getChannel.getCloseFuture() onSuccessOrFailure {
            thriftMuxConnectionCount.decrementAndGet()
          }

          ctx.getPipeline.remove(this)
          super.messageReceived(ctx, e)

        case Throw(exc) => throw exc
      }
    }
  }

  private[this] val downgradedConnectionCount = new AtomicInteger
  private[this] val thriftMuxConnectionCount = new AtomicInteger

  private[this] val thriftmuxConnects = statsReceiver.counter("connects")
  private[this] val downgradedConnects = statsReceiver.counter("downgraded_connects")
  private[this] val downgradedConnectionGauge =
    statsReceiver.addGauge("downgraded_connections") { downgradedConnectionCount.get() }
  private[this] val thriftmuxConnectionGauge =
    statsReceiver.addGauge("connections") { thriftMuxConnectionCount.get() }

  def getPipeline() = {
    val pipeline = mux.PipelineFactory.getPipeline()
    pipeline.addLast("upgrader", new Upgrader)
    pipeline
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy