com.github.pgasync.impl.netty.NettyPgProtocolStream Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of postgres-async-driver Show documentation
Show all versions of postgres-async-driver Show documentation
Asynchronous PostgreSQL Java driver
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pgasync.impl.netty;
import com.github.pgasync.SqlException;
import com.github.pgasync.impl.PgProtocolStream;
import com.github.pgasync.impl.message.*;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import rx.Observable;
import rx.Subscriber;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
/**
* Netty connection to PostgreSQL backend.
*
* @author Antti Laisi
*/
public class NettyPgProtocolStream implements PgProtocolStream {
final EventLoopGroup group;
final EventLoop eventLoop;
final SocketAddress address;
final boolean useSsl;
final boolean pipeline;
final GenericFutureListener> onError;
final Queue> subscribers;
final ConcurrentMap>> listeners = new ConcurrentHashMap<>();
ChannelHandlerContext ctx;
public NettyPgProtocolStream(EventLoopGroup group, SocketAddress address, boolean useSsl, boolean pipeline) {
this.group = group;
this.eventLoop = group.next();
this.address = address;
this.useSsl = useSsl; // TODO: refactor into SSLConfig with trust parameters
this.pipeline = pipeline;
this.subscribers = new LinkedBlockingDeque<>(); // TODO: limit pipeline queue depth
this.onError = future -> {
if(!future.isSuccess()) {
subscribers.peek().onError(future.cause());
}
};
}
@Override
public Observable connect(StartupMessage startup) {
return Observable.create(subscriber -> {
pushSubscriber(subscriber);
new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.handler(newProtocolInitializer(newStartupHandler(startup)))
.connect(address)
.addListener(onError);
}).flatMap(this::throwErrorResponses);
}
@Override
public Observable authenticate(PasswordMessage password) {
return Observable.create(subscriber -> {
pushSubscriber(subscriber);
write(password);
}).flatMap(this::throwErrorResponses);
}
@Override
public Observable send(Message... messages) {
return Observable.create(subscriber -> {
if (!isConnected()) {
subscriber.onError(new IllegalStateException("Channel is closed"));
return;
}
if(pipeline && !eventLoop.inEventLoop()) {
eventLoop.submit(() -> {
pushSubscriber(subscriber);
write(messages);
});
return;
}
pushSubscriber(subscriber);
write(messages);
}).lift(throwErrorResponsesOnComplete());
}
@Override
public boolean isConnected() {
return ctx.channel().isOpen();
}
@Override
public Observable listen(String channel) {
String subscriptionId = UUID.randomUUID().toString();
return Observable.create(subscriber -> {
Map> consumers = new ConcurrentHashMap<>();
Map> old = listeners.putIfAbsent(channel, consumers);
consumers = old != null ? old : consumers;
consumers.put(subscriptionId, subscriber);
}).doOnUnsubscribe(() -> {
Map> consumers = listeners.get(channel);
if (consumers == null || consumers.remove(subscriptionId) == null) {
throw new IllegalStateException("No consumers on channel " + channel + " with id " + subscriptionId);
}
});
}
@Override
public Observable close() {
return Observable.create(subscriber ->
ctx.writeAndFlush(Terminate.INSTANCE).addListener(written ->
ctx.close().addListener(closed -> {
if (!closed.isSuccess()) {
subscriber.onError(closed.cause());
return;
}
subscriber.onNext(null);
subscriber.onCompleted();
})));
}
private void pushSubscriber(Subscriber super Message> subscriber) {
if(!subscribers.offer(subscriber)) {
throw new IllegalStateException("Pipelining not enabled " + subscribers.peek());
}
}
private void write(Message... messages) {
for(Message message : messages) {
ctx.write(message).addListener(onError);
}
ctx.flush();
}
private void publishNotification(NotificationResponse notification) {
Map> consumers = listeners.get(notification.getChannel());
if(consumers != null) {
consumers.values().forEach(c -> c.onNext(notification.getPayload()));
}
}
private Observable throwErrorResponses(Object message) {
return message instanceof ErrorResponse
? Observable.error(toSqlException((ErrorResponse) message))
: Observable.just((Message) message);
}
private static Observable.Operator throwErrorResponsesOnComplete() {
return subscriber -> new Subscriber
© 2015 - 2024 Weber Informatics LLC | Privacy Policy