/*
* Licensed to Crate.io GmbH ("Crate") under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. Crate licenses
* this file to you under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial agreement.
*/
package org.finos.legend.engine.postgres;
import com.google.common.net.InetAddresses;
import com.sun.security.jgss.GSSUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Scope;
import org.finos.legend.engine.postgres.auth.AuthenticationMethod;
import org.finos.legend.engine.postgres.auth.AuthenticationMethodType;
import org.finos.legend.engine.postgres.auth.AuthenticationProvider;
import org.finos.legend.engine.postgres.auth.KerberosIdentityProvider;
import org.finos.legend.engine.postgres.config.GSSConfig;
import org.finos.legend.engine.postgres.handler.PostgresResultSetMetaData;
import org.finos.legend.engine.postgres.types.PGType;
import org.finos.legend.engine.postgres.types.PGTypes;
import org.finos.legend.engine.postgres.utils.OpenTelemetryUtil;
import org.finos.legend.engine.shared.core.identity.Identity;
import org.finos.legend.engine.shared.core.kerberos.SubjectTools;
import org.finos.legend.engine.shared.core.operational.Assert;
import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import javax.net.ssl.SSLSession;
import javax.security.auth.Subject;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.sql.ParameterMetaData;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import static org.finos.legend.engine.postgres.FormatCodes.getFormatCode;
/**
* Netty Handler/FrameDecoder for the Postgres wire protocol. This class handles the message
* flow and dispatching
*
*
*
* Client Server
*
* (optional ssl negotiation)
*
*
* | SSLRequest |
* |--------------------------------->|
* | |
* | 'S' | 'N' | error | (supported in Enterprise version)
* |<---------------------------------|
*
*
* startup:
* The authentication flow is handled by implementations of {@link AuthenticationMethod}.
*
* | |
* | StartupMessage |
* |--------------------------------->|
* | |
* | Authentication Method |
* | or |
* | AuthenticationOK |
* | or |
* | ErrorResponse |
* |<---------------------------------|
* | |
* | ParameterStatus |
* |<---------------------------------|
* | |
* | ReadyForQuery |
* |<---------------------------------|
*
*
* Simple Query:
*
* + +
* | Q (query) |
* |--------------------------------->|
* | |
* | RowDescription |
* |<---------------------------------|
* | |
* | DataRow |
* |<---------------------------------|
* | DataRow |
* |<---------------------------------|
* | CommandComplete |
* |<---------------------------------|
* | ReadyForQuery |
* |<---------------------------------|
*
* Extended Query
*
* + +
* | Parse |
* |--------------------------------->|
* | |
* | ParseComplete or ErrorResponse |
* |<---------------------------------|
* | |
* | Describe Statement (optional) |
* |--------------------------------->|
* | |
* | ParameterDescription (optional) |
* |<-------------------------------- |
* | |
* | RowDescription (optional) |
* |<-------------------------------- |
* | |
* | Bind |
* |--------------------------------->|
* | |
* | BindComplete or ErrorResponse |
* |<---------------------------------|
* | |
* | Describe Portal (optional) |
* |--------------------------------->|
* | |
* | RowDescription (optional) |
* |<-------------------------------- |
* | |
* | Execute |
* |--------------------------------->|
* | |
* | DataRow | |
* | CommandComplete | |
* | EmptyQueryResponse | |
* | ErrorResponse |
* |<---------------------------------|
* | |
* | Sync |
* |--------------------------------->|
* | |
* | ReadyForQuery |
* |<---------------------------------|
*
*
* Take a look at {@link Messages} to see how the messages are structured.
*
* See postgresql docs for a more detailed
* description of the message flow
*/
public class PostgresWireProtocol
{
private static final Logger LOGGER = LoggerFactory.getLogger(
PostgresWireProtocol.class);
public static int SERVER_VERSION_NUM = 100500;
public static String PG_SERVER_VERSION = "10.5";
final PgDecoder decoder;
final MessageHandler handler;
private final SessionsFactory sessions;
/* private final Function getAccessControl;*/
private final AuthenticationProvider authService;
private final GSSConfig gssConfig;
/* private final Consumer addTransportHandler;
*/
private DelayableWriteChannel channel;
Session session;
private boolean ignoreTillSync = false;
private AuthenticationContext authContext;
private Properties properties;
private Messages messages;
public PostgresWireProtocol(SessionsFactory sessions,
/*Function getAcessControl,*/
/*Consumer addTransportHandler,*/
AuthenticationProvider authService,
GSSConfig gssConfig, Supplier getSslContext,
Messages messages)
{
this.sessions = sessions;
//this.getAccessControl = getAcessControl;
//this.addTransportHandler = addTransportHandler;
this.authService = authService;
this.decoder = new PgDecoder(getSslContext);
this.handler = new MessageHandler();
this.gssConfig = gssConfig;
this.messages = messages;
}
static String readCString(ByteBuf buffer)
{
byte[] bytes = new byte[buffer.bytesBefore((byte) 0) + 1];
if (bytes.length == 0)
{
return null;
}
buffer.readBytes(bytes);
return new String(bytes, 0, bytes.length - 1, StandardCharsets.UTF_8);
}
private static char[] readCharArray(ByteBuf buffer)
{
byte[] bytes = new byte[buffer.bytesBefore((byte) 0) + 1];
if (bytes.length == 0)
{
return null;
}
buffer.readBytes(bytes);
return StandardCharsets.UTF_8.decode(ByteBuffer.wrap(bytes)).array();
}
private static byte[] readByteArray(ByteBuf buffer, int payloadLength)
{
if (payloadLength == 0)
{
return null;
}
byte[] bytes = new byte[payloadLength];
buffer.readBytes(bytes);
return bytes;
}
private Properties readStartupMessage(ByteBuf buffer)
{
Properties properties = new Properties();
while (true)
{
String key = readCString(buffer);
if (key == null)
{
break;
}
String value = readCString(buffer);
LOGGER.trace("payload: key={} value={}", key, value);
if (!"".equals(key) && !"".equals(value))
{
properties.setProperty(key, value);
}
}
return properties;
}
private static class ReadyForQueryCallback implements BiConsumer
{
private final Channel channel;
private final Messages messages;
private ReadyForQueryCallback(Channel channel, Messages messages)
{
this.channel = channel;
this.messages = messages;
}
@Override
public void accept(Object result, Throwable t)
{
if (t instanceof CompletionException)
{
Throwable actualCause = t.getCause();
PostgresServerException postgresServerException = PostgresServerException.wrapException(actualCause);
messages.sendErrorResponse(channel, postgresServerException);
}
boolean clientInterrupted = t instanceof ClientInterrupted
|| (t != null && t.getCause() instanceof ClientInterrupted);
if (!clientInterrupted)
{
messages.sendReadyForQuery(channel);
}
}
}
private class MessageHandler extends SimpleChannelInboundHandler
{
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception
{
channel = new DelayableWriteChannel(ctx.channel());
}
@Override
public boolean acceptInboundMessage(Object msg) throws Exception
{
return true;
}
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception
{
Assert.assertTrue(channel != null, () -> "Channel must be initialized");
try
{
dispatchState(buffer, channel);
}
catch (Throwable t)
{
ignoreTillSync = true;
try
{
/*AccessControl accessControl = session == null
? AccessControl.DISABLED
: getAccessControl.apply(session.sessionSettings());
messages.sendErrorResponse(channel, accessControl, t);
*/
LOGGER.error("Unable to handle query", t);
messages.sendErrorResponse(channel, t);
}
catch (Throwable ti)
{
LOGGER.error("Error trying to send error to client: {}", t, ti);
}
}
}
private void dispatchState(ByteBuf buffer, DelayableWriteChannel channel) throws Exception
{
switch (decoder.state())
{
case STARTUP_PARAMETERS:
handleStartupBody(buffer, channel);
decoder.startupDone();
return;
case CANCEL:
handleCancelRequestBody(buffer, channel);
return;
case MSG:
LOGGER.trace("msg={} msgLength={} readableBytes={}", ((char) decoder.msgType()),
decoder.payloadLength(), buffer.readableBytes());
if (ignoreTillSync && decoder.msgType() != 'S')
{
buffer.skipBytes(decoder.payloadLength());
return;
}
dispatchMessage(buffer, channel);
return;
default:
throw new IllegalStateException("Illegal state: " + decoder.state());
}
}
private void dispatchMessage(ByteBuf buffer, DelayableWriteChannel channel) throws Exception
{
switch (decoder.msgType())
{
case 'Q': // Query (simple)
LOGGER.trace("Dispatching simple query");
handleSimpleQuery(buffer, channel);
return;
case 'P':
LOGGER.trace("Dispatching parse");
handleParseMessage(buffer, channel);
return;
case 'p':
LOGGER.trace("Dispatching password");
handlePassword(buffer, channel, decoder.payloadLength());
return;
case 'B':
LOGGER.trace("Dispatching bind");
handleBindMessage(buffer, channel);
return;
case 'D':
LOGGER.trace("Dispatching describe");
handleDescribeMessage(buffer, channel);
return;
case 'E':
LOGGER.trace("Dispatching execute");
handleExecute(buffer, channel);
return;
case 'H':
LOGGER.trace("Dispatching flush");
handleFlush(channel);
return;
case 'S':
LOGGER.trace("Dispatching sync");
handleSync(channel);
return;
case 'C':
LOGGER.trace("Dispatching close ");
handleClose(buffer, channel);
return;
case 'X': // Terminate (called when jdbc connection is closed)
LOGGER.trace("Dispatching close session");
closeSession();
channel.close();
return;
default:
/* messages.sendErrorResponse(
channel,
session == null
? AccessControl.DISABLED
: getAccessControl.apply(session.sessionSettings()),
new UnsupportedOperationException("Unsupported messageType: " + decoder.msgType()));*/
messages.sendErrorResponse(channel,
new UnsupportedOperationException("Unsupported messageType: " + decoder.msgType()));
}
}
private void closeSession()
{
if (session != null)
{
session.close();
session = null;
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception
{
if (cause instanceof SocketException && cause.getMessage().equals("Connection reset"))
{
LOGGER.info("Connection reset. Client likely terminated connection");
closeSession();
}
else
{
LOGGER.error("Uncaught exception: ", cause);
}
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception
{
LOGGER.trace("channelDisconnected");
channel = null;
closeSession();
super.channelUnregistered(ctx);
}
}
private void handleStartupBody(ByteBuf buffer, Channel channel)
{
properties = readStartupMessage(buffer);
initAuthentication(channel);
}
public static InetAddress getRemoteAddress(Channel channel)
{
if (channel.remoteAddress() instanceof InetSocketAddress)
{
return ((InetSocketAddress) channel.remoteAddress()).getAddress();
}
// In certain cases the channel is an EmbeddedChannel (e.g. in tests)
// and this type of channel has an EmbeddedSocketAddress instance as remoteAddress
// which does not have an address.
// An embedded socket address is handled like a local connection via loopback.
return InetAddresses.forString("127.0.0.1");
}
public static SSLSession getSession(Channel channel)
{
SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
if (sslHandler != null)
{
return sslHandler.engine().getSession();
}
return null;
}
private void initAuthentication(Channel channel)
{
String userName = properties.getProperty("user");
InetAddress address = getRemoteAddress(channel);
SSLSession sslSession = getSession(channel);
ConnectionProperties connProperties = new ConnectionProperties(address, sslSession);
AuthenticationMethod authMethod = authService.resolveAuthenticationType(userName,
connProperties);
if (authMethod == null)
{
String errorMessage = String.format(
Locale.ENGLISH,
"No valid auth.host_based entry found for host \"%s\", user \"%s\". Did you enable TLS in your client?",
address.getHostAddress(), userName
);
messages.sendAuthenticationError(channel, errorMessage);
}
else
{
authContext = new AuthenticationContext(authMethod, connProperties, userName, LOGGER);
if (authMethod.name() == AuthenticationMethodType.PASSWORD)
{
messages.sendAuthenticationCleartextPassword(channel);
return;
}
if (authMethod.name() == AuthenticationMethodType.GSS)
{
if (gssConfig == null)
{
messages.sendAuthenticationError(channel, "GSS Auth not configured in this server");
return;
}
messages.sendAuthenticationKerberos(channel);
return;
}
finishAuthentication(channel);
}
}
private void finishAuthentication(Channel channel)
{
Assert.assertTrue(authContext != null, () -> "finishAuthentication() requires an authContext instance");
try
{
Identity authenticatedUser = authContext.authenticate();
handleAuthSuccess(channel, authenticatedUser);
}
catch (Exception e)
{
messages.sendAuthenticationError(channel, e.getMessage());
LOGGER.error("Auth Error", e);
}
finally
{
authContext.close();
authContext = null;
}
}
private void finishAuthentication(Channel channel, Subject delegSubject)
{
Assert.assertTrue(authContext != null, () -> "finishAuthentication() requires an authContext instance");
try
{
Identity authenticatedUser = KerberosIdentityProvider.getIdentityForSubject(delegSubject);
handleAuthSuccess(channel, authenticatedUser);
}
catch (Exception e)
{
messages.sendAuthenticationError(channel, e.getMessage());
LOGGER.error("Auth Error", e);
}
finally
{
authContext.close();
authContext = null;
}
}
private void handleAuthSuccess(Channel channel, Identity authenticatedUser) throws Exception
{
String database = properties.getProperty("database");
session = sessions.createSession(database, authenticatedUser);
MDC.put("user", authenticatedUser.getName());
messages.sendAuthenticationOK(channel)
.addListener(f -> sendParams(channel))
//.addListener(f -> messages.sendKeyData(channel, session.id(), session.secret()))
.addListener(f ->
{
messages.sendReadyForQuery(channel);
/*if (properties.containsKey("CrateDBTransport")) {
switchToTransportProtocol(channel);
}*/
});
}
/* private void switchToTransportProtocol(Channel channel) {
var pipeline = channel.pipeline();
pipeline.remove("frame-decoder");
pipeline.remove("handler");
// SSL is already done via PostgreSQL handshake/auth
addTransportHandler.accept(pipeline);
}*/
private void sendParams(Channel channel)
{
/* messages.sendParameterStatus(channel, "crate_version", Version.CURRENT.externalNumber());
*/
messages.sendParameterStatus(channel, "server_version", PG_SERVER_VERSION);
messages.sendParameterStatus(channel, "server_encoding", "UTF8");
messages.sendParameterStatus(channel, "client_encoding", "UTF8");
messages.sendParameterStatus(channel, "datestyle", "ISO");
messages.sendParameterStatus(channel, "TimeZone", "UTC");
messages.sendParameterStatus(channel, "integer_datetimes", "on");
}
/**
* Flush Message | 'H' | int32 len
*
* Flush forces the backend to deliver any data pending in it's output buffers.
*/
private void handleFlush(Channel channel)
{
try
{
/* // If we have deferred any executions we need to trigger a sync now because the client is expecting data
// (That we've been holding back, as we don't eager react to `execute` requests. (We do that to optimize batch inserts))
// The sync will also trigger a flush eventually if there are deferred executions.
if (session.hasDeferredExecutions()) {
session.flush();
} else {
channel.flush();
}*/
//since we don't handle buffering we should flash right away
//TODO understand when this is called
session.sync();
}
catch (Throwable t)
{
//messages.sendErrorResponse(channel, getAccessControl.apply(session.sessionSettings()), t);
messages.sendErrorResponse(channel, t);
}
}
/**
* Parse Message header: | 'P' | int32 len
*
* body: | string statementName | string query | int16 numParamTypes | foreach param: | int32
* type_oid (zero = unspecified)
*/
private void handleParseMessage(ByteBuf buffer, final Channel channel)
{
String statementName = readCString(buffer);
final String query = readCString(buffer);
short numParams = buffer.readShort();
List paramTypes = new ArrayList<>(numParams);
for (int i = 0; i < numParams; i++)
{
int oid = buffer.readInt();
int dataType = PGTypes.fromOID(oid);
/* if (dataType == null) {
throw new IllegalArgumentException(
String.format(Locale.ENGLISH, "Can't map PGType with oid=%d to Crate type", oid));
}*/
paramTypes.add(dataType);
}
CompletableFuture> parseCompletionFuture = session.parseAsync(statementName, query, paramTypes);
CompletableFuture parseMsgSentFuture = parseCompletionFuture.thenRun(() -> messages.sendParseComplete(channel));
session.setActiveExecution(parseMsgSentFuture);
}
private void handlePassword(ByteBuf buffer, final Channel channel, int payloadLength)
{
switch (authContext.getAuthenticationMethodType())
{
case GSS:
byte[] inputToken = readByteArray(buffer, payloadLength);
if (inputToken == null)
{
messages.sendErrorResponse(channel, new IllegalStateException("GSS Token cannot be empty"));
return;
}
Subject serverSubject = SubjectTools.getSubjectFromKeytab(gssConfig.getKerberosKeytabFile(), gssConfig.getKerberosUserPrincipal(), false);
GSSManager manager = GSSManager.getInstance();
try
{
GSSCredential gssCredential = Subject.doAs(serverSubject, new AcceptorCreator(manager, gssConfig.getKerberosUserPrincipal()));
GSSContext gssContext = manager.createContext(gssCredential);
gssContext.requestCredDeleg(true);
gssContext.requestMutualAuth(true);
byte[] outputToken;
if (!gssContext.isEstablished())
{
outputToken = gssContext.acceptSecContext(inputToken, 0, inputToken.length);
if (outputToken != null)
{
messages.sendGssOutToken(channel, outputToken);
}
}
Subject delegatedSubject = GSSUtil.createSubject(gssContext.getSrcName(), gssContext.getDelegCred());
finishAuthentication(channel, delegatedSubject);
}
catch (PrivilegedActionException | GSSException e)
{
throw new RuntimeException(e);
}
break;
case PASSWORD:
default:
char[] passwd = readCharArray(buffer);
if (passwd != null)
{
authContext.setSecurePassword(passwd);
}
finishAuthentication(channel);
}
}
/**
* Bind Message Header: | 'B' | int32 len
*
* Body:
*
* | string portalName | string statementName
* | int16 numFormatCodes
* foreach
* | int16 formatCode
* | int16 numParams
* foreach
* | int32 valueLength
* | byteN value
* | int16 numResultColumnFormatCodes
* foreach
* | int16 formatCode
*
*/
private void handleBindMessage(ByteBuf buffer, Channel channel)
{
String portalName = readCString(buffer);
String statementName = readCString(buffer);
FormatCodes.FormatCode[] formatCodes = FormatCodes.fromBuffer(buffer);
short numParams = buffer.readShort();
List params = createList(numParams);
for (int i = 0; i < numParams; i++)
{
int valueLength = buffer.readInt();
if (valueLength == -1)
{
params.add(null);
}
else
{
int paramType = session.getParamType(statementName, i);
PGType pgType = PGTypes.get(paramType, 0);
FormatCodes.FormatCode formatCode = getFormatCode(formatCodes, i);
switch (formatCode)
{
case TEXT:
params.add(pgType.readTextValue(buffer, valueLength));
break;
case BINARY:
params.add(pgType.readBinaryValue(buffer, valueLength));
break;
default:
/* messages.sendErrorResponse(
channel,
getAccessControl.apply(session.sessionSettings()),
new UnsupportedOperationException(String.format(
Locale.ENGLISH,
"Unsupported format code '%d' for param '%s'",
formatCode.ordinal(),
paramType.getName())
)
);*/
messages.sendErrorResponse(
channel,
new UnsupportedOperationException(String.format(
Locale.ENGLISH,
"Unsupported format code '%d' for param '%s'",
formatCode.ordinal(),
paramType)
)
);
return;
}
}
}
FormatCodes.FormatCode[] resultFormatCodes = FormatCodes.fromBuffer(buffer);
CompletableFuture> bindCompletionFuture = session.bindAsync(portalName, statementName, params, resultFormatCodes);
CompletableFuture bindMsgSentFuture = bindCompletionFuture.thenRun(() -> messages.sendBindComplete(channel));
session.setActiveExecution(bindMsgSentFuture);
}
private List createList(short size)
{
return size == 0 ? Collections.emptyList() : new ArrayList(size);
}
/**
* Describe Message Header: | 'D' | int32 len
*
* Body: | 'S' = prepared statement or 'P' = portal | string nameOfPortalOrStatement
*/
private void handleDescribeMessage(ByteBuf buffer, Channel channel) throws Exception
{
OpenTelemetryUtil.TOTAL_METADATA.add(1);
OpenTelemetryUtil.ACTIVE_METADATA.add(1);
long startTime = System.currentTimeMillis();
Tracer tracer = OpenTelemetryUtil.getTracer();
Span span = tracer.spanBuilder("WireProtocol Handle Describe Message").startSpan();
try (Scope ignored = span.makeCurrent())
{
byte type = buffer.readByte();
String portalOrStatement = readCString(buffer);
CompletableFuture describeResultFuture = session.describeAsync((char) type, portalOrStatement);
CompletableFuture describeMsgSentFuture = describeResultFuture.thenAccept(describeResult ->
{
try
{
PostgresResultSetMetaData fields = describeResult.getFields();
if (type == 'S')
{
ParameterMetaData parameters = describeResult.getParameters();
messages.sendParameterDescription(channel, parameters);
}
if (fields == null)
{
messages.sendNoData(channel);
}
else
{
FormatCodes.FormatCode[] resultFormatCodes =
type == 'P' ? session.getResultFormatCodes(portalOrStatement) : null;
messages.sendRowDescription(channel, fields, resultFormatCodes);
}
OpenTelemetryUtil.TOTAL_SUCCESS_METADATA.add(1);
OpenTelemetryUtil.METADATA_DURATION.record(System.currentTimeMillis() - startTime);
}
catch (Exception e)
{
span.setStatus(StatusCode.ERROR, "Failed to handle describe message");
span.recordException(e);
OpenTelemetryUtil.TOTAL_FAILURE_METADATA.add(1);
throw PostgresServerException.wrapException(e);
}
finally
{
OpenTelemetryUtil.ACTIVE_METADATA.add(-1);
span.end();
}
});
session.setActiveExecution(describeMsgSentFuture);
}
}
/**
* Execute Message Header: | 'E' | int32 len
*
* Body: | string portalName | int32 maxRows (0 = unlimited)
*/
private void handleExecute(ByteBuf buffer, DelayableWriteChannel channel)
{
Tracer tracer = OpenTelemetryUtil.getTracer();
Span span = tracer.spanBuilder("WireProtocol Handle Execute").startSpan();
try (Scope ignored = span.makeCurrent())
{
String portalName = readCString(buffer);
int maxRows = buffer.readInt();
// .execute is going async and may execute the query in another thread-pool.
// The results are later sent to the clients via the `ResultReceiver` created
// above, The `channel.write` calls - which the `ResultReceiver` makes - may
// happen in a thread which is *not* a netty thread.
// If that is the case, netty schedules the writes instead of running them
// immediately. A consequence of that is that *this* thread can continue
// processing other messages from the client, and if this thread then sends messages to the
// client, these are sent immediately, overtaking the result messages of the
// execute that is triggered here.
//
// This would lead to out-of-order messages. For example, we could send a
// `parseComplete` before the `commandComplete` of the previous statement has
// been transmitted.
//
// To ensure clients receive messages in the correct order we delay all writes
// The "finish" logic of the ResultReceivers writes out all pending writes/unblocks the channel
session.executeAsync(portalName, maxRows, q -> new ResultSetReceiver(q, channel, false, null, messages));
}
catch (Exception e)
{
span.setStatus(StatusCode.ERROR, "Failed to execute");
span.recordException(e);
throw PostgresServerException.wrapException(e);
}
finally
{
span.end();
}
}
private void handleSync(DelayableWriteChannel channel)
{
if (ignoreTillSync)
{
ignoreTillSync = false;
/* // If an error happens all sub-sequent messages can be ignored until the client sends a sync message
// We need to discard any deferred executions to make sure that the *next* sync isn't executing
// something we had previously deferred.
// E.g. JDBC client:
// 1) `addBatch` -> success (results in bind+execute -> we defer execution)
// 2) `addBatch` -> failure (ignoreTillSync=true; we stop after bind, no execute, etc..)
// 3) `sync` -> sendReadyForQuery (this if branch)
// 4) p, b, e -> We've a new query deferred.
// 5) `sync` -> We must execute the query from 4, but not 1)
//session.resetDeferredExecutions();
channel.writePendingMessages();*/
session.clearState();
messages.sendReadyForQuery(channel);
return;
}
try
{
ReadyForQueryCallback readyForQueryCallback = new ReadyForQueryCallback(channel, messages);
session.sync().whenComplete(readyForQueryCallback);
}
catch (Throwable t)
{
channel.discardDelayedWrites();
messages.sendErrorResponse(channel, t);
messages.sendReadyForQuery(channel);
}
}
/**
* | 'C' | int32 len | byte portalOrStatement | string portalOrStatementName |
*/
private void handleClose(ByteBuf buffer, Channel channel)
{
byte b = buffer.readByte();
String portalOrStatementName = readCString(buffer);
session.close((char) b, portalOrStatementName);
messages.sendCloseComplete(channel);
}
/* void handleSimpleQuery(ByteBuf buffer, final DelayableWriteChannel channel) {
String queryString = readCString(buffer);
assert queryString != null : "query must not be nulL";
if (queryString.isEmpty() || ";".equals(queryString)) {
messages.sendEmptyQueryResponse(channel);
sendReadyForQuery(channel, TransactionState.IDLE);
return;
}
List statements;
try {
statements = SqlParser.createStatements(queryString);
} catch (Exception ex) {
messages.sendErrorResponse(channel, getAccessControl.apply(session.sessionSettings()), ex);
sendReadyForQuery(channel, TransactionState.IDLE);
return;
}
CompletableFuture> composedFuture = CompletableFuture.completedFuture(null);
for (var statement : statements) {
composedFuture = composedFuture.thenCompose(result -> handleSingleQuery(statement, channel));
}
composedFuture.whenComplete(new ReadyForQueryCallback(channel, TransactionState.IDLE));
}*/
void handleSimpleQuery(ByteBuf buffer, final DelayableWriteChannel channel)
{
Tracer tracer = OpenTelemetryUtil.getTracer();
Span span = tracer.spanBuilder("WireProtocol Handle Simple Query").startSpan();
try (Scope scope = span.makeCurrent())
{
String queryString = readCString(buffer);
Assert.assertTrue(queryString != null, () -> "query must not be nulL");
span.setAttribute("query", queryString);
if (queryString.isEmpty() || ";".equals(queryString))
{
messages.sendEmptyQueryResponse(channel);
messages.sendReadyForQuery(channel);
return;
}
List queries = QueryStringSplitter.splitQuery(queryString);
CompletableFuture> composedFuture = CompletableFuture.completedFuture(null);
for (String query : queries)
{
composedFuture = composedFuture.thenCompose(result -> handleSingleQuery(query, channel));
}
composedFuture.whenComplete(new ReadyForQueryCallback(channel, messages));
}
catch (Exception e)
{
span.setStatus(StatusCode.ERROR, "Failed to handle simple query");
span.recordException(e);
throw e;
}
finally
{
span.end();
}
}
private CompletableFuture> handleSingleQuery(String query, DelayableWriteChannel channel)
{
Tracer tracer = OpenTelemetryUtil.getTracer();
Span span = tracer.spanBuilder("WireProtocol Handle Simple Query").startSpan();
try (Scope scope = span.makeCurrent())
{
CompletableFuture> result = new CompletableFuture<>();
if (query.isEmpty() || ";".equals(query))
{
messages.sendEmptyQueryResponse(channel);
result.complete(null);
return result;
}
try
{
session.executeSimple(query, () -> new ResultSetReceiver(query, channel, true, null, messages));
return session.sync();
}
catch (Throwable t)
{
//TODO need to understand this usecase
LOGGER.warn("Error processing single query", t);
session.clearState();
messages.sendErrorResponse(channel, t);
result.completeExceptionally(t);
return result;
}
}
finally
{
span.end();
}
}
private void handleCancelRequestBody(ByteBuf buffer, Channel channel)
{
/* var keyData = KeyData.of(buffer);
sessions.cancel(keyData);*/
// Cancel request is sent by the client over a new connection.
// This closes the new connection, not the one running the query.
handler.closeSession();
channel.close();
}
private static class AcceptorCreator implements PrivilegedExceptionAction
{
private final GSSManager manager;
private final String accountPrincipal;
public AcceptorCreator(GSSManager manager, String accountPrincipal)
{
this.manager = manager;
this.accountPrincipal = accountPrincipal;
}
@Override
public GSSCredential run() throws Exception
{
final GSSName gssName = manager.createName(this.accountPrincipal, GSSName.NT_USER_NAME);
return manager
.createCredential(gssName, GSSCredential.DEFAULT_LIFETIME, new Oid[]{
new Oid("1.2.840.113554.1.2.2"), // Kerberos v5
new Oid("1.3.6.1.5.5.2") // SPNEGO
}, GSSCredential.ACCEPT_ONLY);
}
}
}