All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.r2dbc.postgresql.PostgresqlStatement Maven / Gradle / Ivy

/*
 * Copyright 2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.r2dbc.postgresql;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.postgresql.client.Binding;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.client.EncodedParameter;
import io.r2dbc.postgresql.client.SimpleQueryMessageFlow;
import io.r2dbc.postgresql.message.backend.BackendMessage;
import io.r2dbc.postgresql.message.backend.CommandComplete;
import io.r2dbc.postgresql.message.backend.EmptyQueryResponse;
import io.r2dbc.postgresql.message.backend.ErrorResponse;
import io.r2dbc.postgresql.message.frontend.Bind;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.GeneratedValuesUtils;
import io.r2dbc.postgresql.util.Operators;
import io.r2dbc.spi.Statement;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

import javax.annotation.Nonnull;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;

import static io.r2dbc.postgresql.message.frontend.Execute.NO_LIMIT;
import static io.r2dbc.postgresql.util.PredicateUtils.or;

/**
 * A generic {@link Statement}.
 *
 * @since 0.9
 */
final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlStatement {

    private static final Predicate WINDOW_UNTIL = or(CommandComplete.class::isInstance, EmptyQueryResponse.class::isInstance, ErrorResponse.class::isInstance);

    private final ArrayDeque bindings;

    private final ConnectionResources resources;

    private final ConnectionContext connectionContext;

    private final ParsedSql parsedSql;

    private int fetchSize;

    private String[] generatedColumns;

    PostgresqlStatement(ConnectionResources resources, String sql) {
        this.resources = Assert.requireNonNull(resources, "resources must not be null");
        this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null"));
        this.connectionContext = resources.getClient().getContext();
        this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount());

        if (this.parsedSql.getStatementCount() > 1 && this.parsedSql.getParameterCount() > 0) {
            throw new IllegalArgumentException(String.format("Statement '%s' cannot be created. This is often due to the presence of both multiple statements and parameters at the same time.", sql));
        }

        fetchSize(this.resources.getConfiguration().getFetchSize(sql));
    }

    @Override
    public PostgresqlStatement add() {
        Binding binding = this.bindings.peekLast();
        if (binding != null) {
            binding.validate();
        }
        this.bindings.add(new Binding(this.parsedSql.getParameterCount()));
        return this;
    }

    @Override
    public PostgresqlStatement bind(String identifier, Object value) {
        return bind(getIdentifierIndex(identifier), value);
    }

    @Override
    public PostgresqlStatement bind(int index, Object value) {
        Assert.requireNonNull(value, "value must not be null");

        BindingLogger.logBind(this.connectionContext, index, value);
        getCurrentOrFirstBinding().add(index, this.resources.getCodecs().encode(value));
        return this;
    }

    @Override
    public PostgresqlStatement bindNull(String identifier, Class type) {
        return bindNull(getIdentifierIndex(identifier), type);
    }

    @Override
    public PostgresqlStatement bindNull(int index, Class type) {
        Assert.requireNonNull(type, "type must not be null");

        if (index >= this.parsedSql.getParameterCount()) {
            throw new UnsupportedOperationException(String.format("Cannot bind parameter %d, statement has %d parameters", index, this.parsedSql.getParameterCount()));
        }

        BindingLogger.logBindNull(this.connectionContext, index, type);
        getCurrentOrFirstBinding().add(index, this.resources.getCodecs().encodeNull(type));
        return this;
    }

    @Nonnull
    private Binding getCurrentOrFirstBinding() {
        Binding binding = this.bindings.peekLast();
        if (binding == null) {
            Binding newBinding = new Binding(this.parsedSql.getParameterCount());
            this.bindings.add(newBinding);
            return newBinding;
        } else {
            return binding;
        }
    }

    @Override
    public Flux execute() {
        if (this.generatedColumns == null) {
            return execute(this.parsedSql.getSql());
        }
        return execute(GeneratedValuesUtils.augment(this.parsedSql.getSql(), this.generatedColumns));
    }

    @Override
    public PostgresqlStatement returnGeneratedValues(String... columns) {
        Assert.requireNonNull(columns, "columns must not be null");

        if (this.parsedSql.hasDefaultTokenValue("RETURNING")) {
            throw new IllegalStateException("Statement already includes RETURNING clause");
        }

        if (!this.parsedSql.hasDefaultTokenValue("DELETE", "INSERT", "UPDATE")) {
            throw new IllegalStateException("Statement is not a DELETE, INSERT, or UPDATE command");
        }

        this.generatedColumns = columns;
        return this;
    }

    @Override
    public PostgresqlStatement fetchSize(int rows) {
        Assert.isTrue(rows >= 0, "fetch size must be greater or equal zero");
        this.fetchSize = rows;
        return this;
    }

    @Override
    public String toString() {
        return "PostgresqlStatement{" +
            "bindings=" + this.bindings +
            ", context=" + this.resources +
            ", sql='" + this.parsedSql.getSql() + '\'' +
            ", generatedColumns=" + Arrays.toString(this.generatedColumns) +
            '}';
    }

    Binding getCurrentBinding() {
        return getCurrentOrFirstBinding();
    }

    private int getIdentifierIndex(String identifier) {
        Assert.requireNonNull(identifier, "identifier must not be null");
        Assert.requireType(identifier, String.class, "identifier must be a String");
        if (!identifier.startsWith("$")) {
            throw new NoSuchElementException(String.format("\"%s\" is not a valid identifier", identifier));
        }
        try {
            return Integer.parseInt(identifier.substring(1)) - 1;
        } catch (NumberFormatException e) {
            throw new NoSuchElementException(String.format("\"%s\" is not a valid identifier", identifier));
        }
    }

    private Flux execute(String sql) {
        ExceptionFactory factory = ExceptionFactory.withSql(sql);

        if (this.parsedSql.getParameterCount() != 0) {
            // Extended query protocol
            if (this.bindings.size() == 0) {
                throw new IllegalStateException("No parameters have been bound");
            }

            this.bindings.forEach(Binding::validate);
            int fetchSize = this.fetchSize;
            return Flux.defer(() -> {

                // possible optimization: fetch all when statement is already prepared or first statement to be prepared
                if (this.bindings.size() == 1) {

                    Binding binding = this.bindings.peekFirst();
                    Flux messages = collectBindingParameters(binding).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, binding, values, fetchSize));
                    return Flux.just(PostgresqlResult.toResult(this.resources, messages, factory));
                }

                Iterator iterator = this.bindings.iterator();
                Sinks.Many bindings = Sinks.many().unicast().onBackpressureBuffer();
                AtomicBoolean canceled = new AtomicBoolean();
                return bindings.asFlux()
                    .map(it -> {
                        Flux messages =
                            collectBindingParameters(it).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, it, values, this.fetchSize)).doOnComplete(() -> tryNextBinding(iterator, bindings, canceled));

                        return PostgresqlResult.toResult(this.resources, messages, factory);
                    })
                    .doOnCancel(() -> clearBindings(iterator, canceled))
                    .doOnError(e -> clearBindings(iterator, canceled))
                    .doOnSubscribe(it -> bindings.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST));

            }).cast(io.r2dbc.postgresql.api.PostgresqlResult.class);
        }

        Flux exchange;
        // Simple Query protocol
        if (this.fetchSize != NO_LIMIT) {
            exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize);
        } else {
            exchange = SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql);
        }

        return exchange.windowUntil(WINDOW_UNTIL)
            .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) // ensure release of rows within WindowPredicate
            .map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
            .as(Operators::discardOnCancel);
    }

    private static void tryNextBinding(Iterator iterator, Sinks.Many bindingSink, AtomicBoolean canceled) {

        if (canceled.get()) {
            return;
        }

        try {
            if (iterator.hasNext()) {
                bindingSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST);
            } else {
                bindingSink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST);
            }
        } catch (Exception e) {
            bindingSink.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST);
        }
    }

    private static Mono> collectBindingParameters(Binding binding) {

        return Flux.fromIterable(binding.getParameterValues())
            .concatMap(f -> {
                if (f == EncodedParameter.NULL_VALUE) {
                    return Flux.just(Bind.NULL_VALUE);
                } else {
                    return Flux.from(f)
                        .reduce(Unpooled.compositeBuffer(), (c, b) -> c.addComponent(true, b));
                }
            })
            .collectList();
    }

    private void clearBindings(Iterator iterator, AtomicBoolean canceled) {

        canceled.set(true);

        while (iterator.hasNext()) {
            // exhaust iterator, ignore returned elements
            iterator.next();
        }

        this.bindings.forEach(Binding::clear);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy