org.mariadb.r2dbc.client.MariadbResult Maven / Gradle / Ivy
The newest version!
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab
package org.mariadb.r2dbc.client;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import org.mariadb.r2dbc.ExceptionFactory;
import org.mariadb.r2dbc.MariadbConnectionConfiguration;
import org.mariadb.r2dbc.message.Protocol;
import org.mariadb.r2dbc.message.ServerMessage;
import org.mariadb.r2dbc.message.server.*;
import org.mariadb.r2dbc.util.ServerPrepareResult;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class MariadbResult extends AbstractReferenceCounted
implements org.mariadb.r2dbc.api.MariadbResult {
private final Protocol protocol;
private final Flux messages;
private final ExceptionFactory factory;
private final String[] generatedColumns;
private final boolean supportReturning;
private final MariadbConnectionConfiguration conf;
private final AtomicReference prepareResult;
public MariadbResult(
Protocol protocol,
AtomicReference prepareResult,
Flux messages,
ExceptionFactory factory,
String[] generatedColumns,
boolean supportReturning,
MariadbConnectionConfiguration conf) {
this.protocol = protocol;
this.messages = messages;
this.factory = factory;
this.generatedColumns = generatedColumns;
this.supportReturning = supportReturning;
this.conf = conf;
this.prepareResult = prepareResult;
}
public static ByteBuf getLongTextEncoded(long value) {
byte[] byteValue = Long.toString(value).getBytes(StandardCharsets.US_ASCII);
byte[] encodedLength;
int length = byteValue.length;
encodedLength = new byte[] {(byte) length};
return Unpooled.copiedBuffer(encodedLength, byteValue);
}
@Override
@SuppressWarnings({"rawtypes", "unchecked"})
public Mono getRowsUpdated() {
// Since CLIENT_DEPRECATE_EOF is not set in order to identify output parameter
// number of updated row can be identified either by OK_Packet or number of rows in case of
// RETURNING
final AtomicLong rowCount = new AtomicLong(0);
return this.messages
.handle(
(serverMessage, sink) -> {
if (serverMessage instanceof OkPacket) {
OkPacket okPacket = ((OkPacket) serverMessage);
sink.next(okPacket.value());
return;
}
if (serverMessage instanceof ErrorPacket) {
sink.error(this.factory.from((ErrorPacket) serverMessage));
return;
}
if (serverMessage instanceof EofPacket) {
EofPacket eofPacket = ((EofPacket) serverMessage);
if (eofPacket.resultSetEnd()) {
sink.next(rowCount.get());
rowCount.set(0);
}
return;
}
if (serverMessage instanceof RowPacket) {
rowCount.incrementAndGet();
serverMessage.release();
}
})
.collectList()
.handle(
(list, sink) -> {
if (list.isEmpty()) {
return;
}
int sum = 0;
for (Long i : list) {
sum += i;
}
sink.next(sum);
sink.complete();
});
}
@Override
@SuppressWarnings({"rawtypes", "unchecked"})
public Flux map(BiFunction f) {
final List columns = new ArrayList<>();
final AtomicBoolean metaFollows = new AtomicBoolean(true);
final AtomicReference rowConstructor =
new AtomicReference<>();
final AtomicReference meta = new AtomicReference<>();
return this.messages.handle(
(message, sink) -> {
if (message instanceof ErrorPacket) {
sink.error(this.factory.from((ErrorPacket) message));
return;
}
if (message instanceof CompletePrepareResult) {
this.prepareResult.set(((CompletePrepareResult) message).getPrepare());
return;
}
if (message instanceof ColumnCountPacket) {
metaFollows.set(((ColumnCountPacket) message).isMetaFollows());
if (!metaFollows.get()) {
columns.addAll(Arrays.asList(this.prepareResult.get().getColumns()));
}
return;
}
if (message instanceof OkPacket) {
OkPacket okPacket = ((OkPacket) message);
// This is for server that doesn't permit RETURNING: rely on OK_packet LastInsertId
// to retrieve the last generated ID.
if (generatedColumns != null && !supportReturning) {
String colName = generatedColumns.length > 0 ? generatedColumns[0] : "ID";
MariadbRowMetadata tmpMeta =
new MariadbRowMetadata(
new ColumnDefinitionPacket[] {
ColumnDefinitionPacket.fromGeneratedId(colName, conf)
});
if (okPacket.value() > 1) {
sink.error(
this.factory.createException(
"Connector cannot get generated ID (using returnGeneratedValues) multiple"
+ " rows before MariaDB 10.5.1",
"HY000",
-1));
return;
}
ByteBuf buf = getLongTextEncoded(okPacket.getLastInsertId());
org.mariadb.r2dbc.api.MariadbRow row = new MariadbRowText(buf, tmpMeta, factory);
sink.next(f.apply(row, row.getMetadata()));
ReferenceCountUtil.release(row);
if (okPacket.ending()) sink.complete();
}
return;
}
if (message instanceof ColumnDefinitionPacket) {
columns.add((ColumnDefinitionPacket) message);
return;
}
if (message instanceof EofPacket) {
EofPacket eof = (EofPacket) message;
if (!eof.ending()) {
rowConstructor.set(
protocol == Protocol.TEXT ? MariadbRowText::new : MariadbRowBinary::new);
ColumnDefinitionPacket[] columnsArray =
columns.toArray(new ColumnDefinitionPacket[0]);
meta.set(new MariadbRowMetadata(columnsArray));
// in case metadata follows and prepared statement, update meta
if (prepareResult != null && prepareResult.get() != null && metaFollows.get()) {
prepareResult.get().setColumns(columnsArray);
}
} else sink.complete();
return;
}
if (message instanceof RowPacket) {
try {
org.mariadb.r2dbc.api.MariadbRow row =
rowConstructor.get().create(((RowPacket) message).getRaw(), meta.get(), factory);
sink.next(f.apply(row, meta.get()));
} finally {
message.release();
}
}
});
}
@Override
public org.mariadb.r2dbc.api.MariadbResult filter(Predicate filter) {
return MariadbSegmentResult.toResult(
protocol, prepareResult, messages, factory, generatedColumns, supportReturning, conf)
.filter(filter);
}
@Override
public Publisher flatMap(
Function> mappingFunction) {
return MariadbSegmentResult.toResult(
protocol, prepareResult, messages, factory, generatedColumns, supportReturning, conf)
.flatMap(mappingFunction);
}
@Override
protected void deallocate() {
// drain messages for cleanup
this.getRowsUpdated().subscribe();
}
@Override
public ReferenceCounted touch(Object hint) {
return this;
}
@Override
public String toString() {
return "MariadbResult{}";
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy