All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mongodb.connection.netty.NettyStream Maven / Gradle / Ivy

Go to download

The MongoDB Java Driver uber-artifact, containing mongodb-driver, mongodb-driver-core, and bson

There is a newer version: 3.12.14
Show newest version
/*
 * Copyright 2008-present MongoDB, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.mongodb.connection.netty;

import com.mongodb.MongoClientException;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
import com.mongodb.MongoInterruptedException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.connection.Stream;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.util.concurrent.EventExecutor;
import org.bson.ByteBuf;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.io.IOException;
import java.net.SocketAddress;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;

import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
import static com.mongodb.internal.connection.SslHelper.enableSni;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
 * A Stream implementation based on Netty 4.0.
 */
final class NettyStream implements Stream {
    private static final String READ_HANDLER_NAME = "ReadTimeoutHandler";
    private final ServerAddress address;
    private final SocketSettings settings;
    private final SslSettings sslSettings;
    private final EventLoopGroup workerGroup;
    private final Class socketChannelClass;
    private final ByteBufAllocator allocator;

    private volatile boolean isClosed;
    private volatile Channel channel;

    private final LinkedList pendingInboundBuffers = new LinkedList();
    private volatile PendingReader pendingReader;
    private volatile Throwable pendingException;

    NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup,
                final Class socketChannelClass, final ByteBufAllocator allocator) {
        this.address = address;
        this.settings = settings;
        this.sslSettings = sslSettings;
        this.workerGroup = workerGroup;
        this.socketChannelClass = socketChannelClass;
        this.allocator = allocator;
    }

    @Override
    public ByteBuf getBuffer(final int size) {
        return new NettyByteBuf(allocator.buffer(size, size));
    }

    @Override
    public void open() throws IOException {
        FutureAsyncCompletionHandler handler = new FutureAsyncCompletionHandler();
        openAsync(handler);
        handler.get();
    }

    @Override
    public void openAsync(final AsyncCompletionHandler handler) {
        initializeChannel(handler, new LinkedList(address.getSocketAddresses()));
    }

    @SuppressWarnings("deprecation")
    private void initializeChannel(final AsyncCompletionHandler handler, final Queue socketAddressQueue) {
        if (socketAddressQueue.isEmpty()) {
            handler.failed(new MongoSocketException("Exception opening socket", getAddress()));
        } else {
            SocketAddress nextAddress = socketAddressQueue.poll();

            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(workerGroup);
            bootstrap.channel(socketChannelClass);

            bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeout(MILLISECONDS));
            bootstrap.option(ChannelOption.TCP_NODELAY, true);
            bootstrap.option(ChannelOption.SO_KEEPALIVE, settings.isKeepAlive());

            if (settings.getReceiveBufferSize() > 0) {
                bootstrap.option(ChannelOption.SO_RCVBUF, settings.getReceiveBufferSize());
            }
            if (settings.getSendBufferSize() > 0) {
                bootstrap.option(ChannelOption.SO_SNDBUF, settings.getSendBufferSize());
            }
            bootstrap.option(ChannelOption.ALLOCATOR, allocator);

            bootstrap.handler(new ChannelInitializer() {
                @Override
                public void initChannel(final SocketChannel ch) {
                    if (sslSettings.isEnabled()) {
                        SSLEngine engine = getSslContext().createSSLEngine(address.getHost(), address.getPort());
                        engine.setUseClientMode(true);
                        SSLParameters sslParameters = engine.getSSLParameters();
                        enableSni(address.getHost(), sslParameters);
                        if (!sslSettings.isInvalidHostNameAllowed()) {
                            enableHostNameVerification(sslParameters);
                        }
                        engine.setSSLParameters(sslParameters);
                        ch.pipeline().addFirst("ssl", new SslHandler(engine, false));
                    }
                    int readTimeout = settings.getReadTimeout(MILLISECONDS);
                    if (readTimeout > 0) {
                        ch.pipeline().addLast(READ_HANDLER_NAME, new ReadTimeoutHandler(readTimeout));
                    }
                    ch.pipeline().addLast(new InboundBufferHandler());
                }
            });
            final ChannelFuture channelFuture = bootstrap.connect(nextAddress);
            channelFuture.addListener(new OpenChannelFutureListener(socketAddressQueue, channelFuture, handler));
        }
    }

    @Override
    public void write(final List buffers) throws IOException {
        FutureAsyncCompletionHandler future = new FutureAsyncCompletionHandler();
        writeAsync(buffers, future);
        future.get();
    }

    @Override
    public ByteBuf read(final int numBytes) throws IOException {
        FutureAsyncCompletionHandler future = new FutureAsyncCompletionHandler();
        readAsync(numBytes, future);
        return future.get();
    }

    @Override
    public void writeAsync(final List buffers, final AsyncCompletionHandler handler) {
        CompositeByteBuf composite = PooledByteBufAllocator.DEFAULT.compositeBuffer();
        for (ByteBuf cur : buffers) {
            composite.addComponent(true, ((NettyByteBuf) cur).asByteBuf());
        }

        channel.writeAndFlush(composite).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(final ChannelFuture future) throws Exception {
                if (!future.isSuccess()) {
                    handler.failed(future.cause());
                } else {
                    handler.completed(null);
                }
            }
        });
    }

    @Override
    public void readAsync(final int numBytes, final AsyncCompletionHandler handler) {
        scheduleReadTimeout();
        ByteBuf buffer = null;
        Throwable exceptionResult = null;
        synchronized (this) {
            exceptionResult = pendingException;
            if (exceptionResult == null) {
                if (!hasBytesAvailable(numBytes)) {
                    pendingReader = new PendingReader(numBytes, handler);
                } else {
                    CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size());
                    int bytesNeeded = numBytes;
                    for (Iterator iter = pendingInboundBuffers.iterator(); iter.hasNext();) {
                        io.netty.buffer.ByteBuf next = iter.next();
                        int bytesNeededFromCurrentBuffer = Math.min(next.readableBytes(), bytesNeeded);
                        if (bytesNeededFromCurrentBuffer == next.readableBytes()) {
                            composite.addComponent(next);
                            iter.remove();
                        } else {
                            next.retain();
                            composite.addComponent(next.readSlice(bytesNeededFromCurrentBuffer));
                        }
                        composite.writerIndex(composite.writerIndex() + bytesNeededFromCurrentBuffer);
                        bytesNeeded -= bytesNeededFromCurrentBuffer;
                        if (bytesNeeded == 0) {
                            break;
                        }
                    }
                    buffer = new NettyByteBuf(composite).flip();
                }
            }
        }
        if (exceptionResult != null) {
            disableReadTimeout();
            handler.failed(exceptionResult);
        }
        if (buffer != null) {
            disableReadTimeout();
            handler.completed(buffer);
        }
    }

    private boolean hasBytesAvailable(final int numBytes) {
        int bytesAvailable = 0;
        for (io.netty.buffer.ByteBuf cur : pendingInboundBuffers) {
            bytesAvailable += cur.readableBytes();
            if (bytesAvailable >= numBytes) {
                return true;
            }
        }
        return false;
    }

    private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Throwable t) {
        PendingReader localPendingReader = null;
        synchronized (this) {
            if (buffer != null) {
                pendingInboundBuffers.add(buffer.retain());
            } else {
                pendingException = t;
            }
            if (pendingReader != null) {
                localPendingReader = pendingReader;
                pendingReader = null;
            }
        }

        if (localPendingReader != null) {
            readAsync(localPendingReader.numBytes, localPendingReader.handler);
        }
    }

    @Override
    public ServerAddress getAddress() {
        return address;
    }

    @Override
    public void close() {
        isClosed = true;
        if (channel != null) {
            channel.close();
            channel = null;
        }
        for (Iterator iterator = pendingInboundBuffers.iterator(); iterator.hasNext();) {
            io.netty.buffer.ByteBuf nextByteBuf = iterator.next();
            iterator.remove();
            nextByteBuf.release();
        }
    }

    @Override
    public boolean isClosed() {
        return isClosed;
    }

    public SocketSettings getSettings() {
        return settings;
    }

    public SslSettings getSslSettings() {
        return sslSettings;
    }

    public EventLoopGroup getWorkerGroup() {
        return workerGroup;
    }

    public Class getSocketChannelClass() {
        return socketChannelClass;
    }

    public ByteBufAllocator getAllocator() {
        return allocator;
    }

    private SSLContext getSslContext() {
        try {
            return (sslSettings.getContext() == null) ? SSLContext.getDefault() : sslSettings.getContext();
        } catch (NoSuchAlgorithmException e) {
            throw new MongoClientException("Unable to create default SSLContext", e);
        }
    }

    private class InboundBufferHandler extends SimpleChannelInboundHandler {
        @Override
        protected void channelRead0(final ChannelHandlerContext ctx, final io.netty.buffer.ByteBuf buffer) throws Exception {
            handleReadResponse(buffer, null);
        }

        @Override
        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable t) {
            if (t instanceof ReadTimeoutException) {
                handleReadResponse(null, new MongoSocketReadTimeoutException("Timeout while receiving message", address, t));
            } else {
                handleReadResponse(null, t);
            }
            ctx.close();
        }
    }

    private static final class PendingReader {
        private final int numBytes;
        private final AsyncCompletionHandler handler;

        private PendingReader(final int numBytes, final AsyncCompletionHandler handler) {
            this.numBytes = numBytes;
            this.handler = handler;
        }
    }

    private static final class FutureAsyncCompletionHandler implements AsyncCompletionHandler {
        private final CountDownLatch latch = new CountDownLatch(1);
        private volatile T t;
        private volatile Throwable throwable;

        FutureAsyncCompletionHandler() {
        }

        @Override
        public void completed(final T t) {
            this.t = t;
            latch.countDown();
        }

        @Override
        public void failed(final Throwable t) {
            this.throwable = t;
            latch.countDown();
        }

        public T get() throws IOException {
            try {
                latch.await();
                if (throwable != null) {
                    if (throwable instanceof IOException) {
                        throw (IOException) throwable;
                    } else if (throwable instanceof MongoException) {
                        throw (MongoException) throwable;
                    } else {
                        throw new MongoInternalException("Exception thrown from Netty Stream", throwable);
                    }
                }
                return t;
            } catch (InterruptedException e) {
                throw new MongoInterruptedException("Interrupted", e);
            }
        }
    }

    private class OpenChannelFutureListener implements ChannelFutureListener {
        private final Queue socketAddressQueue;
        private final ChannelFuture channelFuture;
        private final AsyncCompletionHandler handler;

        OpenChannelFutureListener(final Queue socketAddressQueue, final ChannelFuture channelFuture,
                                  final AsyncCompletionHandler handler) {
            this.socketAddressQueue = socketAddressQueue;
            this.channelFuture = channelFuture;
            this.handler = handler;
        }

        @Override
        public void operationComplete(final ChannelFuture future) {
            if (future.isSuccess()) {
                channel = channelFuture.channel();
                channel.closeFuture().addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(final ChannelFuture future) {
                        handleReadResponse(null, new IOException("The connection to the server was closed"));
                    }
                });
                handler.completed(null);
            } else {
                if (socketAddressQueue.isEmpty()) {
                    handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), future.cause()));
                } else {
                    initializeChannel(handler, socketAddressQueue);
                }
            }
        }
    }

    private void scheduleReadTimeout() {
        adjustTimeout(false);
    }

    private void disableReadTimeout() {
        adjustTimeout(true);
    }

    private void adjustTimeout(final boolean disable) {
            ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME);
            if (timeoutHandler != null) {
                final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler;
                final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler);
                EventExecutor executor = handlerContext.executor();

                if (disable) {
                    if (executor.inEventLoop()) {
                        readTimeoutHandler.removeTimeout(handlerContext);
                    } else {
                        executor.submit(new Runnable() {
                            @Override
                            public void run() {
                                readTimeoutHandler.removeTimeout(handlerContext);
                            }
                        });
                    }
                } else {
                    if (executor.inEventLoop()) {
                        readTimeoutHandler.scheduleTimeout(handlerContext);
                    } else {
                        executor.submit(new Runnable() {
                            @Override
                            public void run() {
                                readTimeoutHandler.scheduleTimeout(handlerContext);
                            }
                        });
                    }
                }
            }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy