All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.pgasync.impl.PgConnection Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.github.pgasync.impl;

import com.github.pgasync.Connection;
import com.github.pgasync.ResultSet;
import com.github.pgasync.Row;
import com.github.pgasync.Transaction;
import com.github.pgasync.impl.conversion.DataConverter;
import com.github.pgasync.impl.message.*;
import com.github.pgasync.impl.protocol.ProtocolStream;
import lombok.Getter;
import lombok.experimental.Accessors;
import rx.Completable;
import rx.Observable;
import rx.Single;
import rx.Subscriber;
import rx.functions.Action0;
import rx.functions.Func2;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.nurkiewicz.typeof.TypeOf.whenTypeOf;

/**
 * A connection to PostgreSQL backed. The postmaster forks a backend process for
 * each connection. A connection can process only a single queryRows at a time.
 *
 * @author Antti Laisi
 * @author Jacek Sokol
 */
public class PgConnection implements Connection {
    private static int NEXT_CONNECTION_NUMBER = 0;
    private final int number;
    private final ProtocolStream stream;
    private final DataConverter dataConverter;
    private long timeout = 0;
    private Completable setTimeout = Completable.complete();

    PgConnection(ProtocolStream stream, DataConverter dataConverter) {
        this.number = NEXT_CONNECTION_NUMBER++;
        this.stream = stream;
        this.dataConverter = dataConverter;
    }

    @Override
    public Completable close() {
        return stream.close();
    }

    @Override
    public Observable listen(String channel) {
        // TODO: wait for commit before sending unlisten as otherwise it can be rolled back
        AtomicBoolean stopped = new AtomicBoolean();
        Action0 ensureStopped = () -> {
            if (!stopped.get()) {
                stopped.set(true);
                querySet("UNLISTEN " + channel).subscribe();
            }
        };

        return querySet("LISTEN " + channel)
                .flatMapObservable(__ -> stream.listen(channel))
                .doOnUnsubscribe(ensureStopped)
                .doOnTerminate(ensureStopped);
    }

    @Override
    public Single begin() {
        return querySet("BEGIN").map(__ -> new PgConnectionTransaction());
    }

    @Override
    public Observable queryRows(String sql, Object... params) {
        ResultBuilder resultBuilder = new ResultBuilder(dataConverter);
        return sendCommand(sql, params)
                .lift(resultBuilder);
    }

    @Override
    public Single querySet(String sql, Object... params) {
        ResultBuilder resultBuilder = new ResultBuilder(dataConverter);
        Func2, Row, ArrayList> reducer = (list, row) -> {
            list.add(row);
            return list;
        };

        return sendCommand(sql, params)
                .lift(resultBuilder)
                .reduce(new ArrayList<>(), reducer)
                .map(rows -> PgResultSet.create(rows, resultBuilder.columns(), resultBuilder.updated()))
                .last()
                .toSingle();
    }

    @Override
    public Connection withTimeout(long value, TimeUnit timeUnit) {
        long millis = timeUnit.toMillis(value);
        if (millis != timeout) {
            timeout = millis;
            setTimeout = sendCommand("SET statement_timeout = " + millis, new Object[]{}).last().toCompletable();
        }

        return this;
    }

    @Override
    public boolean isConnected() {
        return stream.isConnected();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        PgConnection that = (PgConnection) o;

        return number == that.number;
    }

    @Override
    public int hashCode() {
        return number;
    }

    @Override
    public String toString() {
        return "PgConnection{" +
                "number=" + number +
                ", timeout=" + timeout +
                ", connected=" + stream.isConnected() +
                '}';
    }

    private Observable sendCommand(String sql, Object[] params) {
        Observable command = (params == null || params.length == 0)
                ? stream.command(new Query(sql))
                : stream.command(
                new Parse(sql),
                new Bind(dataConverter.fromParameters(params)),
                ExtendedQuery.DESCRIBE,
                ExtendedQuery.EXECUTE,
                ExtendedQuery.CLOSE,
                ExtendedQuery.SYNC
        );

        return setTimeout.doOnCompleted(() -> setTimeout = Completable.complete()).andThen(command);
    }

    Single connect(String username, String password, String database) {
        return stream.connect(new StartupMessage(username, database))
                .flatMapCompletable(auth ->
                        auth.success()
                                ? Completable.complete()
                                : stream.authenticate(new PasswordMessage(username, password, auth.md5salt()))
                )
                .toSingleDefault(this);
    }

    @Getter
    @Accessors(fluent = true)
    static class ResultBuilder implements Observable.Operator {
        private final DataConverter dataConverter;

        private Map columns;
        private int updated;

        ResultBuilder(DataConverter dataConverter) {
            this.dataConverter = dataConverter;
        }

        @Override
        public Subscriber call(Subscriber subscriber) {
            return new Subscriber(subscriber) {
                @Override
                public void onCompleted() {
                    subscriber.onCompleted();
                }

                @Override
                public void onError(Throwable e) {
                    subscriber.onError(e);
                }

                @Override
                public void onNext(Message o) {
                    whenTypeOf(o)
                            .is(DataRow.class).then(dataRow -> subscriber.onNext(PgRow.create(dataRow, columns, dataConverter)))
                            .is(RowDescription.class).then(rowDescription -> columns = readColumns(rowDescription.columns()))
                            .is(CommandComplete.class).then(commandComplete -> updated = commandComplete.updatedRows());
                }
            };
        }

        private Map readColumns(RowDescription.ColumnDescription[] descriptions) {
            Map columns = new HashMap<>();

            for (int i = 0; i < descriptions.length; i++) {
                String columnName = descriptions[i].name().toUpperCase();
                PgColumn pgColumn = PgColumn.create(i, descriptions[i].type());
                columns.put(columnName, pgColumn);
            }

            return columns;
        }
    }

    /**
     * Transaction that rollbacks the tx on backend error and closes the connection on COMMIT/ROLLBACK failure.
     */
    class PgConnectionTransaction implements Transaction {
        @Override
        public Single begin() {
            return querySet("SAVEPOINT sp_1").map(rs -> new PgConnectionNestedTransaction(1));
        }

        @Override
        public Completable commit() {
            return PgConnection.this.querySet("COMMIT")
                    .toCompletable()
                    .onErrorResumeNext(this::closeStream);
        }

        @Override
        public Completable rollback() {
            return PgConnection.this.querySet("ROLLBACK")
                    .toCompletable()
                    .onErrorResumeNext(this::closeStream);
        }

        @Override
        public Observable queryRows(String sql, Object... params) {
            return PgConnection.this.queryRows(sql, params)
                    .onErrorResumeNext(this::doRollback);
        }

        @Override
        public Single querySet(String sql, Object... params) {
            return PgConnection.this.querySet(sql, params)
                    .onErrorResumeNext(t -> this.doRollback(t).toSingle());
        }

        @Override
        public Transaction withTimeout(long timeout, TimeUnit timeUnit) {
            PgConnection.this.withTimeout(timeout, timeUnit);
            return this;
        }

        private  Observable doRollback(Throwable t) {
            return PgConnection.this.isConnected()
                    ? rollback().andThen(Observable.error(t))
                    : Observable.error(t);
        }

        private Completable closeStream(Throwable exception) {
            return stream.close().onErrorComplete().andThen(Completable.error(exception));
        }
    }

    /**
     * Nested Transaction using savepoints.
     */
    class PgConnectionNestedTransaction extends PgConnectionTransaction {
        final int depth;

        PgConnectionNestedTransaction(int depth) {
            this.depth = depth;
        }

        @Override
        public Single begin() {
            return querySet("SAVEPOINT sp_" + (depth + 1))
                    .map(rs -> new PgConnectionNestedTransaction(depth + 1));
        }

        @Override
        public Completable commit() {
            return PgConnection.this.querySet("RELEASE SAVEPOINT sp_" + depth)
                    .toCompletable();
        }

        @Override
        public Completable rollback() {
            return PgConnection.this.querySet("ROLLBACK TO SAVEPOINT sp_" + depth)
                    .toCompletable();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy