io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandler Maven / Gradle / Ivy
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx.extensions;
import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.internal.UnstableApi;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
/**
* This handler negotiates and initializes the WebSocket Extensions.
*
* It negotiates the extensions based on the client desired order,
* ensures that the successfully negotiated extensions are consistent between them,
* and initializes the channel pipeline with the extension decoder and encoder.
*
* Find a basic implementation for compression extensions at
* io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler.
*/
public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
private final List extensionHandshakers;
private final Queue> validExtensions =
new ArrayDeque>(4);
/**
* Constructor
*
* @param extensionHandshakers
* The extension handshaker in priority order. A handshaker could be repeated many times
* with fallback configuration.
*/
public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
// JDK type checks vs non-implemented interfaces costs O(N), where
// N is the number of interfaces already implemented by the concrete type that's being tested.
// The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
// and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
// singleton and/or concrete types, to save performing such slow type checks.
if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
if (msg instanceof DefaultHttpRequest) {
// fast-path
onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
} else if (msg instanceof HttpRequest) {
// slow path
onHttpRequestChannelRead(ctx, (HttpRequest) msg);
} else {
super.channelRead(ctx, msg);
}
} else {
super.channelRead(ctx, msg);
}
}
/**
* This is a method exposed to perform fail-fast checks of user-defined http types.
* eg:
* If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
* {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
* types too, can override it like this:
*
* public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
* if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
* if (msg instanceof CustomHttpRequest) {
* onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
* } else {
* // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
* // or have to delegate it to super.channelRead (that can perform redundant checks).
* // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
* // ...
* super.channelRead(ctx, msg);
* }
* } else {
* // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
* ctx.fireChannelRead(msg);
* }
* }
*
* IMPORTANT:
* It already call {@code super.channelRead(ctx, request)} before returning.
*/
@UnstableApi
protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
List validExtensionsList = null;
if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
if (extensionsHeader != null) {
List extensions =
WebSocketExtensionUtil.extractExtensions(extensionsHeader);
int rsv = 0;
for (WebSocketExtensionData extensionData : extensions) {
Iterator extensionHandshakersIterator =
extensionHandshakers.iterator();
WebSocketServerExtension validExtension = null;
while (validExtension == null && extensionHandshakersIterator.hasNext()) {
WebSocketServerExtensionHandshaker extensionHandshaker =
extensionHandshakersIterator.next();
validExtension = extensionHandshaker.handshakeExtension(extensionData);
}
if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
if (validExtensionsList == null) {
validExtensionsList = new ArrayList(1);
}
rsv = rsv | validExtension.rsv();
validExtensionsList.add(validExtension);
}
}
}
}
if (validExtensionsList == null) {
validExtensionsList = Collections.emptyList();
}
validExtensions.offer(validExtensionsList);
super.channelRead(ctx, request);
}
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
if (msg instanceof DefaultHttpResponse) {
onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
} else if (msg instanceof HttpResponse) {
onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
} else {
super.write(ctx, msg, promise);
}
} else {
super.write(ctx, msg, promise);
}
}
/**
* This is a method exposed to perform fail-fast checks of user-defined http types.
* eg:
* If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
* {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
*
* public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
* if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
* if (msg instanceof CustomHttpResponse) {
* onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
* } else {
* // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
* // or have to delegate it to super.write (that can perform redundant checks).
* // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
* // ...
* super.write(ctx, msg, promise);
* }
* } else {
* // given that msg isn't a HttpResponse type we can just skip calling super.write
* ctx.write(msg, promise);
* }
* }
*
* IMPORTANT:
* It already call {@code super.write(ctx, response, promise)} before returning.
*/
@UnstableApi
protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
throws Exception {
List validExtensionsList = validExtensions.poll();
HttpResponse httpResponse = response;
//checking the status is faster than looking at headers
//so we do this first
if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(httpResponse.status())) {
handlePotentialUpgrade(ctx, promise, httpResponse, validExtensionsList);
}
super.write(ctx, response, promise);
}
private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
ChannelPromise promise, HttpResponse httpResponse,
final List validExtensionsList) {
HttpHeaders headers = httpResponse.headers();
if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
List extraExtensions =
new ArrayList(extensionHandshakers.size());
for (WebSocketServerExtension extension : validExtensionsList) {
extraExtensions.add(extension.newReponseData());
}
String newHeaderValue = WebSocketExtensionUtil
.computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) {
for (WebSocketServerExtension extension : validExtensionsList) {
WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
String name = ctx.name();
ctx.pipeline()
.addAfter(name, decoder.getClass().getName(), decoder)
.addAfter(name, encoder.getClass().getName(), encoder);
}
}
}
});
if (newHeaderValue != null) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
}
}
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) {
ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
}
}
});
}
}
}