All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.pgasync.impl.PgConnection Maven / Gradle / Ivy

There is a newer version: 0.9
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.pgasync.impl;

import com.github.pgasync.Connection;
import com.github.pgasync.ResultSet;
import com.github.pgasync.Row;
import com.github.pgasync.Transaction;
import com.github.pgasync.impl.conversion.DataConverter;
import com.github.pgasync.impl.message.*;
import rx.Observable;
import rx.Subscriber;
import rx.observers.Subscribers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

import static com.github.pgasync.impl.message.RowDescription.ColumnDescription;

/**
 * A connection to PostgreSQL backed. The postmaster forks a backend process for
 * each connection. A connection can process only a single queryRows at a time.
 * 
 * @author Antti Laisi
 */
public class PgConnection implements Connection {

    final PgProtocolStream stream;
    final DataConverter dataConverter;

    public PgConnection(PgProtocolStream stream, DataConverter dataConverter) {
        this.stream = stream;
        this.dataConverter = dataConverter;
    }

    Observable connect(String username, String password, String database) {
       return stream.connect(new StartupMessage(username, database))
               .flatMap(message -> authenticate(username, password, message))
               .single(message -> message == ReadyForQuery.INSTANCE)
               .map(ready -> this);
    }

    Observable authenticate(String username, String password, Message message) {
        return message instanceof Authentication && !((Authentication) message).isAuthenticationOk()
                    ? stream.authenticate(new PasswordMessage(username, password, ((Authentication) message).getMd5Salt()))
                    : Observable.just(message);
    }

    boolean isConnected() {
        return stream.isConnected();
    }

    @Override
    public Observable querySet(String sql, Object... params) {
        return sendQuery(sql, params)
                .lift(toResultSet(dataConverter));
    }

    @Override
    public Observable queryRows(String sql, Object... params) {
        return sendQuery(sql, params)
                .lift(toRow(dataConverter));
    }

    @Override
    public Observable begin() {
        return querySet("BEGIN").map(rs -> new PgConnectionTransaction());
    }

    @Override
    public Observable listen(String channel) {
        // TODO: wait for commit before sending unlisten as otherwise it can be rolled back
        return querySet("LISTEN " + channel)
                .lift(subscriber -> Subscribers.create( rs -> stream.listen(channel)
                                                                        .subscribe(subscriber),
                                                                subscriber::onError))
                .doOnUnsubscribe(() -> querySet("UNLISTEN " + channel).subscribe(rs -> { }));
    }

    @Override
    public void close() throws Exception {
        CountDownLatch latch = new CountDownLatch(1);
        stream.close().subscribe(__ -> latch.countDown(), ex -> {
            Logger.getLogger(getClass().getName()).warning("Exception closing connection: " + ex);
            latch.countDown();
        });
        latch.await(1000, TimeUnit.MILLISECONDS);
    }

    private Observable sendQuery(String sql, Object[] params) {
        return params == null || params.length == 0
                ? stream.send(
                    new Query(sql))
                : stream.send(
                    new Parse(sql),
                    new Bind(dataConverter.fromParameters(params)),
                    ExtendedQuery.DESCRIBE,
                    ExtendedQuery.EXECUTE,
                    ExtendedQuery.CLOSE,
                    ExtendedQuery.SYNC);
    }

    static Observable.Operator toRow(DataConverter dataConverter) {
        return subscriber -> new Subscriber() {

            Map columns;

            @Override
            public void onNext(Message message) {
                if (message instanceof RowDescription) {
                    columns = getColumns(((RowDescription) message).getColumns());
                } else if(message instanceof DataRow) {
                    subscriber.onNext(new PgRow((DataRow) message, columns, dataConverter));
                }
            }
            @Override
            public void onError(Throwable e) {
                subscriber.onError(e);
            }
            @Override
            public void onCompleted() {
                subscriber.onCompleted();
            }
        };
    }

    static Observable.Operator toResultSet(DataConverter dataConverter) {
        return subscriber -> new Subscriber() {

            Map columns;
            List rows = new ArrayList<>();
            int updated;

            @Override
            public void onNext(Message message) {
                if (message instanceof RowDescription) {
                    columns = getColumns(((RowDescription) message).getColumns());
                } else if(message instanceof DataRow) {
                    rows.add(new PgRow((DataRow) message, columns, dataConverter));
                } else if(message instanceof CommandComplete) {
                    updated = ((CommandComplete) message).getUpdatedRows();
                } else if(message == ReadyForQuery.INSTANCE) {
                    subscriber.onNext(new PgResultSet(columns, rows, updated));
                }
            }
            @Override
            public void onError(Throwable e) {
                subscriber.onError(e);
            }
            @Override
            public void onCompleted() {
                subscriber.onCompleted();
            }
        };
    }

    static Map getColumns(ColumnDescription[] descriptions) {
        Map columns = new HashMap<>();
        for (int i = 0; i < descriptions.length; i++) {
            columns.put(descriptions[i].getName().toUpperCase(), new PgColumn(i, descriptions[i].getType()));
        }
        return columns;
    }

    /**
     * Transaction that rollbacks the tx on backend error and closes the connection on COMMIT/ROLLBACK failure.
     */
    class PgConnectionTransaction implements Transaction {

        @Override
        public Observable commit() {
            return PgConnection.this.querySet("COMMIT")
                    .map(rs -> (Void) null)
                    .doOnError(exception -> stream.close().subscribe());
        }
        @Override
        public Observable rollback() {
            return PgConnection.this.querySet("ROLLBACK")
                    .map(rs -> (Void) null)
                    .doOnError(exception -> stream.close().subscribe());
        }
        @Override
        public Observable queryRows(String sql, Object... params) {
            return PgConnection.this.queryRows(sql, params)
                    .onErrorResumeNext(this::doRollback);
        }
        @Override
        public Observable querySet(String sql, Object... params) {
            return PgConnection.this.querySet(sql, params)
                    .onErrorResumeNext(this::doRollback);
        }
         Observable doRollback(Throwable t) {
            return rollback().flatMap(__ -> Observable.error(t));
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy