All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.r2dbc.postgresql.client.ReactorNettyClient Maven / Gradle / Ivy

There is a newer version: 0.8.13.RELEASE
Show newest version
/*
 * Copyright 2017-2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.r2dbc.postgresql.client;

import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.socket.DatagramChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.postgresql.message.backend.BackendKeyData;
import io.r2dbc.postgresql.message.backend.BackendMessage;
import io.r2dbc.postgresql.message.backend.BackendMessageDecoder;
import io.r2dbc.postgresql.message.backend.ErrorResponse;
import io.r2dbc.postgresql.message.backend.Field;
import io.r2dbc.postgresql.message.backend.NoticeResponse;
import io.r2dbc.postgresql.message.backend.NotificationResponse;
import io.r2dbc.postgresql.message.backend.ParameterStatus;
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
import io.r2dbc.postgresql.message.frontend.Terminate;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import org.reactivestreams.Publisher;
import reactor.core.Disposable;
import reactor.core.publisher.DirectProcessor;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.publisher.SynchronousSink;
import reactor.netty.Connection;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;
import reactor.netty.tcp.TcpClient;
import reactor.netty.tcp.TcpResources;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
import reactor.util.concurrent.Queues;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.StringJoiner;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static io.r2dbc.postgresql.client.TransactionStatus.IDLE;

/**
 * An implementation of client based on the Reactor Netty project.
 *
 * @see TcpClient
 */
public final class ReactorNettyClient implements Client {

    private static final Logger logger = Loggers.getLogger(ReactorNettyClient.class);

    private static final boolean DEBUG_ENABLED = logger.isDebugEnabled();

    private static final Supplier UNEXPECTED = () -> new PostgresConnectionClosedException("Connection unexpectedly closed");

    private static final Supplier EXPECTED = () -> new PostgresConnectionClosedException("Connection closed");

    private final ByteBufAllocator byteBufAllocator;

    private final Connection connection;

    private final EmitterProcessor requestProcessor = EmitterProcessor.create(false);

    private final FluxSink requests = this.requestProcessor.sink();

    private final Queue>> responseReceivers = Queues.>>unbounded().get();

    private final DirectProcessor notificationProcessor = DirectProcessor.create();

    private final AtomicBoolean isClosed = new AtomicBoolean(false);

    private volatile Integer processId;

    private volatile Integer secretKey;

    private volatile TransactionStatus transactionStatus = IDLE;

    private volatile Version version = new Version("", 0);

    /**
     * Creates a new frame processor connected to a given TCP connection.
     *
     * @param connection the TCP connection
     * @throws IllegalArgumentException if {@code connection} is {@code null}
     */
    private ReactorNettyClient(Connection connection) {
        Assert.requireNonNull(connection, "Connection must not be null");

        connection.addHandler(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
        connection.addHandler(new EnsureSubscribersCompleteChannelHandler(this.requestProcessor));
        this.connection = connection;
        this.byteBufAllocator = connection.outbound().alloc();

        AtomicReference receiveError = new AtomicReference<>();
        Mono receive = connection.inbound().receive()
            .map(BackendMessageDecoder::decode)
            .handle(this::handleResponse)
            .doOnError(throwable -> {
                receiveError.set(throwable);
                handleConnectionError(throwable);
            })
            .windowWhile(it -> it.getClass() != ReadyForQuery.class)
            .doOnNext(fluxOfMessages -> {
                MonoSink> receiver = this.responseReceivers.poll();
                if (receiver != null) {
                    receiver.success(fluxOfMessages.doOnComplete(() -> {

                        Throwable throwable = receiveError.get();
                        if (throwable != null) {
                            throw new PostgresConnectionException(throwable);
                        }

                        if (!isConnected()) {
                            throw EXPECTED.get();
                        }
                    }));
                }
            })
            .doOnComplete(this::handleClose)
            .then();

        Mono request = this.requestProcessor
            .flatMap(message -> {
                if (DEBUG_ENABLED) {
                    logger.debug("Request:  {}", message);
                }
                return connection.outbound().send(message.encode(this.byteBufAllocator));
            }, 1)
            .then();

        receive
            .onErrorResume(this::resumeError)
            .subscribe();

        request
            .onErrorResume(this::resumeError)
            .doAfterTerminate(this::handleClose)
            .subscribe();
    }

    private Mono resumeError(Throwable throwable) {

        handleConnectionError(throwable);
        this.requestProcessor.onComplete();

        if (isSslException(throwable)) {
            logger.debug("Connection Error", throwable);
        } else {
            logger.error("Connection Error", throwable);
        }

        return close();
    }

    private static boolean isSslException(Throwable throwable) {
        return throwable instanceof SSLException || throwable.getCause() instanceof SSLException;
    }

    private void handleResponse(BackendMessage message, SynchronousSink sink) {

        if (DEBUG_ENABLED) {
            logger.debug("Response: {}", message);
        }

        if (message.getClass() == NoticeResponse.class) {
            logger.warn("Notice: {}", toString(((NoticeResponse) message).getFields()));
            return;
        }

        if (message.getClass() == BackendKeyData.class) {

            BackendKeyData backendKeyData = (BackendKeyData) message;

            this.processId = backendKeyData.getProcessId();
            this.secretKey = backendKeyData.getSecretKey();
            return;
        }

        if (message.getClass() == ErrorResponse.class) {
            logger.warn("Error: {}", toString(((ErrorResponse) message).getFields()));
        }

        if (message.getClass() == ParameterStatus.class) {
            handleParameterStatus((ParameterStatus) message);
        }

        if (message.getClass() == ReadyForQuery.class) {
            this.transactionStatus = TransactionStatus.valueOf(((ReadyForQuery) message).getTransactionStatus());
        }

        if (message.getClass() == NotificationResponse.class) {
            this.notificationProcessor.onNext((NotificationResponse) message);
            return;
        }

        sink.next(message);
    }

    private void handleParameterStatus(ParameterStatus message) {

        Version existingVersion = this.version;

        String versionString = existingVersion.getVersion();
        int versionNum = existingVersion.getVersionNumber();

        if (message.getName().equals("server_version_num")) {
            versionNum = Integer.parseInt(message.getValue());
        }

        if (message.getName().equals("server_version")) {
            versionString = message.getValue();

            if (versionNum == 0) {
                versionNum = Version.parseServerVersionStr(versionString);
            }
        }

        this.version = new Version(versionString, versionNum);
    }

    /**
     * Creates a new frame processor connected to a given host.
     *
     * @param host the host to connect to
     * @param port the port to connect to
     * @throws IllegalArgumentException if {@code host} is {@code null}
     */
    public static Mono connect(String host, int port) {
        Assert.requireNonNull(host, "host must not be null");

        return connect(host, port, null, new SSLConfig(SSLMode.DISABLE, null, null));
    }

    /**
     * Creates a new frame processor connected to a given host.
     *
     * @param host           the host to connect to
     * @param port           the port to connect to
     * @param connectTimeout connect timeout
     * @param sslConfig      SSL configuration
     * @throws IllegalArgumentException if {@code host} is {@code null}
     */
    public static Mono connect(String host, int port, @Nullable Duration connectTimeout, SSLConfig sslConfig) {
        return connect(ConnectionProvider.newConnection(), InetSocketAddress.createUnresolved(host, port), connectTimeout, sslConfig);
    }

    /**
     * Creates a new frame processor connected to a given host.
     *
     * @param connectionProvider the connection provider resources
     * @param socketAddress      the socketAddress to connect to
     * @param connectTimeout     connect timeout
     * @param sslConfig          SSL configuration
     * @throws IllegalArgumentException if {@code host} is {@code null}
     */
    public static Mono connect(ConnectionProvider connectionProvider, SocketAddress socketAddress, @Nullable Duration connectTimeout, SSLConfig sslConfig) {
        Assert.requireNonNull(connectionProvider, "connectionProvider must not be null");
        Assert.requireNonNull(socketAddress, "socketAddress must not be null");

        TcpClient tcpClient = TcpClient.create(connectionProvider).addressSupplier(() -> socketAddress);

        if (!(socketAddress instanceof InetSocketAddress)) {
            tcpClient = tcpClient.runOn(new SocketLoopResources(), true);
        }

        if (connectTimeout != null) {
            tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(connectTimeout.toMillis()));
        }

        return tcpClient.connect().flatMap(it -> {

            ChannelPipeline pipeline = it.channel().pipeline();

            InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
            if (logger.isTraceEnabled()) {
                pipeline.addFirst(LoggingHandler.class.getSimpleName(),
                    new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
            }

            return registerSslHandler(sslConfig, it).thenReturn(new ReactorNettyClient(it));
        });
    }

    private static Mono registerSslHandler(SSLConfig sslConfig, Connection it) {

        if (sslConfig.getSslMode().startSsl()) {
            SSLSessionHandlerAdapter sslSessionHandlerAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig);
            it.addHandlerFirst(sslSessionHandlerAdapter);
            return sslSessionHandlerAdapter.getHandshake();
        }

        return Mono.empty();
    }

    @Override
    public Mono close() {
        return Mono.defer(() -> {

            drainError(EXPECTED);
            if (this.isClosed.compareAndSet(false, true)) {

                if (!isConnected() || this.processId == null) {
                    this.connection.dispose();
                    return this.connection.onDispose();
                }

                return Flux.just(Terminate.INSTANCE)
                    .doOnNext(message -> logger.debug("Request:  {}", message))
                    .concatMap(message -> this.connection.outbound().send(message.encode(this.connection.outbound().alloc())))
                    .then()
                    .doOnSuccess(v -> this.connection.dispose())
                    .then(this.connection.onDispose());
            }

            return Mono.empty();
        });
    }

    @Override
    public Flux exchange(Publisher requests) {
        Assert.requireNonNull(requests, "requests must not be null");

        return Mono
            .>create(sink -> {

                final AtomicInteger once = new AtomicInteger();

                Flux.from(requests)
                    .subscribe(message -> {

                        if (!isConnected()) {
                            ReferenceCountUtil.safeRelease(message);
                            sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
                            return;
                        }

                        if (once.get() == 0 && once.compareAndSet(0, 1)) {
                            synchronized (this) {
                                this.responseReceivers.add(sink);
                                this.requests.next(message);
                            }
                        } else {
                            this.requests.next(message);
                        }

                    }, this.requests::error, () -> {

                        if (!isConnected()) {
                            sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
                        }
                    });

            })
            .flatMapMany(Function.identity());
    }

    @Override
    public ByteBufAllocator getByteBufAllocator() {
        return this.byteBufAllocator;
    }

    @Override
    public Optional getProcessId() {
        return Optional.ofNullable(this.processId);
    }

    @Override
    public Optional getSecretKey() {
        return Optional.ofNullable(this.secretKey);
    }

    @Override
    public TransactionStatus getTransactionStatus() {
        return this.transactionStatus;
    }

    @Override
    public Version getVersion() {
        return this.version;
    }

    @Override
    public boolean isConnected() {
        if (this.isClosed.get()) {
            return false;
        }

        if (this.requestProcessor.isDisposed()) {
            return false;
        }

        Channel channel = this.connection.channel();
        return channel.isOpen();
    }

    @Override
    public Disposable addNotificationListener(Consumer consumer) {
        return this.notificationProcessor.subscribe(consumer);
    }

    private static String toString(List fields) {

        StringJoiner joiner = new StringJoiner(", ");
        for (Field field : fields) {
            joiner.add(field.getType().name() + "=" + field.getValue());
        }

        return joiner.toString();
    }

    private void handleClose() {
        if (this.isClosed.compareAndSet(false, true)) {
            drainError(UNEXPECTED);
        } else {
            drainError(EXPECTED);
        }
    }

    private void handleConnectionError(Throwable error) {
        drainError(() -> new PostgresConnectionException(error));
    }

    private void drainError(Supplier supplier) {
        MonoSink> receiver;

        while ((receiver = this.responseReceivers.poll()) != null) {
            receiver.error(supplier.get());
        }
    }

    private final class EnsureSubscribersCompleteChannelHandler extends ChannelDuplexHandler {

        private final EmitterProcessor requestProcessor;

        private EnsureSubscribersCompleteChannelHandler(EmitterProcessor requestProcessor) {
            this.requestProcessor = requestProcessor;
        }

        @Override
        public void channelInactive(ChannelHandlerContext ctx) throws Exception {
            super.channelInactive(ctx);
        }

        @Override
        public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
            super.channelUnregistered(ctx);

            this.requestProcessor.onComplete();
            handleClose();
        }
    }

    static class PostgresConnectionClosedException extends R2dbcNonTransientResourceException {

        public PostgresConnectionClosedException(String reason) {
            super(reason);
        }
    }

    static class PostgresConnectionException extends R2dbcNonTransientResourceException {

        public PostgresConnectionException(Throwable cause) {
            super(cause);
        }
    }

    static class SocketLoopResources implements LoopResources {

        @Nullable
        private static final Class EPOLL_SOCKET = findClass("io.netty.channel.epoll.EpollDomainSocketChannel");

        @Nullable
        private static final Class KQUEUE_SOCKET = findClass("io.netty.channel.kqueue.KQueueDomainSocketChannel");

        private static final boolean kqueue;

        static {
            boolean kqueueCheck = false;
            try {
                Class.forName("io.netty.channel.kqueue.KQueue");
                kqueueCheck = io.netty.channel.kqueue.KQueue.isAvailable();
            } catch (ClassNotFoundException cnfe) {
            }
            kqueue = kqueueCheck;
        }

        private static final boolean epoll;

        static {
            boolean epollCheck = false;
            try {
                Class.forName("io.netty.channel.epoll.Epoll");
                epollCheck = Epoll.isAvailable();
            } catch (ClassNotFoundException cnfe) {
            }
            epoll = epollCheck;
        }

        private final LoopResources delegate = TcpResources.get();

        @SuppressWarnings("unchecked")
        private static Class findClass(String className) {
            try {
                return (Class) SocketLoopResources.class.getClassLoader().loadClass(className);
            } catch (ClassNotFoundException e) {
                return null;
            }
        }

        @Override
        public Class onChannel(EventLoopGroup group) {

            if (epoll && EPOLL_SOCKET != null) {
                return EPOLL_SOCKET;
            }

            if (kqueue && KQUEUE_SOCKET != null) {
                return KQUEUE_SOCKET;
            }

            return this.delegate.onChannel(group);
        }

        @Override
        public EventLoopGroup onClient(boolean useNative) {
            return this.delegate.onClient(useNative);
        }

        @Override
        public Class onDatagramChannel(EventLoopGroup group) {
            return this.delegate.onDatagramChannel(group);
        }

        @Override
        public EventLoopGroup onServer(boolean useNative) {
            return this.delegate.onServer(useNative);
        }

        @Override
        public Class onServerChannel(EventLoopGroup group) {
            return this.delegate.onServerChannel(group);
        }

        @Override
        public EventLoopGroup onServerSelect(boolean useNative) {
            return this.delegate.onServerSelect(useNative);
        }

        @Override
        public boolean preferNative() {
            return this.delegate.preferNative();
        }

        @Override
        public boolean daemon() {
            return this.delegate.daemon();
        }

        @Override
        public void dispose() {
            this.delegate.dispose();
        }

        @Override
        public Mono disposeLater() {
            return this.delegate.disposeLater();
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy