All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mongodb.internal.connection.tlschannel.async.AsynchronousTlsChannelGroup Maven / Gradle / Ivy

/*
 * Copyright 2008-present MongoDB, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * Original Work: MIT License, Copyright (c) [2015-2018] all contributors
 * https://github.com/marianobarrios/tls-channel
 */

package com.mongodb.internal.connection.tlschannel.async;

import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.tlschannel.NeedsReadException;
import com.mongodb.internal.connection.tlschannel.NeedsTaskException;
import com.mongodb.internal.connection.tlschannel.NeedsWriteException;
import com.mongodb.internal.connection.tlschannel.TlsChannel;
import com.mongodb.internal.connection.tlschannel.impl.ByteBufferSet;
import com.mongodb.internal.connection.tlschannel.util.Util;

import java.io.IOException;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.ReadPendingException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ShutdownChannelGroupException;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritePendingException;
import java.util.Iterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.IntBinaryOperator;
import java.util.function.LongConsumer;

import static java.lang.String.format;

/**
 * This class encapsulates the infrastructure for running {@link AsynchronousTlsChannel}s. Each instance of this class
 * is a singleton-like object that manages a thread pool that makes it possible to run a group of asynchronous
 * channels.
 */
public class AsynchronousTlsChannelGroup {

    private static final Logger LOGGER = Loggers.getLogger("connection.tls");

    /**
     * The main executor of the group has a queue, whose size is a multiple of the number of CPUs.
     */
    private static final int QUEUE_LENGTH_MULTIPLIER = 32;

    private static AtomicInteger globalGroupCount = new AtomicInteger();

    class RegisteredSocket {

        final TlsChannel tlsChannel;
        final SocketChannel socketChannel;

        /**
         * Used to wait until the channel is effectively in the selector (which happens asynchronously to the initial
         * registration.
         */
        final CountDownLatch registered = new CountDownLatch(1);

        SelectionKey key;

        /**
         * Protects {@link #readOperation} reference and instance.
         */
        final Lock readLock = new ReentrantLock();

        /**
         * Protects {@link #writeOperation} reference and instance.
         */
        final Lock writeLock = new ReentrantLock();

        /**
         * Current read operation, in not null
         */
        ReadOperation readOperation;

        /**
         * Current write operation, if not null
         */
        WriteOperation writeOperation;

        /**
         * Bitwise union of pending operation to be registered in the selector
         */
        final AtomicInteger pendingOps = new AtomicInteger();

        RegisteredSocket(final TlsChannel tlsChannel, final SocketChannel socketChannel) {
            this.tlsChannel = tlsChannel;
            this.socketChannel = socketChannel;
        }

        public void close() {
            doCancelRead(this, null);
            doCancelWrite(this, null);
            key.cancel();
            currentRegistrations.getAndDecrement();
            /*
             * Actual de-registration from the selector will happen asynchronously.
             */
            selector.wakeup();
        }
    }

    private abstract static class Operation {
        final ByteBufferSet bufferSet;
        final LongConsumer onSuccess;
        final Consumer onFailure;
        Future timeoutFuture;

        Operation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer onFailure) {
            this.bufferSet = bufferSet;
            this.onSuccess = onSuccess;
            this.onFailure = onFailure;
        }
    }

    static final class ReadOperation extends Operation {
        ReadOperation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer onFailure) {
            super(bufferSet, onSuccess, onFailure);
        }
    }

    static final class WriteOperation extends Operation {

        /**
         * Because a write operation can flag a block (needs read/write) even after the source buffer was read from, we
         * need to accumulate consumed bytes.
         */
        long consumesBytes = 0;

        WriteOperation(final ByteBufferSet bufferSet, final LongConsumer onSuccess, final Consumer onFailure) {
            super(bufferSet, onSuccess, onFailure);
        }
    }

    private final int id = globalGroupCount.getAndIncrement();

    /**
     * With the intention of being spacer with warnings, use this flag to ensure that we only log the warning about
     * needed task once.
     */
    private final AtomicBoolean loggedTaskWarning = new AtomicBoolean();

    private final Selector selector;

    final ExecutorService executor;

    private final ScheduledThreadPoolExecutor timeoutExecutor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
        @Override
        public Thread newThread(final Runnable runnable) {
            return new Thread(runnable, format("async-channel-group-%d-timeout-thread", id));
        }
    }
    );

    private final Thread selectorThread = new Thread(new Runnable() {
        @Override
        public void run() {
            AsynchronousTlsChannelGroup.this.loop();
        }
    }, format("async-channel-group-%d-selector", id));

    private final ConcurrentLinkedQueue pendingRegistrations = new ConcurrentLinkedQueue();

    private enum Shutdown {
        No, Wait, Immediate
    }

    private volatile Shutdown shutdown = Shutdown.No;

    private LongAdder selectionCount = new LongAdder();

    private LongAdder startedReads = new LongAdder();
    private LongAdder startedWrites = new LongAdder();
    private LongAdder successfulReads = new LongAdder();
    private LongAdder successfulWrites = new LongAdder();
    private LongAdder failedReads = new LongAdder();
    private LongAdder failedWrites = new LongAdder();
    private LongAdder cancelledReads = new LongAdder();
    private LongAdder cancelledWrites = new LongAdder();

    // used for synchronization
    private AtomicInteger currentRegistrations = new AtomicInteger();

    private LongAdder currentReads = new LongAdder();
    private LongAdder currentWrites = new LongAdder();

    /**
     * Creates an instance of this class.
     *
     * @param nThreads number of threads in the executor used to assist the selector loop and run completion handlers.
     */
    public AsynchronousTlsChannelGroup(final int nThreads) {
        try {
            selector = Selector.open();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        timeoutExecutor.setRemoveOnCancelPolicy(true);
        this.executor = new ThreadPoolExecutor(
                nThreads, nThreads,
                0, TimeUnit.MILLISECONDS,
                new LinkedBlockingQueue(nThreads * QUEUE_LENGTH_MULTIPLIER),
                new ThreadFactory() {
                    @Override
                    public Thread newThread(final Runnable runnable) {
                        return new Thread(runnable, format("async-channel-group-%d-handler-executor", id));
                    }
                },
                new ThreadPoolExecutor.CallerRunsPolicy());
        selectorThread.start();
    }

    /**
     * Creates an instance of this class, using as many thread as available processors.
     */
    public AsynchronousTlsChannelGroup() {
        this(Runtime.getRuntime().availableProcessors());
    }

    RegisteredSocket registerSocket(final TlsChannel reader, final SocketChannel socketChannel) {
        if (shutdown != Shutdown.No) {
            throw new ShutdownChannelGroupException();
        }
        RegisteredSocket socket = new RegisteredSocket(reader, socketChannel);
        currentRegistrations.getAndIncrement();
        pendingRegistrations.add(socket);
        selector.wakeup();
        return socket;
    }

    boolean doCancelRead(final RegisteredSocket socket, final ReadOperation op) {
        socket.readLock.lock();
        try {
            // a null op means cancel any operation
            if (op != null && socket.readOperation == op || op == null && socket.readOperation != null) {
                socket.readOperation = null;
                cancelledReads.increment();
                currentReads.decrement();
                return true;
            } else {
                return false;
            }
        } finally {
            socket.readLock.unlock();
        }
    }

    boolean doCancelWrite(final RegisteredSocket socket, final WriteOperation op) {
        socket.writeLock.lock();
        try {
            // a null op means cancel any operation
            if (op != null && socket.writeOperation == op || op == null && socket.writeOperation != null) {
                socket.writeOperation = null;
                cancelledWrites.increment();
                currentWrites.decrement();
                return true;
            } else {
                return false;
            }
        } finally {
            socket.writeLock.unlock();
        }
    }

    ReadOperation startRead(
            final RegisteredSocket socket,
            final ByteBufferSet buffer,
            final long timeout, final TimeUnit unit,
            final LongConsumer onSuccess, final Consumer onFailure)
            throws ReadPendingException {
        checkTerminated();
        Util.assertTrue(buffer.hasRemaining());
        waitForSocketRegistration(socket);
        ReadOperation op;
        socket.readLock.lock();
        try {
            if (socket.readOperation != null) {
                throw new ReadPendingException();
            }
            op = new ReadOperation(buffer, onSuccess, onFailure);
            final ReadOperation finalOp = op;
            /*
             * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
             * operation
             */
            socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
            if (timeout != 0) {
                op.timeoutFuture = timeoutExecutor.schedule(new Runnable() {
                    @Override
                    public void run() {
                        boolean success = AsynchronousTlsChannelGroup.this.doCancelRead(socket, finalOp);
                        if (success) {
                            finalOp.onFailure.accept(new InterruptedByTimeoutException());
                        }
                    }
                }, timeout, unit);
            }
            socket.readOperation = op;
        } finally {
            socket.readLock.unlock();
        }
        selector.wakeup();
        startedReads.increment();
        currentReads.increment();
        return op;
    }

    WriteOperation startWrite(
            final RegisteredSocket socket,
            final ByteBufferSet buffer,
            final long timeout, final TimeUnit unit,
            final LongConsumer onSuccess, final Consumer onFailure)
            throws WritePendingException {
        checkTerminated();
        Util.assertTrue(buffer.hasRemaining());
        waitForSocketRegistration(socket);
        WriteOperation op;
        socket.writeLock.lock();
        try {
            if (socket.writeOperation != null) {
                throw new WritePendingException();
            }
            op = new WriteOperation(buffer, onSuccess, onFailure);
            final WriteOperation finalOp = op;
            /*
             * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
             * operation
             */
            socket.pendingOps.set(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
            if (timeout != 0) {
                op.timeoutFuture = timeoutExecutor.schedule(new Runnable() {
                    @Override
                    public void run() {
                        boolean success = AsynchronousTlsChannelGroup.this.doCancelWrite(socket, finalOp);
                        if (success) {
                            finalOp.onFailure.accept(new InterruptedByTimeoutException());
                        }
                    }
                }, timeout, unit);
            }
            socket.writeOperation = op;
        } finally {
            socket.writeLock.unlock();
        }
        selector.wakeup();
        startedWrites.increment();
        currentWrites.increment();
        return op;
    }

    private void checkTerminated() {
        if (isTerminated()) {
            throw new ShutdownChannelGroupException();
        }
    }

    private void waitForSocketRegistration(final RegisteredSocket socket) {
        try {
            socket.registered.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    private void loop() {
        try {
            while (shutdown == Shutdown.No || shutdown == Shutdown.Wait && currentRegistrations.intValue() > 0) {
                int c = selector.select(); // block
                selectionCount.increment();
                // avoid unnecessary creation of iterator object
                if (c > 0) {
                    Iterator it = selector.selectedKeys().iterator();
                    while (it.hasNext()) {
                        SelectionKey key = it.next();
                        it.remove();
                        try {
                            key.interestOps(0);
                        } catch (CancelledKeyException e) {
                            // can happen when channels are closed with pending operations
                            continue;
                        }
                        RegisteredSocket socket = (RegisteredSocket) key.attachment();
                        processRead(socket);
                        processWrite(socket);
                    }
                }
                registerPendingSockets();
                processPendingInterests();
            }
        } catch (Throwable e) {
            LOGGER.error("error in selector loop", e);
        } finally {
            executor.shutdown();
            // use shutdownNow to stop delayed tasks
            timeoutExecutor.shutdownNow();
            if (shutdown == Shutdown.Immediate) {
                for (SelectionKey key : selector.keys()) {
                    RegisteredSocket socket = (RegisteredSocket) key.attachment();
                    socket.close();
                }
            }
            try {
                selector.close();
            } catch (IOException e) {
                LOGGER.warn(format("error closing selector: %s", e.getMessage()));
            }
        }
    }

    private void processPendingInterests() {
        for (SelectionKey key : selector.keys()) {
            RegisteredSocket socket = (RegisteredSocket) key.attachment();
            int pending = socket.pendingOps.getAndSet(0);
            if (pending != 0) {
                try {
                    key.interestOps(key.interestOps() | pending);
                } catch (CancelledKeyException e) {
                    // can happen when channels are closed with pending operations
                    break;
                }
            }
        }
    }

    private void processWrite(final RegisteredSocket socket) {
        socket.writeLock.lock();
        try {
            final WriteOperation op = socket.writeOperation;
            if (op != null) {
                executor.execute(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            AsynchronousTlsChannelGroup.this.doWrite(socket, op);
                        } catch (Throwable e) {
                            LOGGER.error("error in operation", e);
                        }
                    }
                });
            }
        } finally {
            socket.writeLock.unlock();
        }
    }

    private void processRead(final RegisteredSocket socket) {
        socket.readLock.lock();
        try {
            final ReadOperation op = socket.readOperation;
            if (op != null) {
                executor.execute(new Runnable() {
                    @Override
                    public void run() {
                        try {
                            AsynchronousTlsChannelGroup.this.doRead(socket, op);
                        } catch (Throwable e) {
                            LOGGER.error("error in operation", e);
                        }
                    }
                });
            }
        } finally {
            socket.readLock.unlock();
        }
    }

    private void doWrite(final RegisteredSocket socket, final WriteOperation op) {
        socket.writeLock.lock();
        try {
            if (socket.writeOperation != op) {
                return;
            }
            try {
                long before = op.bufferSet.remaining();
                try {
                    writeHandlingTasks(socket, op);
                } finally {
                    long c = before - op.bufferSet.remaining();
                    Util.assertTrue(c >= 0);
                    op.consumesBytes += c;
                }
                socket.writeOperation = null;
                if (op.timeoutFuture != null) {
                    op.timeoutFuture.cancel(false);
                }
                op.onSuccess.accept(op.consumesBytes);
                successfulWrites.increment();
                currentWrites.decrement();
            } catch (NeedsReadException e) {
                socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, new IntBinaryOperator() {
                    @Override
                    public int applyAsInt(final int a, final int b) {
                        return a | b;
                    }
                });
                selector.wakeup();
            } catch (NeedsWriteException e) {
                socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, new IntBinaryOperator() {
                    @Override
                    public int applyAsInt(final int a, final int b) {
                        return a | b;
                    }
                });
                selector.wakeup();
            } catch (IOException e) {
                if (socket.writeOperation == op) {
                    socket.writeOperation = null;
                }
                if (op.timeoutFuture != null) {
                    op.timeoutFuture.cancel(false);
                }
                op.onFailure.accept(e);
                failedWrites.increment();
                currentWrites.decrement();
            }
        } finally {
            socket.writeLock.unlock();
        }
    }

    /**
     * Intended use of the channel group is with sockets that run tasks internally, but out of tolerance, run tasks in
     * thread in case the socket does not.
     */
    private void writeHandlingTasks(final RegisteredSocket socket, final WriteOperation op) throws IOException {
        while (true) {
            try {
                socket.tlsChannel.write(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
                return;
            } catch (NeedsTaskException e) {
                warnAboutNeedTask();
                e.getTask().run();
            }
        }
    }

    private void warnAboutNeedTask() {
        if (!loggedTaskWarning.getAndSet(true)) {
            LOGGER.warn(format(
                    "caught %s; channels used in asynchronous groups should run tasks themselves; "
                            + "although task is being dealt with anyway, consider configuring channels properly",
                    NeedsTaskException.class.getName()));
        }
    }

    private void doRead(final RegisteredSocket socket, final ReadOperation op) {
        socket.readLock.lock();
        try {
            if (socket.readOperation != op) {
                return;
            }
            try {
                Util.assertTrue(op.bufferSet.hasRemaining());
                long c = readHandlingTasks(socket, op);
                Util.assertTrue(c > 0 || c == -1);
                socket.readOperation = null;
                if (op.timeoutFuture != null) {
                    op.timeoutFuture.cancel(false);
                }
                op.onSuccess.accept(c);
                successfulReads.increment();
                currentReads.decrement();
            } catch (NeedsReadException e) {
                socket.pendingOps.accumulateAndGet(SelectionKey.OP_READ, new IntBinaryOperator() {
                    @Override
                    public int applyAsInt(final int a, final int b) {
                        return a | b;
                    }
                });
                selector.wakeup();
            } catch (NeedsWriteException e) {
                socket.pendingOps.accumulateAndGet(SelectionKey.OP_WRITE, new IntBinaryOperator() {
                    @Override
                    public int applyAsInt(final int a, final int b) {
                        return a | b;
                    }
                });
                selector.wakeup();
            } catch (IOException e) {
                if (socket.readOperation == op) {
                    socket.readOperation = null;
                }
                if (op.timeoutFuture != null) {
                    op.timeoutFuture.cancel(false);
                }
                op.onFailure.accept(e);
                failedReads.increment();
                currentReads.decrement();
            }
        } finally {
            socket.readLock.unlock();
        }
    }

    /**
     * @see #writeHandlingTasks
     */
    private long readHandlingTasks(final RegisteredSocket socket, final ReadOperation op) throws IOException {
        while (true) {
            try {
                return socket.tlsChannel.read(op.bufferSet.array, op.bufferSet.offset, op.bufferSet.length);
            } catch (NeedsTaskException e) {
                warnAboutNeedTask();
                e.getTask().run();
            }
        }
    }

    private void registerPendingSockets() throws ClosedChannelException {
        RegisteredSocket socket;
        while ((socket = pendingRegistrations.poll()) != null) {
            socket.key = socket.socketChannel.register(selector, 0, socket);
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace(format("registered key: %ss", socket.key));
            }
            socket.registered.countDown();
        }
    }

    /**
     * Whether either {@link #shutdown()} or {@link #shutdownNow()} have been called.
     *
     * @return {@code true} if this group has initiated shutdown and {@code false} if the group is active
     */
    public boolean isShutdown() {
        return shutdown != Shutdown.No;
    }

    /**
     * Starts the shutdown process. New sockets cannot be registered, already registered one continue operating normally
     * until they are closed.
     */
    public void shutdown() {
        shutdown = Shutdown.Wait;
        selector.wakeup();
    }

    /**
     * Shuts down this channel group immediately. All registered sockets are closed, pending operations may or may not
     * finish.
     */
    public void shutdownNow() {
        shutdown = Shutdown.Immediate;
        selector.wakeup();
    }

    /**
     * Whether this channel group was shut down, and all pending tasks have drained.
     */
    public boolean isTerminated() {
        return executor.isTerminated();
    }

    /**
     * Blocks until all registers sockets are closed and pending tasks finished execution after a shutdown request, or
     * the timeout occurs, or the current thread is interrupted, whichever happens first.
     *
     * @param timeout the maximum time to wait
     * @param unit    the time unit of the timeout argument
     * @return {@code true} if this group terminated and {@code false} if the group elapsed before termination
     * @throws InterruptedException if interrupted while waiting
     */
    public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
        return executor.awaitTermination(timeout, unit);
    }

    long getSelectionCount() {
        return selectionCount.longValue();
    }

    /**
     * Return the total number of read operations that were started.
     *
     * @return number of operations
     */
    public long getStartedReadCount() {
        return startedReads.longValue();
    }

    /**
     * Return the total number of write operations that were started.
     *
     * @return number of operations
     */
    public long getStartedWriteCount() {
        return startedWrites.longValue();
    }

    /**
     * Return the total number of read operations that succeeded.
     *
     * @return number of operations
     */
    public long getSuccessfulReadCount() {
        return successfulReads.longValue();
    }

    /**
     * Return the total number of write operations that succeeded.
     *
     * @return number of operations
     */
    public long getSuccessfulWriteCount() {
        return successfulWrites.longValue();
    }

    /**
     * Return the total number of read operations that failed.
     *
     * @return number of operations
     */
    public long getFailedReadCount() {
        return failedReads.longValue();
    }

    /**
     * Return the total number of write operations that failed.
     *
     * @return number of operations
     */
    public long getFailedWriteCount() {
        return failedWrites.longValue();
    }

    /**
     * Return the total number of read operations that were cancelled.
     *
     * @return number of operations
     */
    public long getCancelledReadCount() {
        return cancelledReads.longValue();
    }

    /**
     * Return the total number of write operations that were cancelled.
     *
     * @return number of operations
     */
    public long getCancelledWriteCount() {
        return cancelledWrites.longValue();
    }

    /**
     * Returns the current number of active read operations.
     *
     * @return number of operations
     */
    public long getCurrentReadCount() {
        return currentReads.longValue();
    }

    /**
     * Returns the current number of active write operations.
     *
     * @return number of operations
     */
    public long getCurrentWriteCount() {
        return currentWrites.longValue();
    }

    /**
     * Returns the current number of registered sockets.
     *
     * @return number of sockets
     */
    public long getCurrentRegistrationCount() {
        return currentRegistrations.longValue();
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy