All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.netty.handler.ssl.SniHandler Maven / Gradle / Ivy

/*
 * Copyright 2014 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.netty.handler.ssl;

import java.net.IDN;
import java.net.SocketAddress;
import java.util.List;
import java.util.Locale;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.DecoderException;
import io.netty.util.AsyncMapping;
import io.netty.util.CharsetUtil;
import io.netty.util.DomainNameMapping;
import io.netty.util.Mapping;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
 * 

Enables SNI * (Server Name Indication) extension for server side SSL. For clients * support SNI, the server could have multiple host name bound on a single IP. * The client will send host name in the handshake data so server could decide * which certificate to choose for the host name.

*/ public class SniHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { // Maximal number of ssl records to inspect before fallback to the default SslContext. private static final int MAX_SSL_RECORDS = 4; private static final InternalLogger logger = InternalLoggerFactory.getInstance(SniHandler.class); private static final Selection EMPTY_SELECTION = new Selection(null, null); protected final AsyncMapping mapping; private boolean handshakeFailed; private boolean suppressRead; private boolean readPending; private volatile Selection selection = EMPTY_SELECTION; /** * Creates a SNI detection handler with configured {@link SslContext} * maintained by {@link Mapping} * * @param mapping the mapping of domain name to {@link SslContext} */ public SniHandler(Mapping mapping) { this(new AsyncMappingAdapter(mapping)); } /** * Creates a SNI detection handler with configured {@link SslContext} * maintained by {@link DomainNameMapping} * * @param mapping the mapping of domain name to {@link SslContext} */ public SniHandler(DomainNameMapping mapping) { this((Mapping) mapping); } /** * Creates a SNI detection handler with configured {@link SslContext} * maintained by {@link AsyncMapping} * * @param mapping the mapping of domain name to {@link SslContext} */ @SuppressWarnings("unchecked") public SniHandler(AsyncMapping mapping) { this.mapping = (AsyncMapping) ObjectUtil.checkNotNull(mapping, "mapping"); } /** * @return the selected hostname */ public String hostname() { return selection.hostname; } /** * @return the selected {@link SslContext} */ public SslContext sslContext() { return selection.context; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { if (!suppressRead && !handshakeFailed) { final int writerIndex = in.writerIndex(); try { loop: for (int i = 0; i < MAX_SSL_RECORDS; i++) { final int readerIndex = in.readerIndex(); final int readableBytes = writerIndex - readerIndex; if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { // Not enough data to determine the record type and length. return; } final int command = in.getUnsignedByte(readerIndex); // tls, but not handshake command switch (command) { case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: case SslUtils.SSL_CONTENT_TYPE_ALERT: final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); // Not an SSL/TLS packet if (len == SslUtils.NOT_ENCRYPTED) { handshakeFailed = true; NotSslRecordException e = new NotSslRecordException( "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); in.skipBytes(in.readableBytes()); SslUtils.notifyHandshakeFailure(ctx, e); throw e; } if (len == SslUtils.NOT_ENOUGH_DATA || writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) { // Not enough data return; } // increase readerIndex and try again. in.skipBytes(len); continue; case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: final int majorVersion = in.getUnsignedByte(readerIndex + 1); // SSLv3 or TLS if (majorVersion == 3) { final int packetLength = in.getUnsignedShort(readerIndex + 3) + SslUtils.SSL_RECORD_HEADER_LENGTH; if (readableBytes < packetLength) { // client hello incomplete; try again to decode once more data is ready. return; } // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 // // Decode the ssl client hello packet. // We have to skip bytes until SessionID (which sum to 43 bytes). // // struct { // ProtocolVersion client_version; // Random random; // SessionID session_id; // CipherSuite cipher_suites<2..2^16-2>; // CompressionMethod compression_methods<1..2^8-1>; // select (extensions_present) { // case false: // struct {}; // case true: // Extension extensions<0..2^16-1>; // }; // } ClientHello; // final int endOffset = readerIndex + packetLength; int offset = readerIndex + 43; if (endOffset - offset < 6) { break loop; } final int sessionIdLength = in.getUnsignedByte(offset); offset += sessionIdLength + 1; final int cipherSuitesLength = in.getUnsignedShort(offset); offset += cipherSuitesLength + 2; final int compressionMethodLength = in.getUnsignedByte(offset); offset += compressionMethodLength + 1; final int extensionsLength = in.getUnsignedShort(offset); offset += 2; final int extensionsLimit = offset + extensionsLength; if (extensionsLimit > endOffset) { // Extensions should never exceed the record boundary. break loop; } for (;;) { if (extensionsLimit - offset < 4) { break loop; } final int extensionType = in.getUnsignedShort(offset); offset += 2; final int extensionLength = in.getUnsignedShort(offset); offset += 2; if (extensionsLimit - offset < extensionLength) { break loop; } // SNI // See https://tools.ietf.org/html/rfc6066#page-6 if (extensionType == 0) { offset += 2; if (extensionsLimit - offset < 3) { break loop; } final int serverNameType = in.getUnsignedByte(offset); offset++; if (serverNameType == 0) { final int serverNameLength = in.getUnsignedShort(offset); offset += 2; if (extensionsLimit - offset < serverNameLength) { break loop; } final String hostname = in.toString(offset, serverNameLength, CharsetUtil.UTF_8); try { select(ctx, IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US)); } catch (Throwable t) { PlatformDependent.throwException(t); } return; } else { // invalid enum value break loop; } } offset += extensionLength; } } // Fall-through default: //not tls, ssl or application data, do not try sni break loop; } } } catch (Throwable e) { // unexpected encoding, ignore sni and use default if (logger.isDebugEnabled()) { logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); } } // Just select the default SslContext select(ctx, null); } } private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception { Future future = lookup(ctx, hostname); if (future.isDone()) { if (future.isSuccess()) { onSslContext(ctx, hostname, future.getNow()); } else { throw new DecoderException("failed to get the SslContext for " + hostname, future.cause()); } } else { suppressRead = true; future.addListener(new FutureListener() { @Override public void operationComplete(Future future) throws Exception { try { suppressRead = false; if (future.isSuccess()) { try { onSslContext(ctx, hostname, future.getNow()); } catch (Throwable cause) { ctx.fireExceptionCaught(new DecoderException(cause)); } } else { ctx.fireExceptionCaught(new DecoderException("failed to get the SslContext for " + hostname, future.cause())); } } finally { if (readPending) { readPending = false; ctx.read(); } } } }); } } /** * The default implementation will simply call {@link AsyncMapping#map(Object, Promise)} but * users can override this method to implement custom behavior. * * @see AsyncMapping#map(Object, Promise) */ protected Future lookup(ChannelHandlerContext ctx, String hostname) throws Exception { return mapping.map(hostname, ctx.executor().newPromise()); } /** * Called upon successful completion of the {@link AsyncMapping}'s {@link Future}. * * @see #select(ChannelHandlerContext, String) */ private void onSslContext(ChannelHandlerContext ctx, String hostname, SslContext sslContext) { selection = new Selection(sslContext, hostname); try { replaceHandler(ctx, hostname, sslContext); } catch (Throwable cause) { selection = EMPTY_SELECTION; PlatformDependent.throwException(cause); } } /** * The default implementation of this method will simply replace {@code this} {@link SniHandler} * instance with a {@link SslHandler}. Users may override this method to implement custom behavior. * * Please be aware that this method may get called after a client has already disconnected and * custom implementations must take it into consideration when overriding this method. * * It's also possible for the hostname argument to be {@code null}. */ protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception { SslHandler sslHandler = null; try { sslHandler = sslContext.newHandler(ctx.alloc()); ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); sslHandler = null; } finally { // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not // transferred to the SslHandler. // See https://github.com/netty/netty/issues/5678 if (sslHandler != null) { ReferenceCountUtil.safeRelease(sslHandler.engine()); } } } @Override public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { ctx.bind(localAddress, promise); } @Override public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { ctx.connect(remoteAddress, localAddress, promise); } @Override public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.disconnect(promise); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.close(promise); } @Override public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { ctx.deregister(promise); } @Override public void read(ChannelHandlerContext ctx) throws Exception { if (suppressRead) { readPending = true; } else { ctx.read(); } } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { ctx.write(msg, promise); } @Override public void flush(ChannelHandlerContext ctx) throws Exception { ctx.flush(); } private static final class AsyncMappingAdapter implements AsyncMapping { private final Mapping mapping; private AsyncMappingAdapter(Mapping mapping) { this.mapping = ObjectUtil.checkNotNull(mapping, "mapping"); } @Override public Future map(String input, Promise promise) { final SslContext context; try { context = mapping.map(input); } catch (Throwable cause) { return promise.setFailure(cause); } return promise.setSuccess(context); } } private static final class Selection { final SslContext context; final String hostname; Selection(SslContext context, String hostname) { this.context = context; this.hostname = hostname; } } }