
com.twitter.finagle.thriftmux.ThriftEmulator.scala Maven / Gradle / Ivy
package com.twitter.finagle.thriftmux
import com.twitter.concurrent.AsyncQueue
import com.twitter.finagle.mux.transport.{BadMessageException, Message}
import com.twitter.finagle.stats.StatsReceiver
import com.twitter.finagle.thrift._
import com.twitter.finagle.thrift.thrift.{RequestHeader, ResponseHeader, UpgradeReply}
import com.twitter.finagle.tracing.Trace
import com.twitter.finagle.transport.{Transport, TransportProxy}
import com.twitter.finagle.{Failure, Path, Dtab}
import com.twitter.io.Buf
import com.twitter.logging.Level
import com.twitter.util.{Future, NonFatal, Try, Return, Promise, Throw, Updatable}
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Logger
import org.apache.thrift.protocol.{TProtocolFactory, TMessage, TMessageType}
import scala.collection.mutable
/**
* A [[com.twitter.finagle.transport.Transport]] that manages the downgrading
* of mux server sessions to plain thrift or twitter thrift. Because this is used in the
* context of the mux server dispatcher, it's important that when we downgrade we
* faithfully emulate the mux protocol.
*/
private[finagle] object ThriftEmulator {
private[this] val log = Logger.getLogger(getClass.getName)
/**
* A thread-safe swappable Transport reference.
*/
private class TransportRef[In, Out](
init: Transport[In, Out])
extends TransportProxy[In, Out](init)
with Updatable[Transport[In, Out]] {
@volatile private[this] var cur: Transport[In, Out] = init
override def self: Transport[In, Out] = cur
def write(in: In): Future[Unit] = self.write(in)
def read(): Future[Out] = self.read()
def update(trans: Transport[In, Out]): Unit = {
cur = trans
}
}
/**
* Creates a transport that is capable of wrapping a Mux server transport
* and can support vanilla Thrift and TTwitterThrift clients.
*/
def apply(
underlying: Transport[Buf, Buf],
protocolFactory: TProtocolFactory,
sr: StatsReceiver
): Transport[Buf, Buf] = {
val transportP = new Promise[Transport[Buf, Buf]]
// The swap on complete structure allows us to reduce the
// footprint in the common case where we don't need to
// downgrade to thrift.
val init = new Init(underlying, transportP, protocolFactory, sr)
val ref = new TransportRef[Buf, Buf](init)
transportP.onSuccess(ref.update(_))
ref
}
private class Init(
underlying: Transport[Buf, Buf],
transportP: Promise[Transport[Buf, Buf]],
protocolFactory: TProtocolFactory,
sr: StatsReceiver)
extends TransportProxy[Buf, Buf](underlying) {
private[this] val downgradedConnectionCount = new AtomicInteger
private[this] val thriftMuxConnectionCount = new AtomicInteger
private[this] val thriftmuxConnects = sr.counter("connects")
private[this] val downgradedConnects = sr.counter("downgraded_connects")
private[this] val gauges = Seq(
sr.addGauge("downgraded_connections") { downgradedConnectionCount.get() },
sr.addGauge("connections") { thriftMuxConnectionCount.get() }
)
// queues writes while we determine the type of session.
private[this] val writeq = new AsyncQueue[Buf]
// drain `writeq` into the the transport sequentially
private[this] def drain(): Future[Unit] =
if (writeq.size == 0) Future.Done
else writeq.poll()
.flatMap(underlying.write)
.before { drain() }
// initiate drain when the new transport is set.
transportP.ensure { drain() }
def write(buf: Buf): Future[Unit] = {
if (writeq.offer(buf)) Future.Done
else Future.exception(Failure("unable to enqueue write"))
}
def read(): Future[Buf] =
underlying.read().flatMap { buf =>
Try { Message.decode(buf) } match {
// We assume that a bad message decode indicates a thrift
// session. Due to Mux message numbering, a binary-encoded
// thrift frame corresponds to an Rerr message with tag
// 65537. Note that in this context, an R-message is never
// valid.
//
// Binary-encoded thrift messages have the format
//
// header:4 n:4 method:n seqid:4
//
// The header is
//
// 0x80010000 | type
//
// where the type of CALL is 1; the type of ONEWAY is 4. This makes
// the first four bytes of a CALL message 0x80010001.
//
// Mux messages begin with
//
// Type:1 tag:3
//
// Rerr is type 0x80, so we see the above thrift header
// Rerr corresponds to (tag=0x010001).
//
// The hazards of protocol multiplexing.
case Throw(_: BadMessageException) |
Return(Message.Rerr(65537, _)) |
Return(Message.Rerr(65540, _)) =>
downgradedConnects.incr()
downgradedConnectionCount.incrementAndGet()
underlying.onClose.ensure { downgradedConnectionCount.decrementAndGet() }
val trans = new Emulator(underlying, protocolFactory, buf)
transportP.setValue(trans)
trans.read()
// We have a valid mux session, return the original
// transport untouched.
case Return(r) =>
thriftmuxConnects.incr()
thriftMuxConnectionCount.incrementAndGet()
underlying.onClose.ensure { thriftMuxConnectionCount.decrementAndGet() }
transportP.setValue(underlying)
Future.value(buf)
case Throw(exc) =>
val msg = s"Unable to determine the protocol: $exc"
log.log(Level.DEBUG, msg)
transportP.setValue(underlying)
close().before {
Future.exception(Failure(msg).withLogLevel(Level.DEBUG))
}
}
}
}
private class Emulator(
underlying: Transport[Buf, Buf],
protocolFactory: TProtocolFactory,
init: Buf)
extends TransportProxy[Buf, Buf](underlying) {
// A boolean which indicates if we are speaking twitter upgraded thrift.
private[this] val ttwitter: Boolean = {
try {
val buffer = new InputBuffer(
Buf.ByteArray.Owned.extract(init),
protocolFactory)
val msg = buffer().readMessageBegin()
msg.`type` == TMessageType.CALL &&
msg.name == ThriftTracing.CanTraceMethodName
} catch {
case NonFatal(_) => false
}
}
// An encoded header message for TTwitter thrift using `protocolFactory`.
private[this] val ttwitterHeader = Buf.ByteArray.Owned(
OutputBuffer.messageToArray(new ResponseHeader, protocolFactory))
// An encoded ack message for TTwitter thrift using `protocolFactory`.
private[this] val ttwitterAck: Buf = {
val buffer = new OutputBuffer(protocolFactory)
buffer().writeMessageBegin(
new TMessage(ThriftTracing.CanTraceMethodName, TMessageType.REPLY, 0))
val upgradeReply = new UpgradeReply
upgradeReply.write(buffer())
buffer().writeMessageEnd()
Buf.ByteArray.Shared(buffer.toArray)
}
/**
* Lowers a mux message into a Thrift message where possible and writes
* the result to `underlying`.
*/
private[this] def writeMuxToThrift(buf: Buf): Future[Unit] =
Message.decode(buf) match {
case Message.RdispatchOk(_, _, rep) if ttwitter =>
underlying.write(ttwitterHeader.concat(rep))
case Message.RdispatchOk(_, _, rep) =>
underlying.write(rep)
case Message.RdispatchNack(_, _) =>
// The only mechanism for negative acknowledgement afforded by non-Mux
// clients is to tear down the connection.
close()
Future.Done
case Message.Tdrain(tag) =>
// Although downgraded connections don't understand Tdrains,
// we synthesize an Rdrain so the server dispatcher enters draining
// mode.
readq.offer(Message.encode(Message.Rdrain(tag)))
Future.Done
case unexpected =>
// we can't write this, so we signal failure to the remote
// by tearing down the session.
close()
// log here to surface the error
val msg = s"unable to write ${unexpected.getClass.getName} to non-mux client"
log.log(Level.DEBUG, msg)
// return a failure to the level above us.
Future.exception(Failure(msg).withLogLevel(Level.DEBUG))
}
/**
* Returns a Mux.Tdispatch from a thrift dispatch message.
*/
private[this] def thriftToMux(buf: Buf): Message = {
// It's okay to use a static tag since we serialize messages into
// the dispatcher so we are ensured no tag conflicts.
val tag = Message.Tags.MinTag
if (!ttwitter) {
Message.Tdispatch(tag, Nil, Path.empty, Dtab.empty, buf)
} else {
val header = new RequestHeader
val request = InputBuffer.peelMessage(
Buf.ByteArray.Owned.extract(buf),
header,
protocolFactory
)
val richHeader = new RichRequestHeader(header)
val contextBuf =
new mutable.ArrayBuffer[(Buf, Buf)](
2 + (if (header.contexts == null) 0 else header.contexts.size))
contextBuf += (Trace.idCtx.marshalId -> Trace.idCtx.marshal(richHeader.traceId))
richHeader.clientId match {
case Some(clientId) =>
val clientIdBuf = ClientId.clientIdCtx.marshal(Some(clientId))
contextBuf += ClientId.clientIdCtx.marshalId -> clientIdBuf
case None =>
}
if (header.contexts != null) {
val iter = header.contexts.iterator()
while (iter.hasNext) {
val c = iter.next()
contextBuf += (
Buf.ByteArray.Owned(c.getKey) -> Buf.ByteArray.Owned(c.getValue)
)
}
}
val requestBuf = Buf.ByteArray.Owned(request)
Message.Tdispatch(tag, contextBuf.toSeq, richHeader.dest,
richHeader.dtab, requestBuf)
}
}
// We proxy reads via a queue so that we can synthesize incoming messages.
private[this] val readq = new AsyncQueue[Buf]
private[this] def readLoop(): Future[Unit] =
underlying.read().flatMap(processRead)
private[this] val processRead: Buf => Future[Unit] =
buf => {
readq.offer(Message.encode(thriftToMux(buf)))
readLoop()
}
if (ttwitter) {
// write the TTwitter ack
underlying.write(ttwitterAck)
} else {
// we are speaking vanilla thrift, encode `init` as a mux dispatch.
readq.offer(Message.encode(thriftToMux(init)))
}
// kick off readLoop and propagate failure.
readLoop().onFailure { exc => readq.fail(exc) }
def write(buf: Buf): Future[Unit] = writeMuxToThrift(buf)
def read(): Future[Buf] = readq.poll()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy