All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mariadb.r2dbc.MariadbCommonStatement Maven / Gradle / Ivy

// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab

package org.mariadb.r2dbc;

import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import org.mariadb.r2dbc.api.MariadbStatement;
import org.mariadb.r2dbc.client.Client;
import org.mariadb.r2dbc.client.MariadbResult;
import org.mariadb.r2dbc.codec.Codecs;
import org.mariadb.r2dbc.message.Protocol;
import org.mariadb.r2dbc.message.ServerMessage;
import org.mariadb.r2dbc.util.Assert;
import org.mariadb.r2dbc.util.Binding;
import org.mariadb.r2dbc.util.ServerPrepareResult;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

public abstract class MariadbCommonStatement implements MariadbStatement {
  public static final int UNKNOWN_SIZE = -1;
  protected final List bindings = new ArrayList<>();
  protected final Client client;
  protected final String initialSql;
  protected final MariadbConnectionConfiguration configuration;
  protected final ExceptionFactory factory;
  protected int expectedSize;
  protected String[] generatedColumns;
  private Binding currentBinding;

  public MariadbCommonStatement(
      Client client, String sql, MariadbConnectionConfiguration configuration) {
    this.client = client;
    this.configuration = configuration;
    this.initialSql = Assert.requireNonNull(sql, "sql must not be null");
    this.factory = ExceptionFactory.withSql(sql);
  }

  /**
   * Augments an SQL statement with a {@code RETURNING} statement and column names. If the
   * collection is empty, uses {@code *} for column names.
   *
   * @param sql the SQL to augment
   * @param generatedColumns the names of the columns to augment with
   * @return an augmented sql statement returning the specified columns or a wildcard
   * @throws IllegalArgumentException if {@code sql} or {@code generatedColumns} is {@code null}
   */
  public static String augment(String sql, String[] generatedColumns) {
    Assert.requireNonNull(sql, "sql must not be null");
    Assert.requireNonNull(generatedColumns, "generatedColumns must not be null");
    return String.format(
        "%s RETURNING %s",
        sql, generatedColumns.length == 0 ? "*" : String.join(", ", generatedColumns));
  }

  protected static void tryNextBinding(
      Iterator iterator, Sinks.Many bindingSink, AtomicBoolean canceled) {

    if (canceled.get()) {
      return;
    }

    try {
      if (iterator.hasNext()) {
        bindingSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST);
      } else {
        bindingSink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST);
      }
    } catch (Exception e) {
      bindingSink.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST);
    }
  }

  protected void initializeBinding() {
    currentBinding = new Binding(getExpectedSize());
  }

  public MariadbStatement add() {
    currentBinding.validate(getExpectedSize());
    this.bindings.add(currentBinding);
    currentBinding = new Binding(getExpectedSize());
    return this;
  }

  @Override
  public MariadbStatement bind(String identifier, Object value) {
    return bind(getColumnIndex(identifier), value);
  }

  @Override
  public MariadbStatement bindNull(String identifier, Class type) {
    return bindNull(getColumnIndex(identifier), type);
  }

  @Override
  public MariadbStatement bindNull(int index, Class type) {
    if (index < 0) {
      throw new IndexOutOfBoundsException(
          String.format("wrong index value %d, index must be positive", index));
    }
    if (index >= expectedSize && expectedSize != UNKNOWN_SIZE) {

      throw new IndexOutOfBoundsException(
          (getExpectedSize() == 0)
              ? String.format(
                  "Binding parameters is not supported for the statement '%s'", initialSql)
              : String.format(
                  "Cannot bind parameter %d, statement has %d parameters", index, expectedSize));
    }
    getCurrentBinding().add(index, Codecs.encodeNull(type, index));
    return this;
  }

  @Override
  public MariadbStatement bind(int index, Object value) {
    Assert.requireNonNull(value, "value must not be null");
    if (index < 0) {
      throw new IndexOutOfBoundsException(
          String.format("wrong index value %d, index must be positive", index));
    }

    getCurrentBinding().add(index, Codecs.encode(value, index));
    return this;
  }

  protected abstract int getColumnIndex(String name);

  @Nonnull
  protected Binding getCurrentBinding() {
    return currentBinding;
  }

  public Flux toResult(
      final Protocol protocol,
      Flux messages,
      ExceptionFactory factory,
      AtomicReference prepareResult) {
    return messages
        .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)
        .windowUntil(ServerMessage::resultSetEnd)
        .map(
            dataRow ->
                new MariadbResult(
                    protocol,
                    prepareResult,
                    dataRow,
                    factory,
                    generatedColumns,
                    client.getVersion().supportReturning(),
                    configuration))
        .flatMap(m -> client.redirect().then(Mono.just(m)))
        .cast(org.mariadb.r2dbc.api.MariadbResult.class);
  }

  protected void clearBindings(Iterator iterator, AtomicBoolean canceled) {
    canceled.set(true);
    while (iterator.hasNext()) {
      iterator.next().clear();
    }
    this.bindings.forEach(Binding::clear);
  }

  protected int getExpectedSize() {
    return expectedSize;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy