com.twitter.finagle.http.Codec.scala Maven / Gradle / Ivy
package com.twitter.finagle.http
import com.twitter.conversions.storage._
import com.twitter.finagle._
import com.twitter.finagle.dispatch.GenSerialClientDispatcher
import com.twitter.finagle.filter.PayloadSizeFilter
import com.twitter.finagle.http.codec._
import com.twitter.finagle.http.filter.{ClientContextFilter, DtabFilter, HttpNackFilter, ServerContextFilter}
import com.twitter.finagle.http.netty.{Netty3ClientStreamTransport, Netty3ServerStreamTransport}
import com.twitter.finagle.stats.{NullStatsReceiver, StatsReceiver, ServerStatsReceiver}
import com.twitter.finagle.tracing._
import com.twitter.finagle.transport.Transport
import com.twitter.util.{NonFatal, Closable, StorageUnit, Try}
import java.net.InetSocketAddress
import org.jboss.netty.channel.{Channel, ChannelEvent, ChannelHandlerContext, ChannelPipelineFactory, Channels, UpstreamMessageEvent}
import org.jboss.netty.handler.codec.http._
private[finagle] case class BadHttpRequest(
httpVersion: HttpVersion, method: HttpMethod, uri: String, exception: Throwable)
extends DefaultHttpRequest(httpVersion, method, uri)
object BadHttpRequest {
def apply(exception: Throwable): BadHttpRequest =
new BadHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/bad-http-request", exception)
}
private[finagle] sealed trait BadReq
private[finagle] trait ContentTooLong extends BadReq
private[finagle] trait UriTooLong extends BadReq
private[finagle] trait HeaderFieldsTooLarge extends BadReq
private[http] case class BadRequest(httpRequest: HttpRequest, exception: Throwable)
extends Request with BadReq {
lazy val remoteSocketAddress = new InetSocketAddress(0)
}
private[finagle] object BadRequest {
def apply(msg: BadHttpRequest): BadRequest =
new BadRequest(msg, msg.exception)
def apply(exn: Throwable): BadRequest = {
val msg = new BadHttpRequest(
HttpVersion.HTTP_1_0,
HttpMethod.GET,
"/bad-http-request",
exn
)
apply(msg)
}
def contentTooLong(msg: BadHttpRequest): BadRequest with ContentTooLong =
new BadRequest(msg, msg.exception) with ContentTooLong
def contentTooLong(exn: Throwable): BadRequest with ContentTooLong = {
val msg = new BadHttpRequest(
HttpVersion.HTTP_1_0,
HttpMethod.GET,
"/bad-http-request",
exn
)
contentTooLong(msg)
}
def uriTooLong(msg: BadHttpRequest): BadRequest with UriTooLong =
new BadRequest(msg, msg.exception) with UriTooLong
def uriTooLong(exn: Throwable): BadRequest with UriTooLong = {
val msg = new BadHttpRequest(
HttpVersion.HTTP_1_0,
HttpMethod.GET,
"/bad-http-request",
exn
)
uriTooLong(msg)
}
def headerTooLong(msg: BadHttpRequest): BadRequest with HeaderFieldsTooLarge =
new BadRequest(msg, msg.exception) with HeaderFieldsTooLarge
def headerTooLong(exn: Throwable): BadRequest with HeaderFieldsTooLarge = {
val msg = new BadHttpRequest(
HttpVersion.HTTP_1_0,
HttpMethod.GET,
"/bad-http-request",
exn
)
headerTooLong(msg)
}
}
/**
* a HttpChunkAggregator which recovers decode failures into 4xx http responses
*/
private[http] class SafeServerHttpChunkAggregator(maxContentSizeBytes: Int) extends HttpChunkAggregator(maxContentSizeBytes) {
override def handleUpstream(ctx: ChannelHandlerContext, e: ChannelEvent): Unit = {
try {
super.handleUpstream(ctx, e)
} catch {
case NonFatal(ex) =>
val channel = ctx.getChannel()
ctx.sendUpstream(new UpstreamMessageEvent(
channel, BadHttpRequest(ex), channel.getRemoteAddress()))
}
}
}
/** Convert exceptions to BadHttpRequests */
class SafeHttpServerCodec(
maxInitialLineLength: Int,
maxHeaderSize: Int,
maxChunkSize: Int)
extends HttpServerCodec(maxInitialLineLength, maxHeaderSize, maxChunkSize)
{
override def handleUpstream(ctx: ChannelHandlerContext, e: ChannelEvent) {
// this only catches Codec exceptions -- when a handler calls sendUpStream(), it
// rescues exceptions from the upstream handlers and calls notifyHandlerException(),
// which doesn't throw exceptions.
try {
super.handleUpstream(ctx, e)
} catch {
case ex: Exception =>
val channel = ctx.getChannel()
ctx.sendUpstream(new UpstreamMessageEvent(
channel, BadHttpRequest(ex), channel.getRemoteAddress()))
}
}
}
/**
* @param _compressionLevel The compression level to use. If passed the default value (-1) then use
* [[com.twitter.finagle.http.codec.TextualContentCompressor TextualContentCompressor]] which will
* compress text-like content-types with the default compression level (6). Otherwise, use
* [[org.jboss.netty.handler.codec.http.HttpContentCompressor HttpContentCompressor]] for all
* content-types with specified compression level.
*
* @param _maxRequestSize The maximum size of the inbound request an HTTP server constructed with
* this codec can receive (default is 5 megabytes). Should be less than 2 gigabytes (up to
* `Int.MaxValue` bytes). Use streaming/chunked requests to handle larger messages.
*
* @param _maxResponseSize The maximum size of the inbound response an HTTP client constructed with
* this codec can receive (default is 5 megabytes). Should be less than 2 gigabytes (up to
* `Int.MaxValue` bytes). Use streaming/chunked requests to handle larger messages.
*
* @param _streaming Streaming allows applications to work with HTTP messages
* that have large (or infinite) content bodies. When this flag is set to
* `true`, the message content is available through a [[com.twitter.io.Reader]],
* which gives the application a handle to the byte stream. If `false`, the
* entire message content is buffered into a [[com.twitter.io.Buf]].
*/
case class Http(
_compressionLevel: Int = -1,
_maxRequestSize: StorageUnit = 5.megabytes,
_maxResponseSize: StorageUnit = 5.megabytes,
_decompressionEnabled: Boolean = true,
_channelBufferUsageTracker: Option[ChannelBufferUsageTracker] = None,
_annotateCipherHeader: Option[String] = None,
_enableTracing: Boolean = false,
_maxInitialLineLength: StorageUnit = 4096.bytes,
_maxHeaderSize: StorageUnit = 8192.bytes,
_streaming: Boolean = false,
_statsReceiver: StatsReceiver = NullStatsReceiver
) extends CodecFactory[Request, Response] {
def this(
_compressionLevel: Int,
_maxRequestSize: StorageUnit,
_maxResponseSize: StorageUnit,
_decompressionEnabled: Boolean,
_channelBufferUsageTracker: Option[ChannelBufferUsageTracker],
_annotateCipherHeader: Option[String],
_enableTracing: Boolean,
_maxInitialLineLength: StorageUnit,
_maxHeaderSize: StorageUnit,
_streaming: Boolean
) =
this(
_compressionLevel,
_maxRequestSize,
_maxResponseSize,
_decompressionEnabled,
_channelBufferUsageTracker,
_annotateCipherHeader,
_enableTracing,
_maxInitialLineLength,
_maxHeaderSize,
_streaming,
NullStatsReceiver)
require(_maxRequestSize < 2.gigabytes,
s"maxRequestSize should be less than 2 Gb, but was ${_maxRequestSize}")
require(_maxResponseSize < 2.gigabytes,
s"maxResponseSize should be less than 2 Gb, but was ${_maxResponseSize}")
def compressionLevel(level: Int) = copy(_compressionLevel = level)
def maxRequestSize(bufferSize: StorageUnit) = copy(_maxRequestSize = bufferSize)
def maxResponseSize(bufferSize: StorageUnit) = copy(_maxResponseSize = bufferSize)
def decompressionEnabled(yesno: Boolean) = copy(_decompressionEnabled = yesno)
@deprecated("Use maxRequestSize to enforce buffer footprint limits", "2016-05-10")
def channelBufferUsageTracker(usageTracker: ChannelBufferUsageTracker) =
copy(_channelBufferUsageTracker = Some(usageTracker))
def annotateCipherHeader(headerName: String) = copy(_annotateCipherHeader = Option(headerName))
def enableTracing(enable: Boolean) = copy(_enableTracing = enable)
def maxInitialLineLength(length: StorageUnit) = copy(_maxInitialLineLength = length)
def maxHeaderSize(size: StorageUnit) = copy(_maxHeaderSize = size)
def streaming(enable: Boolean) = copy(_streaming = enable)
def client = { config =>
new Codec[Request, Response] {
def pipelineFactory = new ChannelPipelineFactory {
def getPipeline() = {
val pipeline = Channels.pipeline()
val maxInitialLineLengthInBytes = _maxInitialLineLength.inBytes.toInt
val maxHeaderSizeInBytes = _maxHeaderSize.inBytes.toInt
val maxChunkSize = 8192
pipeline.addLast(
"httpCodec", new HttpClientCodec(
maxInitialLineLengthInBytes, maxHeaderSizeInBytes, maxChunkSize))
if (!_streaming)
pipeline.addLast(
"httpDechunker",
new HttpChunkAggregator(_maxResponseSize.inBytes.toInt))
if (_decompressionEnabled)
pipeline.addLast("httpDecompressor", new HttpContentDecompressor)
pipeline
}
}
override def prepareServiceFactory(
underlying: ServiceFactory[Request, Response]
): ServiceFactory[Request, Response] =
underlying.map(new DelayedReleaseService(_))
override def prepareConnFactory(
underlying: ServiceFactory[Request, Response],
params: Stack.Params
): ServiceFactory[Request, Response] =
// Note: This is a horrible hack to ensure that close() calls from
// ExpiringService do not propagate until all chunks have been read
// Waiting on CSL-915 for a proper fix.
underlying.map { u =>
val filters =
new ClientContextFilter[Request, Response]
.andThen(new DtabFilter.Injector)
.andThenIf(!_streaming ->
new PayloadSizeFilter[Request, Response](
params[param.Stats].statsReceiver, _.content.length, _.content.length
)
)
filters.andThen(new DelayedReleaseService(u))
}
override def newClientTransport(ch: Channel, statsReceiver: StatsReceiver): Transport[Any,Any] =
super.newClientTransport(ch, statsReceiver)
override def newClientDispatcher(transport: Transport[Any, Any], params: Stack.Params) =
new HttpClientDispatcher(
new HttpTransport(new Netty3ClientStreamTransport(transport)),
params[param.Stats].statsReceiver.scope(GenSerialClientDispatcher.StatsScope)
)
override def newTraceInitializer =
if (_enableTracing) new HttpClientTraceInitializer[Request, Response]
else TraceInitializerFilter.empty[Request, Response]
override def protocolLibraryName: String = Http.this.protocolLibraryName
}
}
def server = { config =>
new Codec[Request, Response] {
def pipelineFactory = new ChannelPipelineFactory {
def getPipeline() = {
val pipeline = Channels.pipeline()
if (_channelBufferUsageTracker.isDefined) {
pipeline.addLast(
"channelBufferManager", new ChannelBufferManager(_channelBufferUsageTracker.get))
}
val maxRequestSizeInBytes = _maxRequestSize.inBytes.toInt
val maxInitialLineLengthInBytes = _maxInitialLineLength.inBytes.toInt
val maxHeaderSizeInBytes = _maxHeaderSize.inBytes.toInt
pipeline.addLast("httpCodec", new SafeHttpServerCodec(maxInitialLineLengthInBytes, maxHeaderSizeInBytes, maxRequestSizeInBytes))
if (_compressionLevel > 0) {
pipeline.addLast("httpCompressor", new HttpContentCompressor(_compressionLevel))
} else if (_compressionLevel == -1) {
pipeline.addLast("httpCompressor", new TextualContentCompressor)
}
if (_decompressionEnabled)
pipeline.addLast("httpDecompressor", new HttpContentDecompressor)
// The payload size handler should come before the RespondToExpectContinue handler so that we don't
// send a 100 CONTINUE for oversize requests we have no intention of handling.
pipeline.addLast("payloadSizeHandler", new PayloadSizeHandler(maxRequestSizeInBytes))
// Response to ``Expect: Continue'' requests.
pipeline.addLast("respondToExpectContinue", new RespondToExpectContinue)
if (!_streaming)
pipeline.addLast(
"httpDechunker",
new SafeServerHttpChunkAggregator(maxRequestSizeInBytes))
_annotateCipherHeader foreach { headerName: String =>
pipeline.addLast("annotateCipher", new AnnotateCipher(headerName))
}
pipeline
}
}
override def newServerDispatcher(
transport: Transport[Any, Any],
service: Service[Request, Response]
): Closable = new HttpServerDispatcher(
new HttpTransport(new Netty3ServerStreamTransport(transport)),
service,
ServerStatsReceiver)
override def prepareConnFactory(
underlying: ServiceFactory[Request, Response],
params: Stack.Params
): ServiceFactory[Request, Response] = {
val param.Stats(stats) = params[param.Stats]
new HttpNackFilter(stats)
.andThen(new DtabFilter.Extractor)
.andThen(new ServerContextFilter[Request, Response])
.andThenIf(!_streaming -> new PayloadSizeFilter[Request, Response](
stats, _.content.length, _.content.length)
)
.andThen(underlying)
}
override def newTraceInitializer =
if (_enableTracing) new HttpServerTraceInitializer[Request, Response]
else TraceInitializerFilter.empty[Request, Response]
}
}
override val protocolLibraryName: String = "http"
}
object Http {
def get() = new Http()
}
object HttpTracing {
/**
* HTTP headers used for tracing.
*
* See [[headers()]] for Java compatibility.
*/
object Header {
val TraceId = "X-B3-TraceId"
val SpanId = "X-B3-SpanId"
val ParentSpanId = "X-B3-ParentSpanId"
val Sampled = "X-B3-Sampled"
val Flags = "X-B3-Flags"
val All = Seq(TraceId, SpanId, ParentSpanId, Sampled, Flags)
val Required = Seq(TraceId, SpanId)
}
/** Java compatibility API for [[Header]]. */
def headers(): Header.type = Header
/**
* Remove any parameters from url.
*/
private[http] def stripParameters(uri: String): String = {
uri.indexOf('?') match {
case -1 => uri
case n => uri.substring(0, n)
}
}
}
private object TraceInfo {
import HttpTracing._
def letTraceIdFromRequestHeaders[R](request: Request)(f: => R): R = {
val id = if (Header.Required.forall { request.headers.contains(_) }) {
val spanId = SpanId.fromString(request.headers.get(Header.SpanId))
spanId map { sid =>
val traceId = SpanId.fromString(request.headers.get(Header.TraceId))
val parentSpanId = SpanId.fromString(request.headers.get(Header.ParentSpanId))
val sampled = Option(request.headers.get(Header.Sampled)) flatMap { sampled =>
Try(sampled.toBoolean).toOption
}
val flags = getFlags(request)
TraceId(traceId, parentSpanId, sid, sampled, flags)
}
} else if (request.headers.contains(Header.Flags)) {
// even if there are no id headers we want to get the debug flag
// this is to allow developers to just set the debug flag to ensure their
// trace is collected
Some(Trace.nextId.copy(flags = getFlags(request)))
} else {
Some(Trace.nextId)
}
// remove so the header is not visible to users
Header.All foreach { request.headers.remove(_) }
id match {
case Some(id) =>
Trace.letId(id) {
traceRpc(request)
f
}
case None =>
traceRpc(request)
f
}
}
def setClientRequestHeaders(request: Request): Unit = {
Header.All.foreach { request.headers.remove(_) }
val traceId = Trace.id
request.headers.add(Header.TraceId, traceId.traceId.toString)
request.headers.add(Header.SpanId, traceId.spanId.toString)
// no parent id set means this is the root span
traceId._parentId.foreach { id =>
request.headers.add(Header.ParentSpanId, id.toString)
}
// three states of sampled, yes, no or none (let the server decide)
traceId.sampled.foreach { sampled =>
request.headers.add(Header.Sampled, sampled.toString)
}
request.headers.add(Header.Flags, traceId.flags.toLong)
traceRpc(request)
}
def traceRpc(request: Request): Unit = {
if (Trace.isActivelyTracing) {
Trace.recordRpc(request.getMethod.getName)
Trace.recordBinary("http.uri", stripParameters(request.getUri))
}
}
/**
* Safely extract the flags from the header, if they exist. Otherwise return empty flag.
*/
def getFlags(request: Request): Flags = {
try {
Flags(Option(request.headers.get(Header.Flags)).map(_.toLong).getOrElse(0L))
} catch {
case _: Throwable => Flags()
}
}
}
private[finagle] class HttpServerTraceInitializer[Req <: Request, Rep]
extends Stack.Module1[param.Tracer, ServiceFactory[Req, Rep]] {
val role = TraceInitializerFilter.role
val description = "Initialize the tracing system with trace info from the incoming request"
def make(_tracer: param.Tracer, next: ServiceFactory[Req, Rep]) = {
val param.Tracer(tracer) = _tracer
val traceInitializer = Filter.mk[Req, Rep, Req, Rep] { (req, svc) =>
Trace.letTracer(tracer) {
TraceInfo.letTraceIdFromRequestHeaders(req) { svc(req) }
}
}
traceInitializer andThen next
}
}
private[finagle] class HttpClientTraceInitializer[Req <: Request, Rep]
extends Stack.Module1[param.Tracer, ServiceFactory[Req, Rep]] {
val role = TraceInitializerFilter.role
val description = "Sets the next TraceId and attaches trace information to the outgoing request"
def make(_tracer: param.Tracer, next: ServiceFactory[Req, Rep]) = {
val param.Tracer(tracer) = _tracer
val traceInitializer = Filter.mk[Req, Rep, Req, Rep] { (req, svc) =>
Trace.letTracerAndNextId(tracer) {
TraceInfo.setClientRequestHeaders(req)
svc(req)
}
}
traceInitializer andThen next
}
}
/**
* Pass along headers with the required tracing information.
*/
private[finagle] class HttpClientTracingFilter[Req <: Request, Res](serviceName: String)
extends SimpleFilter[Req, Res]
{
def apply(request: Req, service: Service[Req, Res]) = {
TraceInfo.setClientRequestHeaders(request)
service(request)
}
}
/**
* Adds tracing annotations for each http request we receive.
* Including uri, when request was sent and when it was received.
*/
private[finagle] class HttpServerTracingFilter[Req <: Request, Res](serviceName: String)
extends SimpleFilter[Req, Res]
{
def apply(request: Req, service: Service[Req, Res]) =
TraceInfo.letTraceIdFromRequestHeaders(request) {
service(request)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy