All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mariadb.r2dbc.MariadbServerParameterizedQueryStatement Maven / Gradle / Ivy

The newest version!
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab

package org.mariadb.r2dbc;

import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.mariadb.r2dbc.api.MariadbStatement;
import org.mariadb.r2dbc.client.Client;
import org.mariadb.r2dbc.client.DecoderState;
import org.mariadb.r2dbc.message.Protocol;
import org.mariadb.r2dbc.message.ServerMessage;
import org.mariadb.r2dbc.message.client.ExecutePacket;
import org.mariadb.r2dbc.message.client.PreparePacket;
import org.mariadb.r2dbc.message.client.QueryPacket;
import org.mariadb.r2dbc.util.Assert;
import org.mariadb.r2dbc.util.Binding;
import org.mariadb.r2dbc.util.ServerNamedParamParser;
import org.mariadb.r2dbc.util.ServerPrepareResult;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

final class MariadbServerParameterizedQueryStatement extends MariadbCommonStatement
    implements MariadbStatement {

  private final AtomicReference prepareResult;
  private ServerNamedParamParser paramParser;

  MariadbServerParameterizedQueryStatement(
      Client client, String sql, MariadbConnectionConfiguration configuration) {
    super(client, sql, configuration);
    this.expectedSize = UNKNOWN_SIZE;
    this.paramParser = null;
    this.prepareResult = new AtomicReference<>(client.getPrepareCache().get(sql));
    initializeBinding();
  }

  @Override
  protected int getExpectedSize() {
    if (expectedSize == UNKNOWN_SIZE) {
      expectedSize =
          (prepareResult.get() != null)
              ? prepareResult.get().getNumParams()
              : (((paramParser != null)
                  ? paramParser.getParamCount()
                  : ServerNamedParamParser.parameterParts(
                          initialSql, this.client.noBackslashEscapes())
                      .getParamCount()));
    }
    return expectedSize;
  }

  protected int getColumnIndex(String name) {
    Assert.requireNonNull(name, "identifier cannot be null");
    if (paramParser == null) {
      paramParser =
          ServerNamedParamParser.parameterParts(initialSql, this.client.noBackslashEscapes());
    }
    for (int i = 0; i < this.paramParser.getParamNameList().size(); i++) {
      if (name.equals(this.paramParser.getParamNameList().get(i))) return i;
    }
    throw new NoSuchElementException(
        String.format(
            "No parameter with name '%s' found (possible values %s)",
            name, this.paramParser.getParamNameList().toString()));
  }

  @Override
  public MariadbServerParameterizedQueryStatement returnGeneratedValues(String... columns) {
    Assert.requireNonNull(columns, "columns must not be null");

    if (!client.getVersion().supportReturning() && columns.length > 1) {
      throw new IllegalArgumentException(
          "returnGeneratedValues can have only one column before MariaDB 10.5.1");
    }
    this.generatedColumns = columns;
    return this;
  }

  @Override
  public Flux execute() {
    String realSql = paramParser == null ? this.initialSql : paramParser.getRealSql();
    String sql;
    if (this.generatedColumns == null || !client.getVersion().supportReturning()) {
      sql = realSql;
    } else {
      sql = augment(realSql, this.generatedColumns);
    }
    ExceptionFactory factory = ExceptionFactory.withSql(sql);

    if (prepareResult.get() == null && client.getPrepareCache() != null) {
      prepareResult.set(client.getPrepareCache().get(sql));
    }
    if (this.getExpectedSize() != 0) {
      this.getCurrentBinding().validate(this.getExpectedSize());
      return Flux.defer(
          () -> {
            if (this.bindings.size() == 0) {
              // single query
              Binding binding = getCurrentBinding();
              this.initializeBinding();

              if (prepareResult.get() != null) {
                ServerPrepareResult res;
                if (this.client.getPrepareCache() != null
                    && (res = this.client.getPrepareCache().get(sql)) != null
                    && !res.equals(prepareResult.get())) {
                  prepareResult.get().decrementUse(client);
                  prepareResult.set(res);
                }

                if (prepareResult.get().incrementUse()) {
                  Flux messages =
                      this.client
                          .sendCommand(
                              new ExecutePacket(
                                  sql,
                                  prepareResult.get(),
                                  binding.getBindResultParameters(
                                      prepareResult.get().getNumParams())),
                              DecoderState.QUERY_RESPONSE,
                              sql,
                              false)
                          .doFinally(s -> prepareResult.get().decrementUse(client));
                  return toResult(Protocol.BINARY, messages, factory, prepareResult);
                } else {
                  // prepare is closing
                  prepareResult.set(null);
                }
              }
              Flux messages;
              if (configuration.allowPipelining()
                  && client.getVersion().isMariaDBServer()
                  && client.getVersion().versionGreaterOrEqual(10, 2, 0)) {
                messages =
                    this.client.sendCommand(
                        new PreparePacket(sql),
                        new ExecutePacket(sql, null, binding.getBinds()),
                        false);
              } else {
                messages =
                    client
                        .sendPrepare(new PreparePacket(sql), factory, sql)
                        .flatMapMany(
                            serverPrepareResult -> {
                              prepareResult.set(serverPrepareResult);
                              return this.client.sendCommand(
                                  new ExecutePacket(
                                      sql,
                                      prepareResult.get(),
                                      binding.getBindResultParameters(
                                          prepareResult.get().getNumParams())),
                                  DecoderState.QUERY_RESPONSE,
                                  sql,
                                  false);
                            });
              }
              return toResult(Protocol.BINARY, messages, factory, prepareResult)
                  .doFinally(
                      s -> {
                        if (prepareResult.get() != null) {
                          prepareResult.get().decrementUse(client);
                        }
                      });
            }
            // batch
            bindings.add(getCurrentBinding());
            this.initializeBinding();
            Iterator iterator = this.bindings.iterator();
            Sinks.Many bindingSink = Sinks.many().unicast().onBackpressureBuffer();
            AtomicBoolean canceled = new AtomicBoolean();
            return prepareIfNotDone(sql, factory)
                .thenMany(
                    bindingSink
                        .asFlux()
                        .doOnComplete(() -> clearBindings(iterator, canceled))
                        .map(
                            binding -> {
                              Flux messages =
                                  this.client
                                      .sendCommand(
                                          new ExecutePacket(
                                              sql,
                                              prepareResult.get(),
                                              binding.getBindResultParameters(
                                                  prepareResult.get().getNumParams())),
                                          false)
                                      .doOnComplete(
                                          () -> tryNextBinding(iterator, bindingSink, canceled));

                              return toResult(Protocol.BINARY, messages, factory, prepareResult);
                            })
                        .doOnSubscribe(
                            it ->
                                bindingSink.emitNext(
                                    iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST))
                        .doOnComplete(this.bindings::clear)
                        .doFinally(
                            s -> {
                              if (prepareResult.get() != null) {
                                prepareResult.get().decrementUse(client);
                              }
                            })
                        .doOnCancel(() -> clearBindings(iterator, canceled))
                        .doOnError(e -> clearBindings(iterator, canceled)))
                .flatMap(mariadbResultFlux -> mariadbResultFlux);
          });
    } else {
      return Flux.defer(
          () -> {
            Flux messages =
                this.client.sendCommand(
                    new QueryPacket(sql), DecoderState.QUERY_RESPONSE, sql, false);
            return toResult(Protocol.TEXT, messages, factory, null);
          });
    }
  }

  private Mono prepareIfNotDone(String sql, ExceptionFactory factory) {
    // prepare command, if not already done
    if (prepareResult.get() == null) {
      prepareResult.set(client.getPrepareCache().get(sql));
      if (prepareResult.get() == null) {
        return client
            .sendPrepare(new PreparePacket(sql), factory, sql)
            .doOnSuccess(p -> prepareResult.set(p));
      }
    }
    prepareResult.get().incrementUse();
    return Mono.just(prepareResult.get());
  }

  @Override
  public String toString() {
    List tmpBindings = new ArrayList<>();
    tmpBindings.addAll(bindings);
    tmpBindings.add(getCurrentBinding());
    return "MariadbServerParameterizedQueryStatement{"
        + "client="
        + client
        + ", sql='"
        + initialSql
        + '\''
        + ", bindings="
        + Arrays.toString(tmpBindings.toArray())
        + ", configuration="
        + configuration
        + ", generatedColumns="
        + (generatedColumns != null ? Arrays.toString(generatedColumns) : null)
        + ", prepareResult="
        + prepareResult.get()
        + '}';
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy