io.r2dbc.postgresql.client.ReactorNettyClient Maven / Gradle / Ivy
/*
* Copyright 2017-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.r2dbc.postgresql.client;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.socket.DatagramChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.postgresql.message.backend.BackendKeyData;
import io.r2dbc.postgresql.message.backend.BackendMessage;
import io.r2dbc.postgresql.message.backend.BackendMessageDecoder;
import io.r2dbc.postgresql.message.backend.ErrorResponse;
import io.r2dbc.postgresql.message.backend.Field;
import io.r2dbc.postgresql.message.backend.NoticeResponse;
import io.r2dbc.postgresql.message.backend.NotificationResponse;
import io.r2dbc.postgresql.message.backend.ParameterStatus;
import io.r2dbc.postgresql.message.backend.ReadyForQuery;
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
import io.r2dbc.postgresql.message.frontend.Terminate;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.R2dbcTransientResourceException;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.core.Disposable;
import reactor.core.publisher.DirectProcessor;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.netty.Connection;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;
import reactor.netty.tcp.TcpClient;
import reactor.netty.tcp.TcpResources;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
import reactor.util.concurrent.Queues;
import reactor.util.context.Context;
import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.StringJoiner;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import static io.r2dbc.postgresql.client.TransactionStatus.IDLE;
/**
* An implementation of client based on the Reactor Netty project.
*
* @see TcpClient
*/
public final class ReactorNettyClient implements Client {
private static final Logger logger = Loggers.getLogger(ReactorNettyClient.class);
private static final boolean DEBUG_ENABLED = logger.isDebugEnabled();
private static final Supplier UNEXPECTED = () -> new PostgresConnectionClosedException("Connection unexpectedly closed");
private static final Supplier EXPECTED = () -> new PostgresConnectionClosedException("Connection closed");
private final ByteBufAllocator byteBufAllocator;
private final Connection connection;
private final EmitterProcessor> requestProcessor = EmitterProcessor.create(false);
private final FluxSink> requests = this.requestProcessor.sink();
private final DirectProcessor notificationProcessor = DirectProcessor.create();
private final AtomicBoolean isClosed = new AtomicBoolean(false);
private final BackendMessageSubscriber messageSubscriber = new BackendMessageSubscriber();
private volatile Integer processId;
private volatile Integer secretKey;
private volatile TransactionStatus transactionStatus = IDLE;
private volatile Version version = new Version("", 0);
/**
* Creates a new frame processor connected to a given TCP connection.
*
* @param connection the TCP connection
* @throws IllegalArgumentException if {@code connection} is {@code null}
*/
private ReactorNettyClient(Connection connection) {
Assert.requireNonNull(connection, "Connection must not be null");
connection.addHandler(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
connection.addHandler(new EnsureSubscribersCompleteChannelHandler(this.requestProcessor));
this.connection = connection;
this.byteBufAllocator = connection.outbound().alloc();
AtomicReference receiveError = new AtomicReference<>();
connection.inbound().receive()
.map(BackendMessageDecoder::decode)
.doOnError(throwable -> {
receiveError.set(throwable);
handleConnectionError(throwable);
})
.handle((backendMessage, sink) -> {
if (consumeMessage(backendMessage)) {
return;
}
sink.next(backendMessage);
})
.subscribe(this.messageSubscriber);
Mono request = this.requestProcessor
.concatMap(Function.identity())
.flatMap(message -> {
if (DEBUG_ENABLED) {
logger.debug("Request: {}", message);
}
return connection.outbound().send(message.encode(this.byteBufAllocator));
}, 1)
.then();
request
.onErrorResume(this::resumeError)
.doAfterTerminate(this::handleClose)
.subscribe();
}
@Override
public Mono close() {
return Mono.defer(() -> {
if (!this.notificationProcessor.isTerminated()) {
this.notificationProcessor.onComplete();
}
drainError(EXPECTED);
boolean connected = isConnected();
if (this.isClosed.compareAndSet(false, true)) {
if (!connected || this.processId == null) {
this.connection.dispose();
return this.connection.onDispose();
}
return Flux.just(Terminate.INSTANCE)
.doOnNext(message -> logger.debug("Request: {}", message))
.concatMap(message -> this.connection.outbound().send(message.encode(this.connection.outbound().alloc())))
.then()
.doOnSuccess(v -> this.connection.dispose())
.then(this.connection.onDispose());
}
return Mono.empty();
});
}
@Override
public Flux exchange(Predicate takeUntil, Publisher requests) {
Assert.requireNonNull(takeUntil, "takeUntil must not be null");
Assert.requireNonNull(requests, "requests must not be null");
return this.messageSubscriber.addConversation(takeUntil, requests, this.requests::next, this::isConnected);
}
@Override
public void send(FrontendMessage message) {
Assert.requireNonNull(message, "requests must not be null");
this.requests.next(Mono.just(message));
}
private Mono resumeError(Throwable throwable) {
handleConnectionError(throwable);
this.requestProcessor.onComplete();
if (isSslException(throwable)) {
logger.debug("Connection Error", throwable);
} else {
logger.error("Connection Error", throwable);
}
return close();
}
private static boolean isSslException(Throwable throwable) {
return throwable instanceof SSLException || throwable.getCause() instanceof SSLException;
}
/**
* Consume a {@link BackendMessage}. This method can either fully consume the message or it can signal by returning {@literal false} that the method wasn't able to fully consume the message and
* that the message needs to be passed to an active {@link Conversation}.
*
* @param message
* @return {@literal false} if the message could not be fully consumed and should be propagated to the active {@link Conversation}.
*/
private boolean consumeMessage(BackendMessage message) {
if (DEBUG_ENABLED) {
logger.debug("Response: {}", message);
}
if (message.getClass() == NoticeResponse.class) {
logger.warn("Notice: {}", toString(((NoticeResponse) message).getFields()));
return true;
}
if (message.getClass() == BackendKeyData.class) {
BackendKeyData backendKeyData = (BackendKeyData) message;
this.processId = backendKeyData.getProcessId();
this.secretKey = backendKeyData.getSecretKey();
return true;
}
if (message.getClass() == ErrorResponse.class) {
logger.warn("Error: {}", toString(((ErrorResponse) message).getFields()));
}
if (message.getClass() == ParameterStatus.class) {
handleParameterStatus((ParameterStatus) message);
}
if (message.getClass() == ReadyForQuery.class) {
this.transactionStatus = TransactionStatus.valueOf(((ReadyForQuery) message).getTransactionStatus());
}
if (message.getClass() == NotificationResponse.class) {
this.notificationProcessor.onNext((NotificationResponse) message);
return true;
}
return false;
}
private void handleParameterStatus(ParameterStatus message) {
Version existingVersion = this.version;
String versionString = existingVersion.getVersion();
int versionNum = existingVersion.getVersionNumber();
if (message.getName().equals("server_version_num")) {
versionNum = Integer.parseInt(message.getValue());
}
if (message.getName().equals("server_version")) {
versionString = message.getValue();
if (versionNum == 0) {
versionNum = Version.parseServerVersionStr(versionString);
}
}
this.version = new Version(versionString, versionNum);
}
/**
* Creates a new frame processor connected to a given host.
*
* @param host the host to connect to
* @param port the port to connect to
* @throws IllegalArgumentException if {@code host} is {@code null}
*/
public static Mono connect(String host, int port) {
Assert.requireNonNull(host, "host must not be null");
return connect(host, port, null, new SSLConfig(SSLMode.DISABLE, null, null));
}
/**
* Creates a new frame processor connected to a given host.
*
* @param host the host to connect to
* @param port the port to connect to
* @param connectTimeout connect timeout
* @param sslConfig SSL configuration
* @throws IllegalArgumentException if {@code host} is {@code null}
*/
public static Mono connect(String host, int port, @Nullable Duration connectTimeout, SSLConfig sslConfig) {
return connect(ConnectionProvider.newConnection(), InetSocketAddress.createUnresolved(host, port), connectTimeout, sslConfig);
}
/**
* Creates a new frame processor connected to a given host.
*
* @param connectionProvider the connection provider resources
* @param socketAddress the socketAddress to connect to
* @param connectTimeout connect timeout
* @param sslConfig SSL configuration
* @throws IllegalArgumentException if {@code host} is {@code null}
*/
public static Mono connect(ConnectionProvider connectionProvider, SocketAddress socketAddress, @Nullable Duration connectTimeout, SSLConfig sslConfig) {
Assert.requireNonNull(connectionProvider, "connectionProvider must not be null");
Assert.requireNonNull(socketAddress, "socketAddress must not be null");
TcpClient tcpClient = TcpClient.create(connectionProvider).addressSupplier(() -> socketAddress);
if (!(socketAddress instanceof InetSocketAddress)) {
tcpClient = tcpClient.runOn(new SocketLoopResources(), true);
}
if (connectTimeout != null) {
tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(connectTimeout.toMillis()));
}
return tcpClient.connect().flatMap(it -> {
ChannelPipeline pipeline = it.channel().pipeline();
InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
if (logger.isTraceEnabled()) {
pipeline.addFirst(LoggingHandler.class.getSimpleName(),
new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
}
return registerSslHandler(sslConfig, it).thenReturn(new ReactorNettyClient(it));
});
}
private static Mono extends Void> registerSslHandler(SSLConfig sslConfig, Connection it) {
if (sslConfig.getSslMode().startSsl()) {
SSLSessionHandlerAdapter sslSessionHandlerAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig);
it.addHandlerFirst(sslSessionHandlerAdapter);
return sslSessionHandlerAdapter.getHandshake();
}
return Mono.empty();
}
@Override
public Disposable addNotificationListener(Consumer consumer) {
return this.notificationProcessor.subscribe(consumer);
}
@Override
public Disposable addNotificationListener(Subscriber consumer) {
return this.notificationProcessor.subscribe(consumer::onNext, consumer::onError, consumer::onComplete, consumer::onSubscribe);
}
@Override
public ByteBufAllocator getByteBufAllocator() {
return this.byteBufAllocator;
}
@Override
public Optional getProcessId() {
return Optional.ofNullable(this.processId);
}
@Override
public Optional getSecretKey() {
return Optional.ofNullable(this.secretKey);
}
@Override
public TransactionStatus getTransactionStatus() {
return this.transactionStatus;
}
@Override
public Version getVersion() {
return this.version;
}
@Override
public boolean isConnected() {
if (this.isClosed.get()) {
return false;
}
if (this.requestProcessor.isDisposed()) {
return false;
}
Channel channel = this.connection.channel();
return channel.isOpen();
}
private static String toString(List fields) {
StringJoiner joiner = new StringJoiner(", ");
for (Field field : fields) {
joiner.add(field.getType().name() + "=" + field.getValue());
}
return joiner.toString();
}
private void handleClose() {
if (this.isClosed.compareAndSet(false, true)) {
drainError(UNEXPECTED);
} else {
drainError(EXPECTED);
}
}
private void handleConnectionError(Throwable error) {
drainError(() -> new PostgresConnectionException(error));
}
private void drainError(Supplier extends Throwable> supplier) {
this.messageSubscriber.close(supplier);
if (!this.notificationProcessor.isTerminated()) {
this.notificationProcessor.onError(supplier.get());
}
}
private final class EnsureSubscribersCompleteChannelHandler extends ChannelDuplexHandler {
private final EmitterProcessor> requestProcessor;
private EnsureSubscribersCompleteChannelHandler(EmitterProcessor> requestProcessor) {
this.requestProcessor = requestProcessor;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
super.channelUnregistered(ctx);
this.requestProcessor.onComplete();
handleClose();
}
}
static class PostgresConnectionClosedException extends R2dbcNonTransientResourceException {
public PostgresConnectionClosedException(String reason) {
super(reason);
}
}
static class PostgresConnectionException extends R2dbcNonTransientResourceException {
public PostgresConnectionException(Throwable cause) {
super(cause);
}
}
static class RequestQueueException extends R2dbcTransientResourceException {
public RequestQueueException(String message) {
super(message);
}
}
static class ResponseQueueException extends R2dbcNonTransientResourceException {
public ResponseQueueException(String message) {
super(message);
}
}
static class SocketLoopResources implements LoopResources {
@Nullable
private static final Class extends Channel> EPOLL_SOCKET = findClass("io.netty.channel.epoll.EpollDomainSocketChannel");
@Nullable
private static final Class extends Channel> KQUEUE_SOCKET = findClass("io.netty.channel.kqueue.KQueueDomainSocketChannel");
private static final boolean kqueue;
static {
boolean kqueueCheck = false;
try {
Class.forName("io.netty.channel.kqueue.KQueue");
kqueueCheck = io.netty.channel.kqueue.KQueue.isAvailable();
} catch (ClassNotFoundException cnfe) {
}
kqueue = kqueueCheck;
}
private static final boolean epoll;
static {
boolean epollCheck = false;
try {
Class.forName("io.netty.channel.epoll.Epoll");
epollCheck = Epoll.isAvailable();
} catch (ClassNotFoundException cnfe) {
}
epoll = epollCheck;
}
private final LoopResources delegate = TcpResources.get();
@SuppressWarnings("unchecked")
private static Class extends Channel> findClass(String className) {
try {
return (Class extends Channel>) SocketLoopResources.class.getClassLoader().loadClass(className);
} catch (ClassNotFoundException e) {
return null;
}
}
@Override
public Class extends Channel> onChannel(EventLoopGroup group) {
if (epoll && EPOLL_SOCKET != null) {
return EPOLL_SOCKET;
}
if (kqueue && KQUEUE_SOCKET != null) {
return KQUEUE_SOCKET;
}
return this.delegate.onChannel(group);
}
@Override
public EventLoopGroup onClient(boolean useNative) {
return this.delegate.onClient(useNative);
}
@Override
public Class extends DatagramChannel> onDatagramChannel(EventLoopGroup group) {
return this.delegate.onDatagramChannel(group);
}
@Override
public EventLoopGroup onServer(boolean useNative) {
return this.delegate.onServer(useNative);
}
@Override
public Class extends ServerChannel> onServerChannel(EventLoopGroup group) {
return this.delegate.onServerChannel(group);
}
@Override
public EventLoopGroup onServerSelect(boolean useNative) {
return this.delegate.onServerSelect(useNative);
}
@Override
public boolean preferNative() {
return this.delegate.preferNative();
}
@Override
public boolean daemon() {
return this.delegate.daemon();
}
@Override
public void dispose() {
this.delegate.dispose();
}
@Override
public Mono disposeLater() {
return this.delegate.disposeLater();
}
}
/**
* Value object representing a single conversation. The driver permits a single conversation at a time to ensure that request messages get routed to the proper response receiver and do not leak
* into other conversations. A conversation must be finished in the sense that the {@link Publisher} of {@link FrontendMessage} has completed before the next conversation is started.
*
* A single conversation can make use of pipelining.
*/
private static class Conversation {
private static final AtomicLongFieldUpdater DEMAND_UPDATER = AtomicLongFieldUpdater.newUpdater(Conversation.class, "demand");
private final Predicate takeUntil;
private final FluxSink sink;
// access via DEMAND_UPDATER
private volatile long demand;
private Conversation(Predicate takeUntil, FluxSink sink) {
this.sink = sink;
this.takeUntil = takeUntil;
}
private long decrementDemand() {
return Operators.addCap(DEMAND_UPDATER, this, -1);
}
/**
* Check whether the {@link BackendMessage} can complete the conversation.
*
* @param item
* @return
*/
public boolean canComplete(BackendMessage item) {
return this.takeUntil.test(item);
}
/**
* Complete the conversation.
*
* @param item
* @return
*/
public void complete(BackendMessage item) {
ReferenceCountUtil.release(item);
if (!this.sink.isCancelled()) {
this.sink.complete();
}
}
/**
* Emit a {@link BackendMessage}.
*
* @param item
* @return
*/
public void emit(BackendMessage item) {
if (this.sink.isCancelled()) {
ReferenceCountUtil.release(item);
}
decrementDemand();
this.sink.next(item);
}
/**
* Notify the conversation about an error. Drops errors silently if the conversation is finished.
*
* @param throwable
*/
public void onError(Throwable throwable) {
if (!this.sink.isCancelled()) {
this.sink.error(throwable);
}
}
public boolean hasDemand() {
return DEMAND_UPDATER.get(this) > 0;
}
public boolean isCancelled() {
return this.sink.isCancelled();
}
public void incrementDemand(long n) {
Operators.addCap(DEMAND_UPDATER, this, n);
}
}
/**
* Subscriber that handles {@link Conversation}s and keeps track of the current demand. It also routes {@link BackendMessage}s to the currently active {@link Conversation}.
*/
private class BackendMessageSubscriber implements CoreSubscriber {
private static final int DEMAND = 256;
private final Queue conversations = Queues.small().get();
private final Queue buffer = Queues.get(DEMAND).get();
private final AtomicLong demand = new AtomicLong(0);
private final AtomicBoolean drain = new AtomicBoolean();
private volatile boolean terminated;
private Subscription upstream;
public Flux addConversation(Predicate takeUntil, Publisher requests, Consumer> sender,
Supplier isConnected) {
return Flux.create(sink -> {
Conversation conversation = new Conversation(takeUntil, sink);
// ensure ordering in which conversations are added to both queues.
synchronized (this.conversations) {
if (this.conversations.offer(conversation)) {
sink.onRequest(value -> onRequest(conversation, value));
if (!isConnected.get()) {
sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
return;
}
Flux requestMessages = Flux.from(requests).doOnNext(m -> {
if (!isConnected.get()) {
sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
}
});
sender.accept(requestMessages);
} else {
sink.error(new RequestQueueException("Cannot exchange messages because the request queue limit is exceeded"));
}
}
});
}
public void onRequest(Conversation conversation, long n) {
conversation.incrementDemand(n);
while (hasBufferedItems() && hasDownstreamDemand()) {
drainLoop();
}
}
private void demandMore() {
if (!hasBufferedItems() && this.demand.compareAndSet(0, DEMAND)) {
this.upstream.request(DEMAND);
}
}
@Override
public void onSubscribe(Subscription s) {
this.upstream = s;
this.demandMore();
}
private boolean hasDownstreamDemand() {
Conversation conversation = this.conversations.peek();
return conversation != null && conversation.hasDemand();
}
@Override
public void onNext(BackendMessage message) {
if (this.terminated) {
ReferenceCountUtil.release(message);
Operators.onNextDropped(message, currentContext());
return;
}
this.demand.decrementAndGet();
// fast-path
if (this.buffer.isEmpty()) {
Conversation conversation = this.conversations.peek();
if (conversation != null && conversation.hasDemand()) {
emit(conversation, message);
potentiallyDemandMore(conversation);
return;
}
}
// slow-path
if (!this.buffer.offer(message)) {
ReferenceCountUtil.release(message);
Operators.onNextDropped(message, currentContext());
onError(new ResponseQueueException("Response queue is full"));
return;
}
while (hasBufferedItems() && hasDownstreamDemand()) {
this.drainLoop();
}
}
private void drainLoop() {
Conversation lastConversation = null;
if (this.drain.compareAndSet(false, true)) {
try {
while (hasBufferedItems()) {
Conversation conversation = this.conversations.peek();
lastConversation = conversation;
if (conversation == null) {
break;
}
if (conversation.hasDemand()) {
BackendMessage item = this.buffer.poll();
if (item == null) {
break;
}
emit(conversation, item);
} else {
break;
}
}
} finally {
this.drain.compareAndSet(true, false);
}
}
potentiallyDemandMore(lastConversation);
}
private void potentiallyDemandMore(@Nullable Conversation lastConversation) {
if (lastConversation == null || lastConversation.hasDemand() || lastConversation.isCancelled()) {
this.demandMore();
}
}
private void emit(Conversation conversation, BackendMessage item) {
if (conversation.canComplete(item)) {
this.conversations.poll();
conversation.complete(item);
} else {
conversation.emit(item);
}
}
private boolean hasBufferedItems() {
return !this.buffer.isEmpty();
}
@Override
public void onError(Throwable throwable) {
if (this.terminated) {
Operators.onErrorDropped(throwable, currentContext());
return;
}
handleConnectionError(throwable);
ReactorNettyClient.this.requestProcessor.onComplete();
this.terminated = true;
if (isSslException(throwable)) {
logger.debug("Connection Error", throwable);
} else {
logger.error("Connection Error", throwable);
}
ReactorNettyClient.this.close().subscribe();
}
@Override
public void onComplete() {
this.terminated = true;
ReactorNettyClient.this.handleClose();
}
@Override
public Context currentContext() {
Conversation receiver = this.conversations.peek();
if (receiver != null) {
return receiver.sink.currentContext();
} else {
return Context.empty();
}
}
/**
* Cleanup the subscriber by terminating all {@link Conversation}s and purging the data buffer.
*
* @param supplier
*/
public void close(Supplier extends Throwable> supplier) {
this.terminated = true;
Conversation receiver;
while ((receiver = this.conversations.poll()) != null) {
receiver.onError(supplier.get());
}
while (!this.buffer.isEmpty()) {
ReferenceCountUtil.release(this.buffer.poll());
}
}
}
}