com.github.pgasync.impl.protocol.ProtocolStream Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of postgres-async-driver Show documentation
Show all versions of postgres-async-driver Show documentation
Asynchronous PostgreSQL Java driver
The newest version!
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pgasync.impl.protocol;
import com.github.pgasync.DatabaseConfig;
import com.github.pgasync.SqlException;
import com.github.pgasync.impl.NettyScheduler;
import com.github.pgasync.impl.message.*;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.*;
import rx.Observable;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import static com.nurkiewicz.typeof.TypeOf.whenTypeOf;
import static rx.subscriptions.Subscriptions.create;
/**
* Protocol stream handler.
*
* @author Jacek Sokol
*/
public class ProtocolStream {
private static final Logger LOG = LoggerFactory.getLogger(ProtocolStream.class);
abstract class PgConsumer implements Consumer {
final String query;
PgConsumer(String query) {
this.query = query;
}
abstract void error(Throwable throwable);
void closeStream() {
Completable
.fromAction(() -> {
LOG.warn("Closing channel due to premature cancellation [{}]", query);
subscribers.remove(this);
dirty = true;
ctx.channel().close();
})
.subscribeOn(scheduler)
.await(1, TimeUnit.SECONDS);
}
}
abstract class ProtocolConsumer extends PgConsumer {
final SingleSubscriber subscriber;
final AtomicBoolean done = new AtomicBoolean();
@SuppressWarnings("unchecked")
ProtocolConsumer(SingleSubscriber> subscriber, String query) {
super(query);
this.subscriber = (SingleSubscriber) subscriber;
subscriber.add(create(ProtocolConsumer.this::unsubscribe));
}
void complete(T value) {
if (!done.get()) {
done.set(true);
subscriber.onSuccess(value);
}
}
void complete() {
complete(null);
}
void error(Throwable throwable) {
done.set(true);
subscriber.onError(throwable);
}
void unsubscribe() {
if (!done.get()) closeStream();
}
}
abstract class StreamConsumer extends PgConsumer {
final Emitter subscriber;
final AtomicBoolean done = new AtomicBoolean();
@SuppressWarnings("unchecked")
StreamConsumer(Emitter subscriber, String query) {
super(query);
this.subscriber = subscriber;
subscriber.setSubscription(create(StreamConsumer.this::unsubscribe));
}
void complete() {
if (!done.get()) {
done.set(true);
subscriber.onCompleted();
}
}
public void error(Throwable throwable) {
done.set(true);
subscriber.onError(throwable);
}
void unsubscribe() {
if (!done.get()) closeStream();
}
}
private final EventLoopGroup group;
private final DatabaseConfig config;
private final GenericFutureListener> onError;
private final Queue subscribers = new LinkedBlockingDeque<>(); // TODO: limit pipeline queue depth
private final ConcurrentMap>> listeners = new ConcurrentHashMap<>();
private ChannelHandlerContext ctx;
private boolean dirty;
private Scheduler scheduler;
public ProtocolStream(EventLoopGroup group, DatabaseConfig config) {
this.group = group;
this.config = config;
this.onError = future -> {
if (!future.isSuccess()) {
handleError(future.cause());
}
};
}
public Single connect(StartupMessage startup) {
return Single.create(subscriber -> {
ProtocolConsumer consumer = new ProtocolConsumer(subscriber, "CONNECT") {
@Override
public void accept(Message message) {
whenTypeOf(message)
.is(ErrorResponse.class).then(e -> error(toSqlException(e)))
.is(ReadyForQuery.class).then(r -> complete(new Authentication(true, null)))
.is(Authentication.class).then(this::handleAuthRequest)
.orElse(m -> error(new SqlException("Unexpected message at startup stage: " + m)));
}
private void handleAuthRequest(Authentication auth) {
if (!auth.success()) {
subscribers.remove();
complete(auth);
}
}
};
subscribers.add(consumer);
InboundChannelInitializer inboundChannelInitializer = new InboundChannelInitializer(startup);
MessageHandler messageHandler = new MessageHandler(subscribers, listeners, this::handleError);
new Bootstrap()
.group(group)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.connectTimeout())
.channel(NioSocketChannel.class)
.handler(new ProtocolInitializer(config, inboundChannelInitializer, messageHandler))
.connect(config.address())
.addListener(onError);
});
}
public Completable authenticate(PasswordMessage password) {
return Single
.create(subscriber -> {
ProtocolConsumer consumer = new ProtocolConsumer(subscriber, "AUTHENTICATE") {
@Override
public void accept(Message message) {
whenTypeOf(message)
.is(ErrorResponse.class).then(e -> error(toSqlException(e)))
.is(Authentication.class).then(this::handleAuthResponse)
.is(ReadyForQuery.class).then(r -> complete());
}
private void handleAuthResponse(Authentication a) {
if (!a.success())
error(new SqlException("Failed to authenticate"));
}
};
subscribers.add(consumer);
write(password);
})
.subscribeOn(scheduler)
.toCompletable();
}
public Observable command(Message... messages) {
if (messages.length == 0)
return Observable.error(new IllegalArgumentException("No messages to send"));
else if (!isConnected())
return Observable.error(new IllegalStateException("Channel is closed [" + messages[0] + "]"));
return Observable.unsafeCreate(BackPressuredEmitter.create(emitter -> {
StreamConsumer consumer = new StreamConsumer(emitter, messages[0].toString()) {
SqlException exception;
@Override
public void accept(Message message) {
whenTypeOf(message)
.is(ErrorResponse.class).then(this::handleError)
.is(ReadyForQuery.class).then(r -> handleReady())
.is(CommandComplete.class).then(this::handleCompletion)
.is(Message.class).then(emitter::onNext);
}
private void handleCompletion(CommandComplete commandComplete) {
enableAutoRead();
emitter.onNext(commandComplete);
}
private void handleReady() {
if (exception == null)
complete();
else
error(exception);
}
private void handleError(ErrorResponse e) {
exception = toSqlException(e);
enableAutoRead();
}
};
subscribers.add(consumer);
write(messages);
disableAutoRead();
readNext();
}, this::readNext)).subscribeOn(scheduler);
}
public Observable listen(String channel) {
if (!isConnected())
return Observable.error(new IllegalStateException("Channel is closed [LISTEN]"));
return Observable.unsafeCreate(BackPressuredEmitter.create(emitter -> {
StreamConsumer consumer = new StreamConsumer(emitter, "LISTEN") {
@Override
public void accept(Message message) {
whenTypeOf(message)
.is(ErrorResponse.class).then(this::handleError)
.is(CommandComplete.class).then(commandComplete -> enableAutoRead())
.is(NotificationResponse.class).then(notificationResponse -> emitter.onNext(notificationResponse.payload()));
}
private void handleError(ErrorResponse e) {
emitter.onError(toSqlException(e));
enableAutoRead();
}
@Override
protected void unsubscribe() {
enableAutoRead();
ctx.executor().submit(() ->
Optional.of(listeners.get(channel)).ifPresent(list -> {
list.remove(this);
if (list.isEmpty())
listeners.remove(channel);
})
);
}
};
List> consumers = listeners.getOrDefault(channel, new LinkedList<>());
consumers.add(consumer);
listeners.put(channel, consumers);
disableAutoRead();
readNext();
}, this::readNext)).subscribeOn(scheduler);
}
public boolean isConnected() {
return !dirty && Optional
.ofNullable(ctx)
.map(c -> c.channel().isOpen())
.orElse(false);
}
public Completable close() {
return Completable
.create(subscriber -> {
dirty = true;
handleError(new RuntimeException("Closing connection"));
ctx.writeAndFlush(Terminate.INSTANCE)
.addListener(closed -> {
if (closed.isSuccess())
subscriber.onCompleted();
else
subscriber.onError(closed.cause());
});
}
)
.subscribeOn(scheduler);
}
private void write(Message... messages) {
for (Message message : messages) {
LOG.trace("Writing: {}", message);
ctx.write(message).addListener(onError);
}
ctx.flush();
}
private void readNext() {
ctx.channel().read();
}
private void enableAutoRead() {
ctx.channel().config().setAutoRead(true);
}
private void disableAutoRead() {
ctx.channel().config().setAutoRead(false);
}
private void handleError(Throwable throwable) {
if (!isConnected()) {
subscribers.forEach(subscriber -> subscriber.error(throwable));
subscribers.clear();
listeners.values().stream().flatMap(Collection::stream).forEach(consumer -> consumer.error(throwable));
listeners.clear();
} else
Optional.ofNullable(subscribers.poll()).ifPresent(s -> s.error(throwable));
dirty = true;
}
private SqlException toSqlException(ErrorResponse error) {
return new SqlException(error.level().name(), error.code(), error.message());
}
private class InboundChannelInitializer extends ChannelInboundHandlerAdapter {
private final StartupMessage startup;
InboundChannelInitializer(StartupMessage startup) {
this.startup = startup;
}
@Override
public void channelActive(ChannelHandlerContext context) {
ProtocolStream.this.ctx = context;
scheduler = NettyScheduler.forEventExecutor(ctx.executor());
if (config.useSsl())
write(SSLHandshake.INSTANCE);
else
writeStartupAndFixPipeline(context);
}
@Override
public void userEventTriggered(ChannelHandlerContext context, Object evt) {
whenTypeOf(evt).is(SslHandshakeCompletionEvent.class).then(e -> {
if (e.isSuccess())
writeStartupAndFixPipeline(context);
else
context.fireExceptionCaught(new SqlException("Failed to initialise SSL"));
});
}
private void writeStartupAndFixPipeline(ChannelHandlerContext context) {
write(startup);
context.pipeline().remove(this);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy