All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mariadb.r2dbc.client.MariadbSegmentResult Maven / Gradle / Ivy

The newest version!
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab

package org.mariadb.r2dbc.client;

import io.netty.buffer.ByteBuf;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.spi.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import org.mariadb.r2dbc.ExceptionFactory;
import org.mariadb.r2dbc.MariadbConnectionConfiguration;
import org.mariadb.r2dbc.api.MariadbResult;
import org.mariadb.r2dbc.message.Protocol;
import org.mariadb.r2dbc.message.ServerMessage;
import org.mariadb.r2dbc.message.server.*;
import org.mariadb.r2dbc.util.Assert;
import org.mariadb.r2dbc.util.ServerPrepareResult;
import org.mariadb.r2dbc.util.constants.ServerStatus;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public final class MariadbSegmentResult extends AbstractReferenceCounted implements MariadbResult {

  private final Flux segments;

  private MariadbSegmentResult(Flux segments) {
    this.segments = segments;
  }

  MariadbSegmentResult(
      Protocol protocol,
      AtomicReference prepareResult,
      Flux messages,
      ExceptionFactory factory,
      String[] generatedColumns,
      boolean supportReturning,
      MariadbConnectionConfiguration conf) {

    final List columns = new ArrayList<>();
    final AtomicBoolean metaFollows = new AtomicBoolean(true);
    final AtomicReference rowConstructor =
        new AtomicReference<>();
    final AtomicReference meta = new AtomicReference<>();
    final AtomicBoolean isOutputParameter = new AtomicBoolean();
    this.segments =
        messages.handle(
            (message, sink) -> {
              if (message instanceof CompletePrepareResult) {
                prepareResult.set(((CompletePrepareResult) message).getPrepare());
                return;
              }

              if (message instanceof ColumnCountPacket) {
                metaFollows.set(((ColumnCountPacket) message).isMetaFollows());
                if (!metaFollows.get()) {
                  columns.addAll(Arrays.asList(prepareResult.get().getColumns()));
                }
                return;
              }

              if (message instanceof ColumnDefinitionPacket) {
                columns.add((ColumnDefinitionPacket) message);
                return;
              }

              if (message instanceof EofPacket) {
                EofPacket eof = (EofPacket) message;
                if (!eof.ending()) {
                  rowConstructor.set(
                      protocol == Protocol.TEXT ? MariadbRowText::new : MariadbRowBinary::new);
                  ColumnDefinitionPacket[] columnsArray =
                      columns.toArray(new ColumnDefinitionPacket[0]);

                  meta.set(new MariadbRowMetadata(columnsArray));

                  // in case metadata follows and prepared statement, update meta
                  if (prepareResult != null && prepareResult.get() != null && metaFollows.get()) {
                    prepareResult.get().setColumns(columnsArray);
                  }

                  isOutputParameter.set(
                      (eof.getServerStatus() & ServerStatus.PS_OUT_PARAMETERS) > 0);
                }
                return;
              }

              if (message instanceof ErrorPacket) {
                sink.next(new MariadbErrorSegment((ErrorPacket) message, factory));
                return;
              }

              if (message instanceof OkPacket) {

                if (generatedColumns != null && !supportReturning) {
                  String colName = generatedColumns.length > 0 ? generatedColumns[0] : "ID";
                  MariadbRowMetadata tmpMeta =
                      new MariadbRowMetadata(
                          new ColumnDefinitionPacket[] {
                            ColumnDefinitionPacket.fromGeneratedId(colName, conf)
                          });
                  if (((OkPacket) message).value() > 1) {
                    sink.error(
                        factory.createException(
                            "Connector cannot get generated ID (using returnGeneratedValues)"
                                + " multiple rows before MariaDB 10.5.1",
                            "HY000",
                            -1));
                    return;
                  }
                  ByteBuf buf =
                      org.mariadb.r2dbc.client.MariadbResult.getLongTextEncoded(
                          ((OkPacket) message).getLastInsertId());
                  org.mariadb.r2dbc.api.MariadbRow row = new MariadbRowText(buf, tmpMeta, factory);
                  sink.next(new MariadbRowSegment(row, buf));
                }

                Long rowCount = ((OkPacket) message).value();
                if (rowCount != null) {
                  sink.next(new MariadbUpdateCountSegment(rowCount));
                }
                return;
              }

              if (message instanceof RowPacket) {
                RowPacket row = ((RowPacket) message);
                if (isOutputParameter.get()) {
                  org.mariadb.r2dbc.api.MariadbOutParameters outParameters =
                      new MariadbOutParameters(
                          row.getRaw(), new MariadbOutParametersMetadata(columns), factory);
                  sink.next(new MariadbOutSegment(outParameters, (RowPacket) message));
                } else {
                  org.mariadb.r2dbc.api.MariadbRow rowSegment =
                      rowConstructor.get().create(row.getRaw(), meta.get(), factory);
                  sink.next(new MariadbRowSegment(rowSegment, (RowPacket) message));
                }
              }
            });
  }

  static MariadbSegmentResult toResult(
      Protocol protocol,
      AtomicReference prepareResult,
      Flux messages,
      ExceptionFactory factory,
      String[] generatedColumns,
      boolean supportReturning,
      MariadbConnectionConfiguration conf) {
    return new MariadbSegmentResult(
        protocol, prepareResult, messages, factory, generatedColumns, supportReturning, conf);
  }

  @Override
  public Mono getRowsUpdated() {
    return this.segments
        .handle(
            (segment, sink) -> {
              try {
                if (segment instanceof MariadbErrorSegment) {
                  sink.error(((MariadbErrorSegment) segment).exception());
                  return;
                }

                if (segment instanceof Result.UpdateCount) {
                  sink.next((int) (((Result.UpdateCount) segment).value()));
                }

              } finally {
                ReferenceCountUtil.release(segment);
              }
            })
        .collectList()
        .handle(
            (list, sink) -> {
              if (list.isEmpty()) {
                return;
              }

              long sum = 0;

              for (Integer integer : list) {
                sum += integer;
              }

              sink.next(sum);
            });
  }

  @Override
  public  Flux map(BiFunction f) {
    Assert.requireNonNull(f, "f must not be null");

    return this.segments.handle(
        (segment, sink) -> {
          try {
            if (segment instanceof MariadbErrorSegment) {
              sink.error(((MariadbErrorSegment) segment).exception());
              return;
            }

            if (segment instanceof Result.RowSegment) {
              Result.RowSegment row = (Result.RowSegment) segment;
              sink.next(f.apply(row.row(), row.row().getMetadata()));
            }

          } finally {
            ReferenceCountUtil.release(segment);
          }
        });
  }

  @Override
  public MariadbSegmentResult filter(Predicate filter) {
    Assert.requireNonNull(filter, "filter must not be null");
    return new MariadbSegmentResult(
        this.segments.filter(
            it -> {
              boolean result = filter.test(it);
              if (!result) {
                ReferenceCountUtil.release(it);
              }
              return result;
            }));
  }

  @Override
  @SuppressWarnings("unchecked")
  public  Publisher flatMap(
      Function> mappingFunction) {
    Assert.requireNonNull(mappingFunction, "mappingFunction must not be null");
    return this.segments.concatMap(
        segment -> {
          Publisher result = mappingFunction.apply(segment);
          if (result == null) {
            return Mono.error(new IllegalStateException("The mapper returned a null Publisher"));
          }

          if (result instanceof Mono) {
            return ((Mono) result).doFinally(s -> ReferenceCountUtil.release(segment));
          }

          return Flux.from(result).doFinally(s -> ReferenceCountUtil.release(segment));
        });
  }

  @Override
  protected void deallocate() {
    this.getRowsUpdated().subscribe();
  }

  @Override
  public ReferenceCounted touch(Object hint) {
    return this;
  }

  @Override
  public String toString() {
    return "MariadbSegmentResult{segments=" + this.segments + '}';
  }

  static class MariadbRowSegment extends AbstractReferenceCounted implements Result.RowSegment {
    private final Row row;
    private final ReferenceCounted releaseable;

    public MariadbRowSegment(Row row, ReferenceCounted releaseable) {
      this.row = row;
      this.releaseable = releaseable;
    }

    @Override
    public Row row() {
      return this.row;
    }

    @Override
    protected void deallocate() {
      ReferenceCountUtil.release(this.releaseable);
    }

    @Override
    public ReferenceCounted touch(Object hint) {
      return this;
    }
  }

  static class MariadbOutSegment extends AbstractReferenceCounted implements Result.OutSegment {

    private final OutParameters outParameters;

    private final ReferenceCounted releaseable;

    public MariadbOutSegment(OutParameters outParameters, ReferenceCounted releaseable) {
      this.outParameters = outParameters;
      this.releaseable = releaseable;
    }

    @Override
    public OutParameters outParameters() {
      return outParameters;
    }

    @Override
    protected void deallocate() {
      ReferenceCountUtil.release(this.releaseable);
    }

    @Override
    public ReferenceCounted touch(Object hint) {
      return this;
    }
  }

  static class MariadbUpdateCountSegment implements Result.UpdateCount {

    private final long value;

    public MariadbUpdateCountSegment(long value) {
      this.value = value;
    }

    @Override
    public long value() {
      return this.value;
    }
  }

  static class MariadbErrorSegment implements Result.Message {

    private final ExceptionFactory factory;

    private final ErrorPacket error;

    public MariadbErrorSegment(ErrorPacket error, ExceptionFactory factory) {
      this.factory = factory;
      this.error = error;
    }

    @Override
    public R2dbcException exception() {
      return this.factory.from(error);
    }

    @Override
    public int errorCode() {
      return error.errorCode();
    }

    @Override
    public String sqlState() {
      return error.sqlState();
    }

    @Override
    public String message() {
      return error.getMessage();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy