org.jboss.ejb.protocol.remote.EJBClientChannel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jboss-ejb-client Show documentation
Show all versions of jboss-ejb-client Show documentation
Client library for EJB applications working against Wildfly - Jakarta EE Variant
/*
* JBoss, Home of Professional Open Source.
* Copyright 2017 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.ejb.protocol.remote;
import static java.lang.Math.min;
import static org.jboss.ejb.protocol.remote.TCCLUtils.getAndSetSafeTCCL;
import static org.jboss.ejb.protocol.remote.TCCLUtils.resetTCCL;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.IoUtils.safeClose;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterInputStream;
import javax.ejb.CreateException;
import javax.ejb.EJBException;
import javax.ejb.NoSuchEJBException;
import javax.transaction.InvalidTransactionException;
import javax.transaction.RollbackException;
import javax.transaction.SystemException;
import javax.transaction.Transaction;
import javax.transaction.xa.Xid;
import org.jboss.ejb._private.Keys;
import org.jboss.ejb._private.Logs;
import org.jboss.ejb.client.AbstractInvocationContext;
import org.jboss.ejb.client.Affinity;
import org.jboss.ejb.client.AttachmentKey;
import org.jboss.ejb.client.AttachmentKeys;
import org.jboss.ejb.client.ClusterAffinity;
import org.jboss.ejb.client.EJBClient;
import org.jboss.ejb.client.EJBClientContext;
import org.jboss.ejb.client.EJBClientInvocationContext;
import org.jboss.ejb.client.EJBLocator;
import org.jboss.ejb.client.EJBModuleIdentifier;
import org.jboss.ejb.client.EJBReceiverInvocationContext;
import org.jboss.ejb.client.EJBSessionCreationInvocationContext;
import org.jboss.ejb.client.NodeAffinity;
import org.jboss.ejb.client.RequestSendFailedException;
import org.jboss.ejb.client.SessionID;
import org.jboss.ejb.client.StatefulEJBLocator;
import org.jboss.ejb.client.StatelessEJBLocator;
import org.jboss.ejb.client.TransactionID;
import org.jboss.ejb.client.UserTransactionID;
import org.jboss.ejb.client.XidTransactionID;
import org.jboss.marshalling.ByteInput;
import org.jboss.marshalling.Marshaller;
import org.jboss.marshalling.MarshallerFactory;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.MarshallingConfiguration;
import org.jboss.marshalling.Unmarshaller;
import org.jboss.remoting3.Channel;
import org.jboss.remoting3.Connection;
import org.jboss.remoting3.ConnectionPeerIdentity;
import org.jboss.remoting3.MessageInputStream;
import org.jboss.remoting3.MessageOutputStream;
import org.jboss.remoting3.RemotingOptions;
import org.jboss.remoting3._private.IntIndexHashMap;
import org.jboss.remoting3._private.IntIndexMap;
import org.jboss.remoting3.util.Invocation;
import org.jboss.remoting3.util.InvocationTracker;
import org.jboss.remoting3.util.StreamUtils;
import org.wildfly.common.Assert;
import org.wildfly.common.function.ExceptionBiFunction;
import org.wildfly.common.net.CidrAddress;
import org.wildfly.discovery.Discovery;
import org.wildfly.naming.client.NamingProvider;
import org.wildfly.security.auth.client.AuthenticationContext;
import org.wildfly.transaction.client.AbstractTransaction;
import org.wildfly.transaction.client.LocalTransaction;
import org.wildfly.transaction.client.RemoteTransaction;
import org.wildfly.transaction.client.RemoteTransactionContext;
import org.wildfly.transaction.client.XAOutflowHandle;
import org.wildfly.transaction.client.provider.remoting.SimpleIdResolver;
import org.xnio.Cancellable;
import org.xnio.FutureResult;
import org.xnio.IoFuture;
/**
* @author David M. Lloyd
* @author Tomasz Adamski
* @author Richard Opalka
*/
@SuppressWarnings("deprecation")
class EJBClientChannel {
private final MarshallerFactory marshallerFactory;
private final Channel channel;
private final int version;
private final DiscoveredNodeRegistry discoveredNodeRegistry;
private final InvocationTracker invocationTracker;
private final MarshallingConfiguration configuration;
private final IntIndexMap userTxnIds = new IntIndexHashMap(UserTransactionID::getId);
private final RemoteTransactionContext transactionContext;
private final AtomicInteger finishedParts = new AtomicInteger(0);
private final AtomicReference> futureResultRef;
private final RetryExecutorWrapper retryExecutorWrapper;
EJBClientChannel(final Channel channel, final int version, final DiscoveredNodeRegistry discoveredNodeRegistry, final FutureResult futureResult, RetryExecutorWrapper retryExecutorWrapper) {
this.channel = channel;
this.version = version;
this.discoveredNodeRegistry = discoveredNodeRegistry;
this.retryExecutorWrapper = retryExecutorWrapper;
marshallerFactory = Marshalling.getProvidedMarshallerFactory("river");
MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setClassResolver(ProtocolClassResolver.INSTANCE);
final Connection connection = channel.getConnection();
if (version < 3) {
configuration.setClassTable(ProtocolV1ClassTable.INSTANCE);
configuration.setObjectTable(ProtocolV1ObjectTable.INSTANCE);
configuration.setObjectResolver(new ProtocolV1ObjectResolver(connection, true));
configuration.setObjectPreResolver(ProtocolV1ObjectPreResolver.INSTANCE);
configuration.setVersion(2);
// Do not wait for cluster topology report.
finishedParts.set(0b10);
} else {
configuration.setObjectTable(ProtocolV3ObjectTable.INSTANCE);
configuration.setObjectResolver(new ProtocolV3ObjectResolver(connection, true));
configuration.setVersion(4);
// server does not present v3 unless the transaction service is also present
}
transactionContext = RemoteTransactionContext.getInstance();
this.configuration = configuration;
invocationTracker = new InvocationTracker(this.channel, channel.getOption(RemotingOptions.MAX_OUTBOUND_MESSAGES).intValue(), EJBClientChannel::mask);
futureResultRef = new AtomicReference<>(futureResult);
final String nodeName = connection.getRemoteEndpointName();
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);
nodeInformation.addAddress(this);
nodeInformation.setInvalid(false);
channel.addCloseHandler((ignored1, ignored2) -> nodeInformation.removeConnection(this));
}
static int mask(int original) {
return original & 0xffff;
}
private void processMessage(final MessageInputStream message) {
boolean leaveOpen = false;
try {
final int msg = message.readUnsignedByte();
// debug
String remoteEndpoint = channel.getConnection().getRemoteEndpointName();
switch (msg) {
case Protocol.TXN_RESPONSE:
case Protocol.INVOCATION_RESPONSE:
case Protocol.OPEN_SESSION_RESPONSE:
case Protocol.APPLICATION_EXCEPTION:
case Protocol.CANCEL_RESPONSE:
case Protocol.NO_SUCH_EJB:
case Protocol.NO_SUCH_METHOD:
case Protocol.SESSION_NOT_ACTIVE:
case Protocol.EJB_NOT_STATEFUL:
case Protocol.BAD_VIEW_TYPE:
case Protocol.PROCEED_ASYNC_RESPONSE:{
final int invId = message.readUnsignedShort();
leaveOpen = invocationTracker.signalResponse(invId, msg, message, false);
break;
}
case Protocol.COMPRESSED_INVOCATION_MESSAGE: {
DataInputStream inputStream = new DataInputStream(new InflaterInputStream(message));
final int realMessageId = inputStream.readByte();
final int invId = inputStream.readUnsignedShort();
leaveOpen = invocationTracker.signalResponse(invId, realMessageId, new ResponseMessageInputStream(inputStream, invId), false);
break;
}
case Protocol.MODULE_AVAILABLE: {
int count = PackedInteger.readPackedInteger(message);
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
final EJBModuleIdentifier[] moduleList = new EJBModuleIdentifier[count];
for (int i = 0; i < count; i ++) {
final String appName = message.readUTF();
final String moduleName = message.readUTF();
final String distinctName = message.readUTF();
final EJBModuleIdentifier moduleIdentifier = new EJBModuleIdentifier(appName, moduleName, distinctName);
moduleList[i] = moduleIdentifier;
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received MODULE_AVAILABLE(%x) message from node %s for module %s", msg, remoteEndpoint, moduleIdentifier);
}
}
nodeInformation.addModules(this, moduleList);
finishPart(0b01);
break;
}
case Protocol.MODULE_UNAVAILABLE: {
int count = PackedInteger.readPackedInteger(message);
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
final HashSet set = new HashSet<>(count);
for (int i = 0; i < count; i ++) {
final String appName = message.readUTF();
final String moduleName = message.readUTF();
final String distinctName = message.readUTF();
final EJBModuleIdentifier moduleIdentifier = new EJBModuleIdentifier(appName, moduleName, distinctName);
set.add(moduleIdentifier);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received MODULE_UNAVAILABLE(%x) message from node %s for module %s", msg, remoteEndpoint, moduleIdentifier);
}
}
nodeInformation.removeModules(this, set);
break;
}
case Protocol.CLUSTER_TOPOLOGY_ADDITION:
case Protocol.CLUSTER_TOPOLOGY_COMPLETE: {
int clusterCount = PackedInteger.readPackedInteger(message);
for (int i = 0; i < clusterCount; i ++) {
final String clusterName = message.readUTF();
int memberCount = PackedInteger.readPackedInteger(message);
for (int j = 0; j < memberCount; j ++) {
final String nodeName = message.readUTF();
discoveredNodeRegistry.addNode(clusterName, nodeName, channel.getConnection().getPeerURI());
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY(%x) message from node %s, registering cluster %s to node %s", msg, remoteEndpoint, clusterName, nodeName);
}
// create and register the concrete ServiceURLs for each client mapping
int mappingCount = PackedInteger.readPackedInteger(message);
for (int k = 0; k < mappingCount; k ++) {
int b = message.readUnsignedByte();
final boolean ip6 = allAreClear(b, 1);
int netmaskBits = b >>> 1;
final byte[] sourceIpBytes = new byte[ip6 ? 16 : 4];
message.readFully(sourceIpBytes);
final CidrAddress block = CidrAddress.create(sourceIpBytes, netmaskBits);
final String destHost = message.readUTF();
final int destPort = message.readUnsignedShort();
final InetSocketAddress destUnresolved = InetSocketAddress.createUnresolved(destHost, destPort);
nodeInformation.addAddress(channel.getConnection().getProtocol(), clusterName, block, destUnresolved);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY(%x) message block from %s, registering block %s to address %s", msg, remoteEndpoint, block, destUnresolved);
}
}
}
}
finishPart(0b10);
break;
}
case Protocol.CLUSTER_TOPOLOGY_REMOVAL: {
int clusterCount = PackedInteger.readPackedInteger(message);
for (int i = 0; i < clusterCount; i ++) {
String clusterName = message.readUTF();
discoveredNodeRegistry.removeCluster(clusterName);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY_REMOVAL(%x) message from node %s for cluster %s", msg, remoteEndpoint, clusterName);
}
for (NodeInformation nodeInformation : discoveredNodeRegistry.getAllNodeInformation()) {
nodeInformation.removeCluster(clusterName);
}
}
break;
}
case Protocol.CLUSTER_TOPOLOGY_NODE_REMOVAL: {
int clusterCount = PackedInteger.readPackedInteger(message);
for (int i = 0; i < clusterCount; i ++) {
String clusterName = message.readUTF();
int memberCount = PackedInteger.readPackedInteger(message);
for (int j = 0; j < memberCount; j ++) {
String nodeName = message.readUTF();
discoveredNodeRegistry.removeNode(clusterName, nodeName);
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);
nodeInformation.removeCluster(clusterName);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY_NODE_REMOVAL(%x) message from node %s for (cluster, node) = (%s, %s)", msg, remoteEndpoint, clusterName, nodeName);
}
}
}
break;
}
default: {
// ignore message
}
}
} catch (IOException e) {
} finally {
if (! leaveOpen) {
safeClose(message);
}
}
}
private static final AttachmentKey INV_KEY = new AttachmentKey<>();
public void processInvocation(final EJBReceiverInvocationContext receiverContext, final ConnectionPeerIdentity peerIdentity) {
MethodInvocation invocation = invocationTracker.addInvocation(id -> new MethodInvocation(id, receiverContext));
final EJBClientInvocationContext invocationContext = receiverContext.getClientInvocationContext();
invocationContext.putAttachment(INV_KEY, invocation);
final EJBLocator> locator = invocationContext.getLocator();
final int peerIdentityId;
if (version >= 3) {
peerIdentityId = peerIdentity.getId();
} else {
peerIdentityId = 0; // unused
}
try (MessageOutputStream underlying = invocationTracker.allocateMessage()) {
MessageOutputStream out = handleCompression(invocationContext, underlying);
try {
out.write(Protocol.INVOCATION_REQUEST);
out.writeShort(invocation.getIndex());
Marshaller marshaller = getMarshaller();
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(out)));
final Method invokedMethod = invocationContext.getInvokedMethod();
final Object[] parameters = invocationContext.getParameters();
if (version < 3) {
// method name as UTF string
out.writeUTF(invokedMethod.getName());
// write the method signature as UTF string
out.writeUTF(invocationContext.getMethodSignatureString());
// protocol 1 & 2 redundant locator objects
marshaller.writeObject(locator.getAppName());
marshaller.writeObject(locator.getModuleName());
marshaller.writeObject(locator.getDistinctName());
marshaller.writeObject(locator.getBeanName());
} else {
// write identifier to allow the peer to find the class loader
marshaller.writeObject(locator.getIdentifier());
// write method locator
marshaller.writeObject(invocationContext.getMethodLocator());
// write sec context
marshaller.writeInt(peerIdentityId);
// write weak affinity
marshaller.writeObject(invocationContext.getWeakAffinity());
// write response compression info
if (invocationContext.isCompressResponse()) {
int compressionLevel = invocationContext.getCompressionLevel() > 0 ? invocationContext.getCompressionLevel() : 9;
marshaller.writeByte(compressionLevel);
} else {
marshaller.writeByte(0);
}
// write txn context
invocation.setOutflowHandle(writeTransaction(invocationContext.getTransaction(), marshaller, invocationContext.getAuthenticationContext()));
}
// write the invocation locator itself
marshaller.writeObject(locator);
// and the parameters
if (parameters != null && parameters.length > 0) {
for (final Object methodParam : parameters) {
marshaller.writeObject(methodParam);
}
}
// now, attachments
// we write out the private (a.k.a JBoss specific) attachments as well as public invocation context data
// (a.k.a user application specific data)
final Map, ?> privateAttachments = invocationContext.getAttachments();
final Map contextData = invocationContext.getContextData();
// write the attachment count which is the sum of invocation context data + 1 (since we write
// out the private attachments under a single key with the value being the entire attachment map)
int totalContextData = contextData.size();
if (version >= 3) {
// Just write the attachments.
PackedInteger.writePackedInteger(marshaller, totalContextData);
for (Map.Entry invocationContextData : contextData.entrySet()) {
marshaller.writeObject(invocationContextData.getKey());
marshaller.writeObject(invocationContextData.getValue());
}
} else {
final Transaction transaction = invocationContext.getTransaction();
// We are only marshalling those attachments whose keys are present in the object table
final Map, Object> marshalledPrivateAttachments = new HashMap<>();
for (final Map.Entry, ?> entry : privateAttachments.entrySet()) {
final AttachmentKey> key = entry.getKey();
if (key == AttachmentKeys.TRANSACTION_ID_KEY) {
// skip!
} else if (ProtocolV1ObjectTable.INSTANCE.getObjectWriter(key) != null) {
marshalledPrivateAttachments.put(key, entry.getValue());
}
}
if (transaction != null) {
marshalledPrivateAttachments.put(AttachmentKeys.TRANSACTION_ID_KEY, calculateTransactionId(transaction));
}
final boolean hasPrivateAttachments = ! marshalledPrivateAttachments.isEmpty();
if (hasPrivateAttachments) {
totalContextData++;
}
// Note: The code here is just for backward compatibility of 1.x and 2.x versions of EJB client project.
// Attach legacy transaction ID, if there is an active txn.
if (transaction != null) {
// we additionally add/duplicate the transaction id under a different attachment key
// to preserve backward compatibility. This is here just for 1.0.x backward compatibility
totalContextData++;
}
// backward compatibility code block for transaction id ends here.
PackedInteger.writePackedInteger(marshaller, totalContextData);
// write out public (application specific) context data
ProtocolObjectResolver.enableNonSerReplacement();
try {
for (Map.Entry invocationContextData : contextData.entrySet()) {
marshaller.writeObject(invocationContextData.getKey());
marshaller.writeObject(invocationContextData.getValue());
}
} finally {
ProtocolObjectResolver.disableNonSerReplacement();
}
if (hasPrivateAttachments) {
// now write out the JBoss specific attachments under a single key and the value will be the
// entire map of JBoss specific attachments
marshaller.writeObject(EJBClientInvocationContext.PRIVATE_ATTACHMENTS_KEY);
marshaller.writeObject(marshalledPrivateAttachments);
}
// Note: The code here is just for backward compatibility of 1.0.x version of EJB client project
// against AS7 7.1.x releases. Discussion here https://github.com/wildfly/jboss-ejb-client/pull/11#issuecomment-6573863
if (transaction != null) {
// we additionally add/duplicate the transaction id under a different attachment key
// to preserve backward compatibility. This is here just for 1.0.x backward compatibility
marshaller.writeObject(TransactionID.PRIVATE_DATA_KEY);
// This transaction id attachment duplication *won't* cause increase in EJB protocol message payload
// since we rely on JBoss Marshalling to use back references for the same transaction id object being
// written out
marshaller.writeObject(marshalledPrivateAttachments.get(AttachmentKeys.TRANSACTION_ID_KEY));
}
// backward compatibility code block for transaction id ends here.
}
// finished
marshaller.finish();
} catch (IOException e) {
underlying.cancel();
throw e;
} finally {
out.close();
}
} catch (IOException e) {
receiverContext.requestFailed(new RequestSendFailedException(e.getMessage() + " @ " + peerIdentity.getConnection().getPeerURI(), e, true), getRetryExecutor(receiverContext) );
} catch (RollbackException | SystemException | RuntimeException e) {
receiverContext.requestFailed(new EJBException(e.getMessage(), e), getRetryExecutor(receiverContext));
return;
}
}
/**
* Wraps the {@link MessageOutputStream message output stream} into a relevant {@link DataOutputStream}, taking into account various factors like the necessity to
* compress the data that gets passed along the stream
*
* @param invocationContext The Enterprise Beans client invocation context
* @param messageOutputStream The message output stream that needs to be wrapped
* @return the compressed stream
* @throws IOException if a problem occurs
*/
private MessageOutputStream handleCompression(final EJBClientInvocationContext invocationContext, final MessageOutputStream messageOutputStream) throws IOException {
// if the negotiated protocol version doesn't support compressed messages then just return
if(version == 1) {
return messageOutputStream;
}
// if "hints" are disabled, just return a DataOutputStream without the necessity of processing any "hints"
final Boolean hintsDisabled = invocationContext.getProxyAttachment(AttachmentKeys.HINTS_DISABLED);
if (hintsDisabled != null && hintsDisabled) {
if (Logs.REMOTING.isTraceEnabled()) {
Logs.REMOTING.trace("Hints are disabled. Ignoring any CompressionHint on methods being invoked on view " + invocationContext.getViewClass());
}
return messageOutputStream;
}
// process any CompressionHint
final int compressionLevel = invocationContext.getCompressionLevel();
// write out a attachment to indicate whether or not the response has to be compressed (v2 only)
if (version == 2 && invocationContext.isCompressResponse()) {
invocationContext.putAttachment(AttachmentKeys.COMPRESS_RESPONSE, true);
invocationContext.putAttachment(AttachmentKeys.RESPONSE_COMPRESSION_LEVEL, compressionLevel);
if (Logs.REMOTING.isTraceEnabled()) {
Logs.REMOTING.trace("Letting the server know that the response of method " + invocationContext.getInvokedMethod() + " has to be compressed with compression level = " + compressionLevel);
}
}
// create a compressed invocation data *only* if the request has to be compressed (note, it's perfectly valid for certain methods to just specify that only the response is compressed)
if (invocationContext.isCompressRequest()) {
// write out the header indicating that it's a compressed stream
messageOutputStream.write(Protocol.COMPRESSED_INVOCATION_MESSAGE);
// create the deflater using the specified level
final Deflater deflater = new Deflater(compressionLevel);
// wrap the message outputstream with the deflater stream so that *any subsequent* data writes to the stream are compressed
final DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(messageOutputStream, deflater);
if (Logs.REMOTING.isTraceEnabled()) {
Logs.REMOTING.trace("Using a compressing stream with compression level = " + compressionLevel + " for request data for EJB invocation on method " + invocationContext.getInvokedMethod());
}
return new WrapperMessageOutputStream(messageOutputStream, deflaterOutputStream);
} else {
// just return a normal DataOutputStream without any compression
return messageOutputStream;
}
}
private TransactionID calculateTransactionId(final Transaction transaction) throws RollbackException, SystemException, InvalidTransactionException {
final URI location = channel.getConnection().getPeerURI();
Assert.assertNotNull(transaction);
if (transaction instanceof RemoteTransaction) {
final RemoteTransaction remoteTransaction = (RemoteTransaction) transaction;
remoteTransaction.setLocation(location);
final SimpleIdResolver ir = remoteTransaction.getProviderInterface(SimpleIdResolver.class);
if (ir == null) throw Logs.TXN.cannotEnlistTx();
return new UserTransactionID(channel.getConnection().getRemoteEndpointName(), ir.getTransactionId(channel.getConnection()));
} else if (transaction instanceof LocalTransaction) {
final LocalTransaction localTransaction = (LocalTransaction) transaction;
final XAOutflowHandle outflowHandle = transactionContext.outflowTransaction(location, localTransaction);
// always verify V1/2 outflows
outflowHandle.verifyEnlistment();
return new XidTransactionID(outflowHandle.getXid());
} else {
throw Logs.TXN.cannotEnlistTx();
}
}
private XAOutflowHandle writeTransaction(final Transaction transaction, final DataOutput dataOutput,
final AuthenticationContext authenticationContext) throws IOException, RollbackException, SystemException {
if (authenticationContext != null) {
if (Logs.MAIN.isDebugEnabled()) {
Logs.MAIN.debug("Using existing AuthenticationContext for writeTransaction(...)");
}
try {
return authenticationContext.run((PrivilegedExceptionAction) () -> _writeTransaction(transaction, dataOutput));
} catch (PrivilegedActionException e) {
Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw (IOException) cause;
} else if (cause instanceof RollbackException) {
throw (RollbackException) cause;
} else if (cause instanceof SystemException) {
throw (SystemException) cause;
}
throw new RuntimeException(e);
}
} else {
if (Logs.MAIN.isDebugEnabled()) {
Logs.MAIN.debug("No existing AuthenticationContext for writeTransaction(...)");
}
return _writeTransaction(transaction, dataOutput);
}
}
private XAOutflowHandle _writeTransaction(final Transaction transaction, final DataOutput dataOutput) throws IOException, RollbackException, SystemException {
final URI location = channel.getConnection().getPeerURI();
if (transaction == null) {
dataOutput.writeByte(0);
return null;
} else if (transaction instanceof RemoteTransaction) {
final RemoteTransaction remoteTransaction = (RemoteTransaction) transaction;
remoteTransaction.setLocation(location);
dataOutput.writeByte(1);
final SimpleIdResolver ir = remoteTransaction.getProviderInterface(SimpleIdResolver.class);
if (ir == null) throw Logs.TXN.cannotEnlistTx();
final int id = ir.getTransactionId(channel.getConnection());
dataOutput.writeInt(id);
int transactionTimeout = remoteTransaction.getEstimatedRemainingTime();
if(transactionTimeout == 0) throw Logs.TXN.outflowTransactionTimeoutElapsed(transaction);
PackedInteger.writePackedInteger(dataOutput, transactionTimeout);
return null;
} else if (transaction instanceof LocalTransaction) {
final LocalTransaction localTransaction = (LocalTransaction) transaction;
final XAOutflowHandle outflowHandle = transactionContext.outflowTransaction(location, localTransaction);
final Xid xid = outflowHandle.getXid();
dataOutput.writeByte(2);
PackedInteger.writePackedInteger(dataOutput, xid.getFormatId());
final byte[] gtid = xid.getGlobalTransactionId();
dataOutput.writeByte(gtid.length);
dataOutput.write(gtid);
final byte[] bq = xid.getBranchQualifier();
// this will normally be zero, but write it anyway just in case we need to change it later
dataOutput.writeByte(bq.length);
dataOutput.write(bq);
int transactionTimeout = outflowHandle.getRemainingTime();
if(transactionTimeout == 0) throw Logs.TXN.outflowTransactionTimeoutElapsed(transaction);
PackedInteger.writePackedInteger(dataOutput, transactionTimeout);
return outflowHandle;
} else {
throw Logs.TXN.cannotEnlistTx();
}
}
boolean cancelInvocation(final EJBReceiverInvocationContext receiverContext, boolean cancelIfRunning) {
if (version < 3 && ! cancelIfRunning) {
// keep legacy behavior
return false;
}
final MethodInvocation invocation = receiverContext.getClientInvocationContext().getAttachment(INV_KEY);
if (invocation == null) {
// lost it somehow
return false;
}
if (invocation.alloc()) try {
final int index = invocation.getIndex();
try (MessageOutputStream out = invocationTracker.allocateMessage()) {
out.write(Protocol.CANCEL_REQUEST);
out.writeShort(index);
if (version >= 3) {
out.writeBoolean(cancelIfRunning);
}
} catch (IOException ignored) {}
} finally {
invocation.free();
}
// now await the result
return invocation.receiverInvocationContext.getClientInvocationContext().awaitCancellationResult();
}
private Marshaller getMarshaller() throws IOException {
return marshallerFactory.createMarshaller(configuration);
}
public StatefulEJBLocator openSession(final StatelessEJBLocator statelessLocator, final ConnectionPeerIdentity identity, EJBSessionCreationInvocationContext clientInvocationContext) throws Exception {
SessionOpenInvocation invocation = invocationTracker.addInvocation(id -> new SessionOpenInvocation<>(id, statelessLocator, clientInvocationContext));
try (MessageOutputStream out = invocationTracker.allocateMessage()) {
out.write(Protocol.OPEN_SESSION_REQUEST);
out.writeShort(invocation.getIndex());
writeRawIdentifier(statelessLocator, out);
if (version >= 3) {
out.writeInt(identity.getId());
invocation.setOutflowHandle(writeTransaction(clientInvocationContext.getTransaction(), out, clientInvocationContext.getAuthenticationContext()));
}
} catch (IOException e) {
CreateException createException = new CreateException(e.getMessage());
createException.initCause(e);
throw createException;
}
// await the response
return invocation.getResult();
}
private static void writeRawIdentifier(final EJBLocator statelessLocator, final MessageOutputStream out) throws IOException {
final String appName = statelessLocator.getAppName();
out.writeUTF(appName == null ? "" : appName);
out.writeUTF(statelessLocator.getModuleName());
final String distinctName = statelessLocator.getDistinctName();
out.writeUTF(distinctName == null ? "" : distinctName);
out.writeUTF(statelessLocator.getBeanName());
}
static IoFuture construct(final Channel channel, final DiscoveredNodeRegistry discoveredNodeRegistry, RetryExecutorWrapper retryExecutorWrapper) {
FutureResult futureResult = new FutureResult<>();
// now perform opening negotiation: receive server greeting
channel.receiveMessage(new Channel.Receiver() {
public void handleError(final Channel channel, final IOException error) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
futureResult.setException(error);
} finally {
resetTCCL(oldCL);
}
}
public void handleEnd(final Channel channel) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
futureResult.setCancelled();
} finally {
resetTCCL(oldCL);
}
}
public void handleMessage(final Channel channel, final MessageInputStream message) {
final ClassLoader oldCL = getAndSetSafeTCCL();
// receive message body
try {
final int version = min(3, StreamUtils.readInt8(message));
// drain the rest of the message because it's just garbage really
while (message.read() != -1) {
message.skip(Long.MAX_VALUE);
}
// send back result
try (MessageOutputStream out = channel.writeMessage()) {
out.write(version);
out.writeUTF("river");
}
// almost done; wait for initial module available report
final EJBClientChannel ejbClientChannel = new EJBClientChannel(channel, version, discoveredNodeRegistry, futureResult, retryExecutorWrapper);
channel.receiveMessage(new Channel.Receiver() {
public void handleError(final Channel channel, final IOException error) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
futureResult.setException(error);
} finally {
safeClose(channel);
resetTCCL(oldCL);
}
}
public void handleEnd(final Channel channel) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
futureResult.setException(new EOFException());
} finally {
safeClose(channel);
resetTCCL(oldCL);
}
}
public void handleMessage(final Channel channel, final MessageInputStream message) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
ejbClientChannel.processMessage(message);
} finally {
channel.receiveMessage(this);
resetTCCL(oldCL);
}
}
});
} catch (final IOException e) {
channel.closeAsync();
channel.addCloseHandler((closed, exception) -> futureResult.setException(e));
} finally {
resetTCCL(oldCL);
}
}
});
futureResult.addCancelHandler(new Cancellable() {
public Cancellable cancel() {
if (futureResult.setCancelled()) {
safeClose(channel);
}
return this;
}
});
return futureResult.getIoFuture();
}
InvocationTracker getInvocationTracker() {
return invocationTracker;
}
/**
* Glue two stack traces together.
*
* @param exception the exception which occurred in another thread
* @param userStackTrace the stack trace of the current thread from {@link Thread#getStackTrace()}
* @param trimCount the number of frames to trim
* @param msg the message to use
*/
private static void glueStackTraces(final Throwable exception, final StackTraceElement[] userStackTrace,
final int trimCount, final String msg) {
final StackTraceElement[] est = exception.getStackTrace();
final StackTraceElement[] fst = Arrays.copyOf(est, est.length + userStackTrace.length - trimCount + 1);
fst[est.length] = new StackTraceElement("..." + msg + "..", "", null, -1);
System.arraycopy(userStackTrace, trimCount, fst, est.length + 1, userStackTrace.length - trimCount);
exception.setStackTrace(fst);
}
final class SessionOpenInvocation extends Invocation {
private final StatelessEJBLocator statelessLocator;
private final EJBSessionCreationInvocationContext clientInvocationContext;
private int id;
private MessageInputStream inputStream;
private XAOutflowHandle outflowHandle;
private IOException ex;
protected SessionOpenInvocation(final int index, final StatelessEJBLocator statelessLocator, EJBSessionCreationInvocationContext clientInvocationContext) {
super(index);
this.statelessLocator = statelessLocator;
this.clientInvocationContext = clientInvocationContext;
}
public void handleResponse(final int id, final MessageInputStream inputStream) {
synchronized (this) {
this.id = id;
this.inputStream = inputStream;
notifyAll();
}
}
public void handleClosed() {
synchronized (this) {
notifyAll();
}
}
public void handleException(IOException cause) {
synchronized (this) {
this.ex = cause;
notifyAll();
}
}
void setOutflowHandle(final XAOutflowHandle outflowHandle) {
this.outflowHandle = outflowHandle;
}
XAOutflowHandle getOutflowHandle() {
return outflowHandle;
}
StatefulEJBLocator getResult() throws Exception {
Exception e;
try (ResponseMessageInputStream response = removeInvocationResult()) {
switch (id) {
case Protocol.OPEN_SESSION_RESPONSE: {
Affinity affinity;
int size = PackedInteger.readPackedInteger(response);
byte[] bytes = new byte[size];
response.readFully(bytes);
// todo: pool unmarshallers? use very small instance count config?
if (1 <= version && version <= 2) {
try (final Unmarshaller unmarshaller = createUnmarshaller()) {
unmarshaller.start(response);
affinity = unmarshaller.readObject(Affinity.class);
unmarshaller.finish();
}
} else {
affinity = statelessLocator.getAffinity();
final int cmd = response.readUnsignedByte();
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) {
if (cmd == 0) {
outflowHandle.forgetEnlistment();
} else if (cmd == 1) {
try {
outflowHandle.verifyEnlistment();
} catch (RollbackException | SystemException e1) {
throw new EJBException(e1);
}
} else if (cmd == 2) {
outflowHandle.nonMasterEnlistment();
}
}
final int updateBits = response.readUnsignedByte();
if (allAreSet(updateBits, Protocol.UPDATE_BIT_WEAK_AFFINITY)) {
final byte[] b = new byte[PackedInteger.readPackedInteger(response)];
response.readFully(b);
clientInvocationContext.setWeakAffinity(new NodeAffinity(new String(b, StandardCharsets.UTF_8)));
}
if (allAreSet(updateBits, Protocol.UPDATE_BIT_STRONG_AFFINITY)) {
final byte[] b = new byte[PackedInteger.readPackedInteger(response)];
response.readFully(b);
String clusterName = new String(b, StandardCharsets.UTF_8);
affinity = new ClusterAffinity(clusterName);
}
}
StatefulEJBLocator locator = statelessLocator.withSessionAndAffinity(SessionID.createSessionID(bytes), affinity);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("EJBClientChannel.SessionOpenInvocation.getResult(): updating Locator (sessionID), new = %s; weakAffinity = %s",
locator, clientInvocationContext.getWeakAffinity());
}
clientInvocationContext.setLocator(locator);
return locator;
}
case Protocol.APPLICATION_EXCEPTION: {
if (version >= 3) {
final int cmd = response.readUnsignedByte();
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) {
if (cmd == 0) {
outflowHandle.forgetEnlistment();
} else if (cmd == 1) {
try {
outflowHandle.verifyEnlistment();
} catch (RollbackException | SystemException e1) {
throw new EJBException(e1);
}
} else if (cmd == 2) {
outflowHandle.nonMasterEnlistment();
}
}
}
try (final Unmarshaller unmarshaller = createUnmarshaller()) {
unmarshaller.start(response);
e = unmarshaller.readObject(Exception.class);
unmarshaller.finish();
if (version < 3) {
// drain off attachments so the server doesn't complain
while (response.read() != -1) {
response.skip(Long.MAX_VALUE);
}
}
}
glueStackTraces(e, Thread.currentThread().getStackTrace(), 1, "asynchronous invocation");
break;
}
case Protocol.CANCEL_RESPONSE: {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(clientInvocationContext);
throw Logs.REMOTING.requestCancelled();
}
case Protocol.NO_SUCH_EJB: {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(clientInvocationContext);
final String message = response.readUTF();
throw new NoSuchEJBException(message + " @ " + getChannel().getConnection().getPeerURI());
}
case Protocol.BAD_VIEW_TYPE: {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(clientInvocationContext);
final String message = response.readUTF();
throw Logs.REMOTING.invalidViewTypeForInvocation(message);
}
case Protocol.EJB_NOT_STATEFUL: {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(clientInvocationContext);
final String message = response.readUTF();
throw Logs.REMOTING.ejbNotStateful(message);
}
default: {
throw new EJBException("Invalid EJB creation response (id " + id + ")");
}
}
} catch (ClassNotFoundException | IOException ex) {
throw new EJBException("Failed to read session create response", ex);
} finally {
invocationTracker.remove(this);
}
if (e == null) {
throw new EJBException("Null exception response");
}
// throw the application exception outside of the try block
throw e;
}
private ResponseMessageInputStream removeInvocationResult() {
MessageInputStream mis;
int id;
try {
synchronized (this) {
for (; ; ) {
id = this.getIndex();
if (inputStream != null) {
mis = inputStream;
inputStream = null;
break;
}
if (ex != null) {
throw new EJBException(ex);
}
if (id == -1) {
throw new EJBException("Connection closed");
}
wait();
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new EJBException("Session creation interrupted");
}
return new ResponseMessageInputStream(mis, id);
}
}
Unmarshaller createUnmarshaller() throws IOException {
return marshallerFactory.createUnmarshaller(configuration);
}
Channel getChannel() {
return channel;
}
UserTransactionID allocateUserTransactionID() {
final ThreadLocalRandom random = ThreadLocalRandom.current();
final String nodeName = getChannel().getConnection().getRemoteEndpointName();
final byte[] nameBytes;
try {
nameBytes = nodeName.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw Assert.unreachableCode();
}
final int length = nameBytes.length;
if (length > 255) {
throw Assert.unreachableCode();
}
final byte[] target = new byte[6 + length];
target[0] = 0x01;
target[1] = (byte) length;
System.arraycopy(nameBytes, 0, target, 2, length);
for (;;) {
int id = random.nextInt();
if (! userTxnIds.containsKey(id)) {
target[2 + length] = (byte) (id >> 24);
target[3 + length] = (byte) (id >> 16);
target[4 + length] = (byte) (id >> 8);
target[5 + length] = (byte) id;
final UserTransactionID uti = (UserTransactionID) TransactionID.createTransactionID(target);
if (userTxnIds.putIfAbsent(uti) == null) {
return uti;
}
}
}
}
void finishPart(int bit) {
int oldVal, newVal;
do {
oldVal = finishedParts.get();
if (allAreSet(oldVal, bit)) {
return;
}
newVal = oldVal | bit;
} while (! finishedParts.compareAndSet(oldVal, newVal));
if (newVal == 0b11) {
final FutureResult futureResult = futureResultRef.get();
if (futureResult != null && futureResultRef.compareAndSet(futureResult, null)) {
// done!
futureResult.setResult(this);
}
}
}
final class MethodInvocation extends Invocation {
private final EJBReceiverInvocationContext receiverInvocationContext;
private final AtomicInteger refCounter = new AtomicInteger(1);
private XAOutflowHandle outflowHandle;
MethodInvocation(final int index, final EJBReceiverInvocationContext receiverInvocationContext) {
super(index);
this.receiverInvocationContext = receiverInvocationContext;
}
boolean alloc() {
final AtomicInteger refCounter = this.refCounter;
int oldVal;
do {
oldVal = refCounter.get();
if (oldVal == 0) {
return false;
}
} while (! refCounter.compareAndSet(oldVal, oldVal + 1));
return true;
}
void free() {
final AtomicInteger refCounter = this.refCounter;
final int newVal = refCounter.decrementAndGet();
if (newVal == 0) {
invocationTracker.remove(this);
}
}
@Override
public void handleResponse(int parameter, MessageInputStream inputStream) {
handleResponse(parameter, new DataInputStream(inputStream));
}
private void handleResponse(final int id, final DataInputStream inputStream) {
switch (id) {
case Protocol.INVOCATION_RESPONSE: {
free();
final EJBClientInvocationContext context = receiverInvocationContext.getClientInvocationContext();
if (version >= 3) try {
final int cmd = inputStream.readUnsignedByte();
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) {
if (cmd == 0) {
outflowHandle.forgetEnlistment();
} else if (cmd == 1) {
outflowHandle.verifyEnlistment();
} else if (cmd == 2) {
outflowHandle.nonMasterEnlistment();
}
}
final int updateBits = inputStream.readUnsignedByte();
if (allAreSet(updateBits, Protocol.UPDATE_BIT_SESSION_ID)) {
byte[] encoded = new byte[PackedInteger.readPackedInteger(inputStream)];
inputStream.readFully(encoded);
final SessionID sessionID = SessionID.createSessionID(encoded);
final Object invokedProxy = context.getInvokedProxy();
EJBClient.convertToStateful(invokedProxy, sessionID);
context.setLocator(EJBClient.getLocatorFor(invokedProxy));
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated Locator (sessionID); new = %s", context.getLocator().getAffinity());
}
}
if (allAreSet(updateBits, Protocol.UPDATE_BIT_WEAK_AFFINITY)) {
byte[] b = new byte[PackedInteger.readPackedInteger(inputStream)];
inputStream.readFully(b);
context.setWeakAffinity(new NodeAffinity(new String(b, StandardCharsets.UTF_8)));
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated weak affinity = %s", context.getWeakAffinity());
}
}
if (allAreSet(updateBits, Protocol.UPDATE_BIT_STRONG_AFFINITY)) {
byte[] b = new byte[PackedInteger.readPackedInteger(inputStream)];
inputStream.readFully(b);
context.setLocator(context.getLocator().withNewAffinity(new ClusterAffinity(new String(b, StandardCharsets.UTF_8))));
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated strong affinity = %s", context.getLocator().getAffinity());
}
}
} catch (RuntimeException | IOException | RollbackException | SystemException e) {
receiverInvocationContext.requestFailed(new EJBException(e), getRetryExecutor(receiverInvocationContext) );
safeClose(inputStream);
break;
}
final NamingProvider provider = context.getProxyAttachment(Keys.NAMING_PROVIDER_ATTACHMENT_KEY);
receiverInvocationContext.resultReady(new MethodCallResultProducer(provider, inputStream, id));
break;
}
case Protocol.CANCEL_RESPONSE: {
free();
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
receiverInvocationContext.requestCancelled();
break;
}
case Protocol.APPLICATION_EXCEPTION: {
free();
receiverInvocationContext.resultReady(new ExceptionResultProducer(inputStream, id));
break;
}
case Protocol.NO_SUCH_EJB: {
free();
try {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
final String message = inputStream.readUTF();
final EJBModuleIdentifier moduleIdentifier = receiverInvocationContext.getClientInvocationContext().getLocator().getIdentifier().getModuleIdentifier();
final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
nodeInformation.removeModule(EJBClientChannel.this, moduleIdentifier);
receiverInvocationContext.requestFailed(new NoSuchEJBException(message + " @ " + getChannel().getConnection().getPeerURI()), getRetryExecutor(receiverInvocationContext));
} catch (IOException e) {
receiverInvocationContext.requestFailed(new EJBException("Failed to read 'No such EJB' response", e), getRetryExecutor(receiverInvocationContext));
} finally {
safeClose(inputStream);
}
break;
}
case Protocol.BAD_VIEW_TYPE: {
free();
try {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
final String message = inputStream.readUTF();
receiverInvocationContext.requestFailed(Logs.REMOTING.invalidViewTypeForInvocation(message), getRetryExecutor(receiverInvocationContext));
} catch (IOException e) {
receiverInvocationContext.requestFailed(new EJBException("Failed to read 'Bad EJB view type' response", e), getRetryExecutor(receiverInvocationContext));
} finally {
safeClose(inputStream);
}
break;
}
case Protocol.NO_SUCH_METHOD: {
free();
try {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
final String message = inputStream.readUTF();
// todo: I don't think this is the best exception type for this case...
receiverInvocationContext.requestFailed(new IllegalArgumentException(message), getRetryExecutor(receiverInvocationContext));
} catch (IOException e) {
receiverInvocationContext.requestFailed(new EJBException("Failed to read 'No such EJB method' response", e), getRetryExecutor(receiverInvocationContext));
} finally {
safeClose(inputStream);
}
break;
}
case Protocol.SESSION_NOT_ACTIVE: {
free();
try {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
final String message = inputStream.readUTF();
receiverInvocationContext.requestFailed(new EJBException(message), getRetryExecutor(receiverInvocationContext));
} catch (IOException e) {
receiverInvocationContext.requestFailed(new EJBException("Failed to read 'Session not active' response", e), getRetryExecutor(receiverInvocationContext));
} finally {
safeClose(inputStream);
}
break;
}
case Protocol.EJB_NOT_STATEFUL: {
free();
try {
if (version >= 3) {
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) outflowHandle.forgetEnlistment();
}
disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
final String message = inputStream.readUTF();
receiverInvocationContext.requestFailed(new EJBException(message), getRetryExecutor(receiverInvocationContext));
} catch (IOException e) {
receiverInvocationContext.requestFailed(new EJBException("Failed to read 'EJB not stateful' response"), getRetryExecutor(receiverInvocationContext));
} finally {
safeClose(inputStream);
}
break;
}
case Protocol.PROCEED_ASYNC_RESPONSE: {
// do not free; response is forthcoming
safeClose(inputStream);
receiverInvocationContext.proceedAsynchronously();
break;
}
default: {
free();
safeClose(inputStream);
receiverInvocationContext.requestFailed(new EJBException("Unknown protocol response"), getRetryExecutor(receiverInvocationContext));
break;
}
}
}
public void handleClosed() {
receiverInvocationContext.requestFailed(new EJBException(new ClosedChannelException()), getRetryExecutor(receiverInvocationContext));
}
public void handleException(IOException cause) {
receiverInvocationContext.requestFailed(new EJBException(cause), getRetryExecutor(receiverInvocationContext));
}
XAOutflowHandle getOutflowHandle() {
return outflowHandle;
}
void setOutflowHandle(final XAOutflowHandle outflowHandle) {
this.outflowHandle = outflowHandle;
}
class MethodCallResultProducer implements EJBReceiverInvocationContext.ResultProducer, ExceptionBiFunction {
private final NamingProvider namingProvider;
private final InputStream inputStream;
private final int id;
MethodCallResultProducer(final NamingProvider provider, final InputStream inputStream, final int id) {
namingProvider = provider;
this.inputStream = inputStream;
this.id = id;
}
/**
* Clean the ContextData on the {@link EJBClientInvocationContext} by removing all keys, contained in the
* {@link EJBClientInvocationContext#RETURNED_CONTEXT_DATA_KEY}.
*
* @param clientInvocationContext
*/
private void cleanContextDataBeforeUnmarshalling(EJBClientInvocationContext clientInvocationContext) {
Map contextData = clientInvocationContext.getContextData();
@SuppressWarnings("unchecked")
Set returnedContextDataKeys = (Set) contextData.get(EJBClientInvocationContext.RETURNED_CONTEXT_DATA_KEY);
if(returnedContextDataKeys != null) {
contextData.keySet().removeAll(returnedContextDataKeys);
}
}
public Object apply(final Void ignored0, final Void ignored1) throws Exception {
final ResponseMessageInputStream response;
if(inputStream instanceof ResponseMessageInputStream) {
response = (ResponseMessageInputStream) inputStream;
} else {
response = new ResponseMessageInputStream(inputStream, id);
}
Object result;
try (final Unmarshaller unmarshaller = createUnmarshaller()) {
unmarshaller.start(response);
result = unmarshaller.readObject();
int attachments = unmarshaller.readUnsignedByte();
final EJBClientInvocationContext clientInvocationContext = receiverInvocationContext.getClientInvocationContext();
// clean the client side ContextData to allow server side removed keys being removed - EJBCLIENT-425
cleanContextDataBeforeUnmarshalling(clientInvocationContext);
for (int i = 0; i < attachments; i ++) {
String key = unmarshaller.readObject(String.class);
if (version < 3 && key.equals(Affinity.WEAK_AFFINITY_CONTEXT_KEY)) {
final Affinity affinity = unmarshaller.readObject(Affinity.class);
clientInvocationContext.putAttachment(AttachmentKeys.WEAK_AFFINITY, affinity);
clientInvocationContext.setWeakAffinity(affinity);
if (Logs.INVOCATION.isDebugEnabled()) {
Logs.INVOCATION.debugf("EJBClientChannel.MethodCallResultProducer: updated weak affinity (version < 3) = %s", affinity);
}
} else if (key.equals(EJBClientInvocationContext.PRIVATE_ATTACHMENTS_KEY)) {
// skip
unmarshaller.readObject();
} else {
final Object value = unmarshaller.readObject();
if (value != null) clientInvocationContext.getContextData().put(key, value);
}
}
unmarshaller.finish();
} catch (IOException | ClassNotFoundException ex) {
discardResult();
throw new EJBException("Failed to read response", ex);
}
return result;
}
public Object getResult() throws Exception {
if (namingProvider != null) {
return namingProvider.performExceptionAction(this, null, null);
} else {
return apply(null, null);
}
}
public void discardResult() {
safeClose(inputStream);
}
}
class ExceptionResultProducer implements EJBReceiverInvocationContext.ResultProducer {
private final InputStream inputStream;
private final int id;
ExceptionResultProducer(final InputStream inputStream, final int id) {
this.inputStream = inputStream;
this.id = id;
}
public Object getResult() throws Exception {
Exception e;
try (final ResponseMessageInputStream response = new ResponseMessageInputStream(inputStream, id)) {
if (version >= 3) {
final int cmd = response.readUnsignedByte();
final XAOutflowHandle outflowHandle = getOutflowHandle();
if (outflowHandle != null) {
if (cmd == 0) {
outflowHandle.forgetEnlistment();
} else if (cmd == 1) {
try {
outflowHandle.verifyEnlistment();
} catch (RollbackException | SystemException e1) {
throw new EJBException(e1);
}
} else if (cmd == 2) {
outflowHandle.nonMasterEnlistment();
}
}
}
try (final Unmarshaller unmarshaller = createUnmarshaller()) {
unmarshaller.start(response);
e = unmarshaller.readObject(Exception.class);
if (version < 3) {
// discard attachment data, if any
int attachments = unmarshaller.readUnsignedByte();
for (int i = 0; i < attachments; i ++) {
unmarshaller.readObject();
unmarshaller.readObject();
}
}
unmarshaller.finish();
}
} catch (IOException | ClassNotFoundException ex) {
discardResult();
throw new EJBException("Failed to read response", ex);
}
if (e == null) {
throw new EJBException("Null exception response");
}
glueStackTraces(e, Thread.currentThread().getStackTrace(), 1, "asynchronous invocation");
throw e;
}
public void discardResult() {
safeClose(inputStream);
}
}
}
private static void disassociateRemoteTxIfPossible(AbstractInvocationContext context) {
AbstractTransaction transaction = context.getTransaction();
if (transaction instanceof RemoteTransaction) {
RemoteTransaction remote = (RemoteTransaction) transaction;
if (!remote.tryClearLocation()) {
Logs.TXN.tracef("Could not disassociate remote transaction (already in-use or completed) from %s", remote.getLocation());
}
}
}
private Executor getRetryExecutor() {
return retryExecutorWrapper.getExecutor(getChannel().getConnection().getEndpoint().getXnioWorker());
}
/*
* Provides a retry executor which will transfer the given thread contexts to the worker thread before execution.
*/
private Executor getRetryExecutor(EJBReceiverInvocationContext ejbReceiverInvocationContext) {
EJBClientContext ejbClientContext = ejbReceiverInvocationContext.getClientContext();
Discovery discovery = ejbReceiverInvocationContext.getDiscovery();
AuthenticationContext authentoicationContext = ejbReceiverInvocationContext.getAuthenticationContext();
Executor executor = getChannel().getConnection().getEndpoint().getXnioWorker();
return retryExecutorWrapper.getExecutor(executor, ejbClientContext, discovery, authentoicationContext);
}
static class ResponseMessageInputStream extends MessageInputStream implements ByteInput {
private final InputStream delegate;
private final int id;
ResponseMessageInputStream(final InputStream delegate, final int id) {
this.delegate = delegate;
this.id = id;
}
public int read() throws IOException {
return delegate.read();
}
public int read(final byte[] b, final int off, final int len) throws IOException {
return delegate.read(b, off, len);
}
public long skip(final long n) throws IOException {
return delegate.skip(n);
}
public void close() throws IOException {
delegate.close();
}
public int getId() {
return id;
}
}
}