All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jboss.ejb.protocol.remote.EJBClientChannel Maven / Gradle / Ivy

Go to download

This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up with different versions on classes on the class path).

There is a newer version: 34.0.0.Final
Show newest version
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2017 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jboss.ejb.protocol.remote;

import static java.lang.Math.min;
import static org.jboss.ejb.protocol.remote.TCCLUtils.getAndSetSafeTCCL;
import static org.jboss.ejb.protocol.remote.TCCLUtils.resetTCCL;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.IoUtils.safeClose;

import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterInputStream;
import javax.transaction.xa.Xid;

import jakarta.ejb.CreateException;
import jakarta.ejb.EJBException;
import jakarta.ejb.NoSuchEJBException;
import jakarta.transaction.InvalidTransactionException;
import jakarta.transaction.RollbackException;
import jakarta.transaction.SystemException;
import jakarta.transaction.Transaction;
import org.jboss.ejb._private.Keys;
import org.jboss.ejb._private.Logs;
import org.jboss.ejb.client.AbstractInvocationContext;
import org.jboss.ejb.client.Affinity;
import org.jboss.ejb.client.AttachmentKey;
import org.jboss.ejb.client.AttachmentKeys;
import org.jboss.ejb.client.ClusterAffinity;
import org.jboss.ejb.client.EJBClient;
import org.jboss.ejb.client.EJBClientContext;
import org.jboss.ejb.client.EJBClientInvocationContext;
import org.jboss.ejb.client.EJBLocator;
import org.jboss.ejb.client.EJBModuleIdentifier;
import org.jboss.ejb.client.EJBReceiverInvocationContext;
import org.jboss.ejb.client.EJBSessionCreationInvocationContext;
import org.jboss.ejb.client.NodeAffinity;
import org.jboss.ejb.client.RequestSendFailedException;
import org.jboss.ejb.client.SessionID;
import org.jboss.ejb.client.StatefulEJBLocator;
import org.jboss.ejb.client.StatelessEJBLocator;
import org.jboss.ejb.client.TransactionID;
import org.jboss.ejb.client.UserTransactionID;
import org.jboss.ejb.client.XidTransactionID;
import org.jboss.marshalling.ByteInput;
import org.jboss.marshalling.Marshaller;
import org.jboss.marshalling.MarshallerFactory;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.MarshallingConfiguration;
import org.jboss.marshalling.Unmarshaller;
import org.jboss.remoting3.Channel;
import org.jboss.remoting3.Connection;
import org.jboss.remoting3.ConnectionPeerIdentity;
import org.jboss.remoting3.MessageInputStream;
import org.jboss.remoting3.MessageOutputStream;
import org.jboss.remoting3.RemotingOptions;
import org.jboss.remoting3._private.IntIndexHashMap;
import org.jboss.remoting3._private.IntIndexMap;
import org.jboss.remoting3.util.Invocation;
import org.jboss.remoting3.util.InvocationTracker;
import org.jboss.remoting3.util.StreamUtils;
import org.wildfly.common.Assert;
import org.wildfly.common.function.ExceptionBiFunction;
import org.wildfly.common.net.CidrAddress;
import org.wildfly.discovery.Discovery;
import org.wildfly.naming.client.NamingProvider;
import org.wildfly.security.auth.client.AuthenticationContext;
import org.wildfly.transaction.client.AbstractTransaction;
import org.wildfly.transaction.client.LocalTransaction;
import org.wildfly.transaction.client.RemoteTransaction;
import org.wildfly.transaction.client.RemoteTransactionContext;
import org.wildfly.transaction.client.XAOutflowHandle;
import org.wildfly.transaction.client.provider.remoting.SimpleIdResolver;
import org.xnio.Cancellable;
import org.xnio.FutureResult;
import org.xnio.IoFuture;

/**
 * @author David M. Lloyd
 * @author Tomasz Adamski
 * @author Richard Opalka
 */
@SuppressWarnings("deprecation")
class EJBClientChannel {

    private final MarshallerFactory marshallerFactory;

    private final Channel channel;
    private final int version;
    private final DiscoveredNodeRegistry discoveredNodeRegistry;

    private final InvocationTracker invocationTracker;

    private final MarshallingConfiguration configuration;
    private final IntIndexMap userTxnIds = new IntIndexHashMap(UserTransactionID::getId);

    private final RemoteTransactionContext transactionContext;
    private final AtomicInteger finishedParts = new AtomicInteger(0);
    private final AtomicReference> futureResultRef;

    private final RetryExecutorWrapper retryExecutorWrapper;

    EJBClientChannel(final Channel channel, final int version, final DiscoveredNodeRegistry discoveredNodeRegistry, final FutureResult futureResult, RetryExecutorWrapper retryExecutorWrapper) {
        this.channel = channel;
        this.version = version;
        this.discoveredNodeRegistry = discoveredNodeRegistry;
        this.retryExecutorWrapper = retryExecutorWrapper;
        marshallerFactory = Marshalling.getProvidedMarshallerFactory("river");
        MarshallingConfiguration configuration = new MarshallingConfiguration();
        configuration.setClassResolver(ProtocolClassResolver.INSTANCE);
        final Connection connection = channel.getConnection();
        if (version < 3) {
            configuration.setClassTable(ProtocolV1ClassTable.INSTANCE);
            configuration.setObjectTable(ProtocolV1ObjectTable.INSTANCE);
            configuration.setObjectResolver(new ProtocolV1ObjectResolver(connection, true));
            configuration.setObjectPreResolver(ProtocolV1ObjectPreResolver.INSTANCE);
            configuration.setVersion(2);
            // Do not wait for cluster topology report.
            finishedParts.set(0b10);
        } else {
            configuration.setObjectTable(ProtocolV3ObjectTable.INSTANCE);
            configuration.setObjectResolver(new ProtocolV3ObjectResolver(connection, true));
            configuration.setVersion(4);
            // server does not present v3 unless the transaction service is also present
        }
        EENamespaceInteroperability.handleInteroperability(configuration, version);
        transactionContext = RemoteTransactionContext.getInstance();
        this.configuration = configuration;
        invocationTracker = new InvocationTracker(this.channel, channel.getOption(RemotingOptions.MAX_OUTBOUND_MESSAGES).intValue(), EJBClientChannel::mask);
        futureResultRef = new AtomicReference<>(futureResult);
        final String nodeName = connection.getRemoteEndpointName();
        final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);
        nodeInformation.addAddress(this);
        nodeInformation.setInvalid(false);
        channel.addCloseHandler((ignored1, ignored2) -> nodeInformation.removeConnection(this));
    }

    static int mask(int original) {
        return original & 0xffff;
    }

    private void processMessage(final MessageInputStream message) {
        boolean leaveOpen = false;
        try {
            final int msg = message.readUnsignedByte();
            // debug
            String remoteEndpoint = channel.getConnection().getRemoteEndpointName();
            switch (msg) {
                case Protocol.TXN_RESPONSE:
                case Protocol.INVOCATION_RESPONSE:
                case Protocol.OPEN_SESSION_RESPONSE:
                case Protocol.APPLICATION_EXCEPTION:
                case Protocol.CANCEL_RESPONSE:
                case Protocol.NO_SUCH_EJB:
                case Protocol.NO_SUCH_METHOD:
                case Protocol.SESSION_NOT_ACTIVE:
                case Protocol.EJB_NOT_STATEFUL:
                case Protocol.BAD_VIEW_TYPE:
                case Protocol.PROCEED_ASYNC_RESPONSE:{
                    final int invId = message.readUnsignedShort();
                    leaveOpen = invocationTracker.signalResponse(invId, msg, message, false);
                    break;
                }
                case Protocol.COMPRESSED_INVOCATION_MESSAGE: {
                    DataInputStream inputStream = new DataInputStream(new InflaterInputStream(message));
                    final int realMessageId = inputStream.readByte();
                    final int invId = inputStream.readUnsignedShort();
                    leaveOpen = invocationTracker.signalResponse(invId, realMessageId, new ResponseMessageInputStream(inputStream, invId), false);
                    break;
                }
                case Protocol.MODULE_AVAILABLE: {
                    int count = PackedInteger.readPackedInteger(message);
                    final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
                    final EJBModuleIdentifier[] moduleList = new EJBModuleIdentifier[count];
                    for (int i = 0; i < count; i ++) {
                        final String appName = message.readUTF();
                        final String moduleName = message.readUTF();
                        final String distinctName = message.readUTF();
                        final EJBModuleIdentifier moduleIdentifier = new EJBModuleIdentifier(appName, moduleName, distinctName);
                        moduleList[i] = moduleIdentifier;

                        if (Logs.INVOCATION.isDebugEnabled()) {
                            Logs.INVOCATION.debugf("Received MODULE_AVAILABLE(%x) message from node %s for module %s", msg, remoteEndpoint, moduleIdentifier);
                        }
                    }
                    nodeInformation.addModules(this, moduleList);
                    finishPart(0b01);
                    break;
                }
                case Protocol.MODULE_UNAVAILABLE: {
                    int count = PackedInteger.readPackedInteger(message);
                    final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
                    final HashSet set = new HashSet<>(count);
                    for (int i = 0; i < count; i ++) {
                        final String appName = message.readUTF();
                        final String moduleName = message.readUTF();
                        final String distinctName = message.readUTF();
                        final EJBModuleIdentifier moduleIdentifier = new EJBModuleIdentifier(appName, moduleName, distinctName);
                        set.add(moduleIdentifier);

                        if (Logs.INVOCATION.isDebugEnabled()) {
                            Logs.INVOCATION.debugf("Received MODULE_UNAVAILABLE(%x) message from node %s for module %s", msg, remoteEndpoint, moduleIdentifier);
                        }
                    }
                    nodeInformation.removeModules(this, set);
                    break;
                }
                case Protocol.CLUSTER_TOPOLOGY_ADDITION:
                case Protocol.CLUSTER_TOPOLOGY_COMPLETE: {
                    int clusterCount = PackedInteger.readPackedInteger(message);
                    for (int i = 0; i < clusterCount; i ++) {
                        final String clusterName = message.readUTF();
                        int memberCount = PackedInteger.readPackedInteger(message);
                        for (int j = 0; j < memberCount; j ++) {
                            final String nodeName = message.readUTF();
                            discoveredNodeRegistry.addNode(clusterName, nodeName, channel.getConnection().getPeerURI());
                            final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY(%x) message from node %s, registering cluster %s to node %s", msg, remoteEndpoint, clusterName, nodeName);
                            }

                            // create and register the concrete ServiceURLs for each client mapping
                            int mappingCount = PackedInteger.readPackedInteger(message);
                            for (int k = 0; k < mappingCount; k ++) {
                                int b = message.readUnsignedByte();
                                final boolean ip6 = allAreClear(b, 1);
                                int netmaskBits = b >>> 1;
                                final byte[] sourceIpBytes = new byte[ip6 ? 16 : 4];
                                message.readFully(sourceIpBytes);
                                final CidrAddress block = CidrAddress.create(sourceIpBytes, netmaskBits);
                                final String destHost = message.readUTF();
                                final int destPort = message.readUnsignedShort();
                                final InetSocketAddress destUnresolved = InetSocketAddress.createUnresolved(destHost, destPort);
                                nodeInformation.addAddress(channel.getConnection().getProtocol(), clusterName, block, destUnresolved);
                                if (Logs.INVOCATION.isDebugEnabled()) {
                                    Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY(%x) message block from %s, registering block %s to address %s", msg, remoteEndpoint, block, destUnresolved);
                                }
                            }
                        }
                    }
                    finishPart(0b10);
                    break;
                }
                case Protocol.CLUSTER_TOPOLOGY_REMOVAL: {
                    int clusterCount = PackedInteger.readPackedInteger(message);
                    for (int i = 0; i < clusterCount; i ++) {
                        String clusterName = message.readUTF();
                        discoveredNodeRegistry.removeCluster(clusterName);

                        if (Logs.INVOCATION.isDebugEnabled()) {
                            Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY_REMOVAL(%x) message from node %s for cluster %s", msg, remoteEndpoint, clusterName);
                        }

                        for (NodeInformation nodeInformation : discoveredNodeRegistry.getAllNodeInformation()) {
                            nodeInformation.removeCluster(clusterName);
                        }
                    }
                    break;
                }
                case Protocol.CLUSTER_TOPOLOGY_NODE_REMOVAL: {
                    int clusterCount = PackedInteger.readPackedInteger(message);
                    for (int i = 0; i < clusterCount; i ++) {
                        String clusterName = message.readUTF();
                        int memberCount = PackedInteger.readPackedInteger(message);
                        for (int j = 0; j < memberCount; j ++) {
                            String nodeName = message.readUTF();
                            discoveredNodeRegistry.removeNode(clusterName, nodeName);
                            final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(nodeName);
                            nodeInformation.removeCluster(clusterName);

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("Received CLUSTER_TOPOLOGY_NODE_REMOVAL(%x) message from node %s for (cluster, node) = (%s, %s)", msg, remoteEndpoint, clusterName, nodeName);
                            }
                        }
                    }
                    break;
                }
                default: {
                    // ignore message
                }
            }
        } catch (IOException e) {
        } finally {
            if (! leaveOpen) {
                safeClose(message);
            }
        }
    }

    private static final AttachmentKey INV_KEY = new AttachmentKey<>();

    public void processInvocation(final EJBReceiverInvocationContext receiverContext, final ConnectionPeerIdentity peerIdentity) {
        MethodInvocation invocation = invocationTracker.addInvocation(id -> new MethodInvocation(id, receiverContext));
        final EJBClientInvocationContext invocationContext = receiverContext.getClientInvocationContext();
        invocationContext.putAttachment(INV_KEY, invocation);
        final EJBLocator locator = invocationContext.getLocator();
        final int peerIdentityId;
        if (version >= 3) {
            peerIdentityId = peerIdentity.getId();
        } else {
            peerIdentityId = 0; // unused
        }
        try (MessageOutputStream underlying = invocationTracker.allocateMessage()) {
            MessageOutputStream out = handleCompression(invocationContext, underlying);
            try {
                out.write(Protocol.INVOCATION_REQUEST);
                out.writeShort(invocation.getIndex());

                Marshaller marshaller = getMarshaller();
                marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(out)));

                final Method invokedMethod = invocationContext.getInvokedMethod();
                final Object[] parameters = invocationContext.getParameters();

                if (version < 3) {
                    // method name as UTF string
                    out.writeUTF(invokedMethod.getName());

                    // write the method signature as UTF string
                    out.writeUTF(invocationContext.getMethodSignatureString());

                    // protocol 1 & 2 redundant locator objects
                    marshaller.writeObject(locator.getAppName());
                    marshaller.writeObject(locator.getModuleName());
                    marshaller.writeObject(locator.getDistinctName());
                    marshaller.writeObject(locator.getBeanName());
                } else {

                    // write identifier to allow the peer to find the class loader
                    marshaller.writeObject(locator.getIdentifier());

                    // write method locator
                    marshaller.writeObject(invocationContext.getMethodLocator());

                    // write sec context
                    marshaller.writeInt(peerIdentityId);

                    // write weak affinity
                    marshaller.writeObject(invocationContext.getWeakAffinity());

                    // write response compression info
                    if (invocationContext.isCompressResponse()) {
                        int compressionLevel = invocationContext.getCompressionLevel() > 0 ? invocationContext.getCompressionLevel() : 9;
                        marshaller.writeByte(compressionLevel);
                    } else {
                        marshaller.writeByte(0);
                    }

                    // write txn context
                    invocation.setOutflowHandle(writeTransaction(invocationContext.getTransaction(), marshaller, invocationContext.getAuthenticationContext()));
                }
                // write the invocation locator itself
                marshaller.writeObject(locator);

                // and the parameters
                if (parameters != null && parameters.length > 0) {
                    for (final Object methodParam : parameters) {
                        marshaller.writeObject(methodParam);
                    }
                }

                // now, attachments
                // we write out the private (a.k.a JBoss specific) attachments as well as public invocation context data
                // (a.k.a user application specific data)
                final Map, ?> privateAttachments = invocationContext.getAttachments();
                final Map contextData = invocationContext.getContextData();

                // write the attachment count which is the sum of invocation context data + 1 (since we write
                // out the private attachments under a single key with the value being the entire attachment map)
                int totalContextData = contextData.size();
                if (version >= 3) {
                    // Just write the attachments.
                    PackedInteger.writePackedInteger(marshaller, totalContextData);

                    for (Map.Entry invocationContextData : contextData.entrySet()) {
                        marshaller.writeObject(invocationContextData.getKey());
                        marshaller.writeObject(invocationContextData.getValue());
                    }
                } else {
                    final Transaction transaction = invocationContext.getTransaction();

                    // We are only marshalling those attachments whose keys are present in the object table
                    final Map, Object> marshalledPrivateAttachments = new HashMap<>();
                    for (final Map.Entry, ?> entry : privateAttachments.entrySet()) {
                        final AttachmentKey key = entry.getKey();
                        if (key == AttachmentKeys.TRANSACTION_ID_KEY) {
                            // skip!
                        } else if (ProtocolV1ObjectTable.INSTANCE.getObjectWriter(key) != null) {
                            marshalledPrivateAttachments.put(key, entry.getValue());
                        }
                    }

                    if (transaction != null) {
                        marshalledPrivateAttachments.put(AttachmentKeys.TRANSACTION_ID_KEY, calculateTransactionId(transaction));
                    }

                    final boolean hasPrivateAttachments = ! marshalledPrivateAttachments.isEmpty();
                    if (hasPrivateAttachments) {
                        totalContextData++;
                    }
                    // Note: The code here is just for backward compatibility of 1.x and 2.x versions of EJB client project.
                    // Attach legacy transaction ID, if there is an active txn.

                    if (transaction != null) {
                        // we additionally add/duplicate the transaction id under a different attachment key
                        // to preserve backward compatibility. This is here just for 1.0.x backward compatibility
                        totalContextData++;
                    }
                    // backward compatibility code block for transaction id ends here.

                    PackedInteger.writePackedInteger(marshaller, totalContextData);
                    // write out public (application specific) context data
                    ProtocolObjectResolver.enableNonSerReplacement();
                    try {
                        for (Map.Entry invocationContextData : contextData.entrySet()) {
                            marshaller.writeObject(invocationContextData.getKey());
                            marshaller.writeObject(invocationContextData.getValue());
                        }
                    } finally {
                        ProtocolObjectResolver.disableNonSerReplacement();
                    }
                    if (hasPrivateAttachments) {
                        // now write out the JBoss specific attachments under a single key and the value will be the
                        // entire map of JBoss specific attachments
                        marshaller.writeObject(EJBClientInvocationContext.PRIVATE_ATTACHMENTS_KEY);
                        marshaller.writeObject(marshalledPrivateAttachments);
                    }

                    // Note: The code here is just for backward compatibility of 1.0.x version of EJB client project
                    // against AS7 7.1.x releases. Discussion here https://github.com/wildfly/jboss-ejb-client/pull/11#issuecomment-6573863
                    if (transaction != null) {
                        // we additionally add/duplicate the transaction id under a different attachment key
                        // to preserve backward compatibility. This is here just for 1.0.x backward compatibility
                        marshaller.writeObject(TransactionID.PRIVATE_DATA_KEY);
                        // This transaction id attachment duplication *won't* cause increase in EJB protocol message payload
                        // since we rely on JBoss Marshalling to use back references for the same transaction id object being
                        // written out
                        marshaller.writeObject(marshalledPrivateAttachments.get(AttachmentKeys.TRANSACTION_ID_KEY));
                    }
                    // backward compatibility code block for transaction id ends here.
                }

                // finished
                marshaller.finish();
            } catch (IOException e) {
                underlying.cancel();
                throw e;
            } finally {
                out.close();
            }
        } catch (IOException e) {
            receiverContext.requestFailed(new RequestSendFailedException(e.getMessage() + " @ " + peerIdentity.getConnection().getPeerURI(), e, true), getRetryExecutor(receiverContext) );
        } catch (RollbackException | SystemException | RuntimeException e) {
            receiverContext.requestFailed(new EJBException(e.getMessage(), e), getRetryExecutor(receiverContext));
            return;
        }
    }

    /**
     * Wraps the {@link MessageOutputStream message output stream} into a relevant {@link DataOutputStream}, taking into account various factors like the necessity to
     * compress the data that gets passed along the stream
     *
     * @param invocationContext         The Enterprise Beans client invocation context
     * @param messageOutputStream       The message output stream that needs to be wrapped
     * @return the compressed stream
     * @throws IOException if a problem occurs
     */
    private MessageOutputStream handleCompression(final EJBClientInvocationContext invocationContext, final MessageOutputStream messageOutputStream) throws IOException {
        // if the negotiated protocol version doesn't support compressed messages then just return
        if(version == 1) {
            return messageOutputStream;
        }

        // if "hints" are disabled, just return a DataOutputStream without the necessity of processing any "hints"
        final Boolean hintsDisabled = invocationContext.getProxyAttachment(AttachmentKeys.HINTS_DISABLED);
        if (hintsDisabled != null && hintsDisabled) {
            if (Logs.REMOTING.isTraceEnabled()) {
                Logs.REMOTING.trace("Hints are disabled. Ignoring any CompressionHint on methods being invoked on view " + invocationContext.getViewClass());
            }
            return messageOutputStream;
        }

        // process any CompressionHint
        final int compressionLevel = invocationContext.getCompressionLevel();
        // write out a attachment to indicate whether or not the response has to be compressed (v2 only)
        if (version == 2 && invocationContext.isCompressResponse()) {
            invocationContext.putAttachment(AttachmentKeys.COMPRESS_RESPONSE, true);
            invocationContext.putAttachment(AttachmentKeys.RESPONSE_COMPRESSION_LEVEL, compressionLevel);
            if (Logs.REMOTING.isTraceEnabled()) {
                Logs.REMOTING.trace("Letting the server know that the response of method " + invocationContext.getInvokedMethod() + " has to be compressed with compression level = " + compressionLevel);
            }
        }

        // create a compressed invocation data *only* if the request has to be compressed (note, it's perfectly valid for certain methods to just specify that only the response is compressed)
        if (invocationContext.isCompressRequest()) {
            // write out the header indicating that it's a compressed stream
            messageOutputStream.write(Protocol.COMPRESSED_INVOCATION_MESSAGE);
            // create the deflater using the specified level
            final Deflater deflater = new Deflater(compressionLevel);
            // wrap the message outputstream with the deflater stream so that *any subsequent* data writes to the stream are compressed
            final DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(messageOutputStream, deflater);
            if (Logs.REMOTING.isTraceEnabled()) {
                Logs.REMOTING.trace("Using a compressing stream with compression level = " + compressionLevel + " for request data for EJB invocation on method " + invocationContext.getInvokedMethod());
            }
            return new WrapperMessageOutputStream(messageOutputStream, deflaterOutputStream);
        } else {
            // just return a normal DataOutputStream without any compression
            return messageOutputStream;
        }

    }

    private TransactionID calculateTransactionId(final Transaction transaction) throws RollbackException, SystemException, InvalidTransactionException {
        final URI location = channel.getConnection().getPeerURI();
        Assert.assertNotNull(transaction);
        if (transaction instanceof RemoteTransaction) {
            final RemoteTransaction remoteTransaction = (RemoteTransaction) transaction;
            remoteTransaction.setLocation(location);
            final SimpleIdResolver ir = remoteTransaction.getProviderInterface(SimpleIdResolver.class);
            if (ir == null) throw Logs.TXN.cannotEnlistTx();
            return new UserTransactionID(channel.getConnection().getRemoteEndpointName(), ir.getTransactionId(channel.getConnection()));
        } else if (transaction instanceof LocalTransaction) {
            final LocalTransaction localTransaction = (LocalTransaction) transaction;
            final XAOutflowHandle outflowHandle = transactionContext.outflowTransaction(location, localTransaction);
            // always verify V1/2 outflows
            outflowHandle.verifyEnlistment();
            return new XidTransactionID(outflowHandle.getXid());
        } else {
            throw Logs.TXN.cannotEnlistTx();
        }
    }

    private XAOutflowHandle writeTransaction(final Transaction transaction, final DataOutput dataOutput,
            final AuthenticationContext authenticationContext) throws IOException, RollbackException, SystemException {
        if (authenticationContext != null) {
            if (Logs.MAIN.isDebugEnabled()) {
                Logs.MAIN.debug("Using existing AuthenticationContext for writeTransaction(...)");
            }
            try {
                return authenticationContext.run((PrivilegedExceptionAction) () -> _writeTransaction(transaction, dataOutput));
            } catch (PrivilegedActionException e) {
                Throwable cause = e.getCause();
                if (cause instanceof IOException) {
                    throw (IOException) cause;
                } else if (cause instanceof RollbackException) {
                    throw (RollbackException) cause;
                } else if (cause instanceof SystemException) {
                    throw (SystemException) cause;
                }
                throw new RuntimeException(e);
            }
        } else {
            if (Logs.MAIN.isDebugEnabled()) {
                Logs.MAIN.debug("No existing AuthenticationContext for writeTransaction(...)");
            }
            return _writeTransaction(transaction, dataOutput);
        }
    }

    private XAOutflowHandle _writeTransaction(final Transaction transaction, final DataOutput dataOutput) throws IOException, RollbackException, SystemException {
        final URI location = channel.getConnection().getPeerURI();
        if (transaction == null) {
            dataOutput.writeByte(0);
            return null;
        } else if (transaction instanceof RemoteTransaction) {
            final RemoteTransaction remoteTransaction = (RemoteTransaction) transaction;
            remoteTransaction.setLocation(location);
            dataOutput.writeByte(1);
            final SimpleIdResolver ir = remoteTransaction.getProviderInterface(SimpleIdResolver.class);
            if (ir == null) throw Logs.TXN.cannotEnlistTx();
            final int id = ir.getTransactionId(channel.getConnection());
            dataOutput.writeInt(id);
            int transactionTimeout = remoteTransaction.getEstimatedRemainingTime();
            if(transactionTimeout == 0) throw Logs.TXN.outflowTransactionTimeoutElapsed(transaction);
            PackedInteger.writePackedInteger(dataOutput, transactionTimeout);
            return null;
        } else if (transaction instanceof LocalTransaction) {
            final LocalTransaction localTransaction = (LocalTransaction) transaction;
            final XAOutflowHandle outflowHandle = transactionContext.outflowTransaction(location, localTransaction);
            final Xid xid = outflowHandle.getXid();
            dataOutput.writeByte(2);
            PackedInteger.writePackedInteger(dataOutput, xid.getFormatId());
            final byte[] gtid = xid.getGlobalTransactionId();
            dataOutput.writeByte(gtid.length);
            dataOutput.write(gtid);
            final byte[] bq = xid.getBranchQualifier();
            // this will normally be zero, but write it anyway just in case we need to change it later
            dataOutput.writeByte(bq.length);
            dataOutput.write(bq);
            int transactionTimeout = outflowHandle.getRemainingTime();
            if(transactionTimeout == 0) throw Logs.TXN.outflowTransactionTimeoutElapsed(transaction);
            PackedInteger.writePackedInteger(dataOutput, transactionTimeout);
            return outflowHandle;
        } else {
            throw Logs.TXN.cannotEnlistTx();
        }
    }

    boolean cancelInvocation(final EJBReceiverInvocationContext receiverContext, boolean cancelIfRunning) {
        if (version < 3 && ! cancelIfRunning) {
            // keep legacy behavior
            return false;
        }
        final MethodInvocation invocation = receiverContext.getClientInvocationContext().getAttachment(INV_KEY);
        if (invocation == null) {
            // lost it somehow
            return false;
        }
        if (invocation.alloc()) try {
            final int index = invocation.getIndex();
            try (MessageOutputStream out = invocationTracker.allocateMessage()) {
                out.write(Protocol.CANCEL_REQUEST);
                out.writeShort(index);
                if (version >= 3) {
                    out.writeBoolean(cancelIfRunning);
                }
            } catch (IOException ignored) {}
        } finally {
            invocation.free();
        }
        // now await the result
        return invocation.receiverInvocationContext.getClientInvocationContext().awaitCancellationResult();
    }

    private Marshaller getMarshaller() throws IOException {
        return marshallerFactory.createMarshaller(configuration);
    }

    public  StatefulEJBLocator openSession(final StatelessEJBLocator statelessLocator, final ConnectionPeerIdentity identity, EJBSessionCreationInvocationContext clientInvocationContext) throws Exception {
        SessionOpenInvocation invocation = invocationTracker.addInvocation(id -> new SessionOpenInvocation<>(id, statelessLocator, clientInvocationContext));
        try (MessageOutputStream out = invocationTracker.allocateMessage()) {
            out.write(Protocol.OPEN_SESSION_REQUEST);
            out.writeShort(invocation.getIndex());
            writeRawIdentifier(statelessLocator, out);
            if (version >= 3) {
                out.writeInt(identity.getId());
                invocation.setOutflowHandle(writeTransaction(clientInvocationContext.getTransaction(), out, clientInvocationContext.getAuthenticationContext()));
            }
        } catch (IOException e) {
            CreateException createException = new CreateException(e.getMessage());
            createException.initCause(e);
            throw createException;
        }
        // await the response
        return invocation.getResult();
    }

    private static  void writeRawIdentifier(final EJBLocator statelessLocator, final MessageOutputStream out) throws IOException {
        final String appName = statelessLocator.getAppName();
        out.writeUTF(appName == null ? "" : appName);
        out.writeUTF(statelessLocator.getModuleName());
        final String distinctName = statelessLocator.getDistinctName();
        out.writeUTF(distinctName == null ? "" : distinctName);
        out.writeUTF(statelessLocator.getBeanName());
    }

    static IoFuture construct(final Channel channel, final DiscoveredNodeRegistry discoveredNodeRegistry, RetryExecutorWrapper retryExecutorWrapper) {
        FutureResult futureResult = new FutureResult<>();
        // now perform opening negotiation: receive server greeting
        channel.receiveMessage(new Channel.Receiver() {
            public void handleError(final Channel channel, final IOException error) {
                final ClassLoader oldCL = getAndSetSafeTCCL();
                try {
                    futureResult.setException(error);
                } finally {
                    resetTCCL(oldCL);
                }
            }

            public void handleEnd(final Channel channel) {
                final ClassLoader oldCL = getAndSetSafeTCCL();
                try {
                    futureResult.setCancelled();
                } finally {
                    resetTCCL(oldCL);
                }
            }

            public void handleMessage(final Channel channel, final MessageInputStream message) {
                final ClassLoader oldCL = getAndSetSafeTCCL();
                // receive message body
                try {
                    final int version = min(Protocol.LATEST_VERSION, StreamUtils.readInt8(message));
                    // drain the rest of the message because it's just garbage really
                    while (message.read() != -1) {
                        message.skip(Long.MAX_VALUE);
                    }
                    // send back result
                    try (MessageOutputStream out = channel.writeMessage()) {
                        out.write(version);
                        out.writeUTF("river");
                    }
                    // almost done; wait for initial module available report
                    final EJBClientChannel ejbClientChannel = new EJBClientChannel(channel, version, discoveredNodeRegistry, futureResult, retryExecutorWrapper);
                    channel.receiveMessage(new Channel.Receiver() {
                        public void handleError(final Channel channel, final IOException error) {
                            final ClassLoader oldCL = getAndSetSafeTCCL();
                            try {
                                futureResult.setException(error);
                            } finally {
                                safeClose(channel);
                                resetTCCL(oldCL);
                            }
                        }

                        public void handleEnd(final Channel channel) {
                            final ClassLoader oldCL = getAndSetSafeTCCL();
                            try {
                                futureResult.setException(new EOFException());
                            } finally {
                                safeClose(channel);
                                resetTCCL(oldCL);
                            }
                        }

                        public void handleMessage(final Channel channel, final MessageInputStream message) {
                            final ClassLoader oldCL = getAndSetSafeTCCL();
                            try {
                                ejbClientChannel.processMessage(message);
                            } finally {
                                channel.receiveMessage(this);
                                resetTCCL(oldCL);
                            }
                        }
                    });
                } catch (final IOException e) {
                    channel.closeAsync();
                    channel.addCloseHandler((closed, exception) -> futureResult.setException(e));
                } finally {
                    resetTCCL(oldCL);
                }
            }
        });
        futureResult.addCancelHandler(new Cancellable() {
            public Cancellable cancel() {
                if (futureResult.setCancelled()) {
                    safeClose(channel);
                }
                return this;
            }
        });
        return futureResult.getIoFuture();
    }

    InvocationTracker getInvocationTracker() {
        return invocationTracker;
    }

    /**
     * Glue two stack traces together.
     *
     * @param exception      the exception which occurred in another thread
     * @param userStackTrace the stack trace of the current thread from {@link Thread#getStackTrace()}
     * @param trimCount      the number of frames to trim
     * @param msg            the message to use
     */
    private static void glueStackTraces(final Throwable exception, final StackTraceElement[] userStackTrace,
                                        final int trimCount, final String msg) {
        final StackTraceElement[] est = exception.getStackTrace();
        final StackTraceElement[] fst = Arrays.copyOf(est, est.length + userStackTrace.length - trimCount + 1);
        fst[est.length] = new StackTraceElement("..." + msg + "..", "", null, -1);
        System.arraycopy(userStackTrace, trimCount, fst, est.length + 1, userStackTrace.length - trimCount);
        exception.setStackTrace(fst);
    }

    final class SessionOpenInvocation extends Invocation {

        private final StatelessEJBLocator statelessLocator;
        private final EJBSessionCreationInvocationContext clientInvocationContext;
        private int id;
        private MessageInputStream inputStream;
        private XAOutflowHandle outflowHandle;
        private IOException ex;

        protected SessionOpenInvocation(final int index, final StatelessEJBLocator statelessLocator, EJBSessionCreationInvocationContext clientInvocationContext) {
            super(index);
            this.statelessLocator = statelessLocator;
            this.clientInvocationContext = clientInvocationContext;
        }

        public void handleResponse(final int id, final MessageInputStream inputStream) {
            synchronized (this) {
                this.id = id;
                this.inputStream = inputStream;
                notifyAll();
            }
        }

        public void handleClosed() {
            synchronized (this) {
                notifyAll();
            }
        }

        public void handleException(IOException cause) {
            synchronized (this) {
                this.ex = cause;
                notifyAll();
            }
        }

        void setOutflowHandle(final XAOutflowHandle outflowHandle) {
            this.outflowHandle = outflowHandle;
        }

        XAOutflowHandle getOutflowHandle() {
            return outflowHandle;
        }

        StatefulEJBLocator getResult() throws Exception {
            Exception e;
            try (ResponseMessageInputStream response = removeInvocationResult()) {
                switch (id) {
                    case Protocol.OPEN_SESSION_RESPONSE: {
                        Affinity affinity;
                        int size = PackedInteger.readPackedInteger(response);
                        byte[] bytes = new byte[size];
                        response.readFully(bytes);
                        // todo: pool unmarshallers?  use very small instance count config?
                        if (1 <= version && version <= 2) {
                            try (final Unmarshaller unmarshaller = createUnmarshaller()) {
                                unmarshaller.start(response);
                                affinity = unmarshaller.readObject(Affinity.class);
                                unmarshaller.finish();
                            }
                        } else {
                            affinity = statelessLocator.getAffinity();
                            final int cmd = response.readUnsignedByte();
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) {
                                if (cmd == 0) {
                                    outflowHandle.forgetEnlistment();
                                } else if (cmd == 1) {
                                    try {
                                        outflowHandle.verifyEnlistment();
                                    } catch (RollbackException | SystemException e1) {
                                        throw new EJBException(e1);
                                    }
                                } else if (cmd == 2) {
                                    outflowHandle.nonMasterEnlistment();
                                }
                            }
                            final int updateBits = response.readUnsignedByte();
                            if (allAreSet(updateBits, Protocol.UPDATE_BIT_WEAK_AFFINITY)) {
                                final byte[] b = new byte[PackedInteger.readPackedInteger(response)];
                                response.readFully(b);
                                clientInvocationContext.setWeakAffinity(new NodeAffinity(new String(b, StandardCharsets.UTF_8)));
                            }
                            if (allAreSet(updateBits, Protocol.UPDATE_BIT_STRONG_AFFINITY)) {
                                final byte[] b = new byte[PackedInteger.readPackedInteger(response)];
                                response.readFully(b);
                                String clusterName = new String(b, StandardCharsets.UTF_8);
                                affinity = new ClusterAffinity(clusterName);
                            }
                        }
                        StatefulEJBLocator locator = statelessLocator.withSessionAndAffinity(SessionID.createSessionID(bytes), affinity);

                        if (Logs.INVOCATION.isDebugEnabled()) {
                            Logs.INVOCATION.debugf("EJBClientChannel.SessionOpenInvocation.getResult(): updating Locator (sessionID), new = %s; weakAffinity = %s",
                                    locator, clientInvocationContext.getWeakAffinity());
                        }

                        clientInvocationContext.setLocator(locator);
                        return locator;
                    }
                    case Protocol.APPLICATION_EXCEPTION: {
                        if (version >= 3) {
                            final int cmd = response.readUnsignedByte();
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) {
                                if (cmd == 0) {
                                    outflowHandle.forgetEnlistment();
                                } else if (cmd == 1) {
                                    try {
                                        outflowHandle.verifyEnlistment();
                                    } catch (RollbackException | SystemException e1) {
                                        throw new EJBException(e1);
                                    }
                                } else if (cmd == 2) {
                                    outflowHandle.nonMasterEnlistment();
                                }
                            }
                        }
                        try (final Unmarshaller unmarshaller = createUnmarshaller()) {
                            unmarshaller.start(response);
                            e = unmarshaller.readObject(Exception.class);
                            unmarshaller.finish();
                            if (version < 3) {
                                // drain off attachments so the server doesn't complain
                                while (response.read() != -1) {
                                    response.skip(Long.MAX_VALUE);
                                }
                            }
                        }
                        glueStackTraces(e, Thread.currentThread().getStackTrace(), 1, "asynchronous invocation");
                        break;
                    }
                    case Protocol.CANCEL_RESPONSE: {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(clientInvocationContext);
                        throw Logs.REMOTING.requestCancelled();
                    }
                    case Protocol.NO_SUCH_EJB: {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(clientInvocationContext);
                        final String message = response.readUTF();
                        throw new NoSuchEJBException(message + " @ " + getChannel().getConnection().getPeerURI());
                    }
                    case Protocol.BAD_VIEW_TYPE: {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(clientInvocationContext);
                        final String message = response.readUTF();
                        throw Logs.REMOTING.invalidViewTypeForInvocation(message);
                    }
                    case Protocol.EJB_NOT_STATEFUL: {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(clientInvocationContext);
                        final String message = response.readUTF();
                        throw Logs.REMOTING.ejbNotStateful(message);
                    }
                    default: {
                        throw new EJBException("Invalid EJB creation response (id " + id + ")");
                    }
                }
            } catch (ClassNotFoundException | IOException ex) {
                throw new EJBException("Failed to read session create response", ex);
            } finally {
                invocationTracker.remove(this);
            }
            if (e == null) {
                throw new EJBException("Null exception response");
            }
            // throw the application exception outside of the try block
            throw e;
        }

        private ResponseMessageInputStream removeInvocationResult() {
            MessageInputStream mis;
            int id;
            try {
                synchronized (this) {
                    for (; ; ) {
                        id = this.getIndex();
                        if (inputStream != null) {
                            mis = inputStream;
                            inputStream = null;
                            break;
                        }
                        if (ex != null) {
                            throw new EJBException(ex);
                        }
                        if (id == -1) {
                            throw new EJBException("Connection closed");
                        }
                        wait();
                    }
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new EJBException("Session creation interrupted");
            }
            return new ResponseMessageInputStream(mis, id);
        }
    }

    Unmarshaller createUnmarshaller() throws IOException {
        return marshallerFactory.createUnmarshaller(configuration);
    }

    Channel getChannel() {
        return channel;
    }

    UserTransactionID allocateUserTransactionID() {
        final ThreadLocalRandom random = ThreadLocalRandom.current();
        final String nodeName = getChannel().getConnection().getRemoteEndpointName();
        final byte[] nameBytes;
        try {
            nameBytes = nodeName.getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            throw Assert.unreachableCode();
        }
        final int length = nameBytes.length;
        if (length > 255) {
            throw Assert.unreachableCode();
        }
        final byte[] target = new byte[6 + length];
        target[0] = 0x01;
        target[1] = (byte) length;
        System.arraycopy(nameBytes, 0, target, 2, length);
        for (;;) {
            int id = random.nextInt();
            if (! userTxnIds.containsKey(id)) {
                target[2 + length] = (byte) (id >> 24);
                target[3 + length] = (byte) (id >> 16);
                target[4 + length] = (byte) (id >> 8);
                target[5 + length] = (byte) id;
                final UserTransactionID uti = (UserTransactionID) TransactionID.createTransactionID(target);
                if (userTxnIds.putIfAbsent(uti) == null) {
                    return uti;
                }
            }
        }
    }

    void finishPart(int bit) {
        int oldVal, newVal;
        do {
            oldVal = finishedParts.get();
            if (allAreSet(oldVal, bit)) {
                return;
            }
            newVal = oldVal | bit;
        } while (! finishedParts.compareAndSet(oldVal, newVal));
        if (newVal == 0b11) {
            final FutureResult futureResult = futureResultRef.get();
            if (futureResult != null && futureResultRef.compareAndSet(futureResult, null)) {
                // done!
                futureResult.setResult(this);
            }
        }
    }

    final class MethodInvocation extends Invocation {
        private final EJBReceiverInvocationContext receiverInvocationContext;
        private final AtomicInteger refCounter = new AtomicInteger(1);
        private XAOutflowHandle outflowHandle;

        MethodInvocation(final int index, final EJBReceiverInvocationContext receiverInvocationContext) {
            super(index);
            this.receiverInvocationContext = receiverInvocationContext;
        }

        boolean alloc() {
            final AtomicInteger refCounter = this.refCounter;
            int oldVal;
            do {
                oldVal = refCounter.get();
                if (oldVal == 0) {
                    return false;
                }
            } while (! refCounter.compareAndSet(oldVal, oldVal + 1));
            return true;
        }

        void free() {
            final AtomicInteger refCounter = this.refCounter;
            final int newVal = refCounter.decrementAndGet();
            if (newVal == 0) {
                invocationTracker.remove(this);
            }
        }

        @Override
        public void handleResponse(int parameter, MessageInputStream inputStream) {
            handleResponse(parameter, new DataInputStream(inputStream));
        }

        private void handleResponse(final int id, final DataInputStream inputStream) {
            switch (id) {
                case Protocol.INVOCATION_RESPONSE: {
                    free();
                    final EJBClientInvocationContext context = receiverInvocationContext.getClientInvocationContext();
                    if (version >= 3) try {
                        final int cmd = inputStream.readUnsignedByte();
                        final XAOutflowHandle outflowHandle = getOutflowHandle();
                        if (outflowHandle != null) {
                            if (cmd == 0) {
                                outflowHandle.forgetEnlistment();
                            } else if (cmd == 1) {
                                outflowHandle.verifyEnlistment();
                            } else if (cmd == 2) {
                                outflowHandle.nonMasterEnlistment();
                            }
                        }
                        final int updateBits = inputStream.readUnsignedByte();
                        if (allAreSet(updateBits, Protocol.UPDATE_BIT_SESSION_ID)) {
                            byte[] encoded = new byte[PackedInteger.readPackedInteger(inputStream)];
                            inputStream.readFully(encoded);
                            final SessionID sessionID = SessionID.createSessionID(encoded);
                            final Object invokedProxy = context.getInvokedProxy();
                            EJBClient.convertToStateful(invokedProxy, sessionID);
                            context.setLocator(EJBClient.getLocatorFor(invokedProxy));

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated Locator (sessionID); new = %s", context.getLocator().getAffinity());
                            }
                        }
                        if (allAreSet(updateBits, Protocol.UPDATE_BIT_WEAK_AFFINITY)) {
                            byte[] b = new byte[PackedInteger.readPackedInteger(inputStream)];
                            inputStream.readFully(b);
                            context.setWeakAffinity(new NodeAffinity(new String(b, StandardCharsets.UTF_8)));

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated weak affinity = %s", context.getWeakAffinity());
                            }
                        }
                        if (allAreSet(updateBits, Protocol.UPDATE_BIT_STRONG_AFFINITY)) {
                            byte[] b = new byte[PackedInteger.readPackedInteger(inputStream)];
                            inputStream.readFully(b);
                            context.setLocator(context.getLocator().withNewAffinity(new ClusterAffinity(new String(b, StandardCharsets.UTF_8))));

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("EJBClientChannel.handleResponse: updated strong affinity = %s", context.getLocator().getAffinity());
                            }
                        }
                    } catch (RuntimeException | IOException | RollbackException | SystemException e) {
                        receiverInvocationContext.requestFailed(new EJBException(e), getRetryExecutor(receiverInvocationContext) );
                        safeClose(inputStream);
                        break;
                    }
                    final NamingProvider provider = context.getProxyAttachment(Keys.NAMING_PROVIDER_ATTACHMENT_KEY);
                    receiverInvocationContext.resultReady(new MethodCallResultProducer(provider, inputStream, id));
                    break;
                }
                case Protocol.CANCEL_RESPONSE: {
                    free();
                    if (version >= 3) {
                        final XAOutflowHandle outflowHandle = getOutflowHandle();
                        if (outflowHandle != null) outflowHandle.forgetEnlistment();
                    }
                    receiverInvocationContext.requestCancelled();
                    break;
                }
                case Protocol.APPLICATION_EXCEPTION: {
                    free();
                    receiverInvocationContext.resultReady(new ExceptionResultProducer(inputStream, id));
                    break;
                }
                case Protocol.NO_SUCH_EJB: {
                    free();
                    try {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
                        final String message = inputStream.readUTF();
                        final EJBModuleIdentifier moduleIdentifier = receiverInvocationContext.getClientInvocationContext().getLocator().getIdentifier().getModuleIdentifier();
                        final NodeInformation nodeInformation = discoveredNodeRegistry.getNodeInformation(getChannel().getConnection().getRemoteEndpointName());
                        nodeInformation.removeModule(EJBClientChannel.this, moduleIdentifier);
                        receiverInvocationContext.requestFailed(new NoSuchEJBException(message + " @ " + getChannel().getConnection().getPeerURI()), getRetryExecutor(receiverInvocationContext));
                    } catch (IOException e) {
                        receiverInvocationContext.requestFailed(new EJBException("Failed to read 'No such EJB' response", e), getRetryExecutor(receiverInvocationContext));
                    } finally {
                        safeClose(inputStream);
                    }
                    break;
                }
                case Protocol.BAD_VIEW_TYPE: {
                    free();
                    try {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
                        final String message = inputStream.readUTF();
                        receiverInvocationContext.requestFailed(Logs.REMOTING.invalidViewTypeForInvocation(message), getRetryExecutor(receiverInvocationContext));
                    } catch (IOException e) {
                        receiverInvocationContext.requestFailed(new EJBException("Failed to read 'Bad EJB view type' response", e), getRetryExecutor(receiverInvocationContext));
                    } finally {
                        safeClose(inputStream);
                    }
                    break;
                }
                case Protocol.NO_SUCH_METHOD: {
                    free();
                    try {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
                        final String message = inputStream.readUTF();
                        // todo: I don't think this is the best exception type for this case...
                        receiverInvocationContext.requestFailed(new IllegalArgumentException(message), getRetryExecutor(receiverInvocationContext));
                    } catch (IOException e) {
                        receiverInvocationContext.requestFailed(new EJBException("Failed to read 'No such EJB method' response", e), getRetryExecutor(receiverInvocationContext));
                    } finally {
                        safeClose(inputStream);
                    }
                    break;
                }
                case Protocol.SESSION_NOT_ACTIVE: {
                    free();
                    try {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
                        final String message = inputStream.readUTF();
                        receiverInvocationContext.requestFailed(new EJBException(message), getRetryExecutor(receiverInvocationContext));
                    } catch (IOException e) {
                        receiverInvocationContext.requestFailed(new EJBException("Failed to read 'Session not active' response", e), getRetryExecutor(receiverInvocationContext));
                    } finally {
                        safeClose(inputStream);
                    }
                    break;
                }
                case Protocol.EJB_NOT_STATEFUL: {
                    free();
                    try {
                        if (version >= 3) {
                            final XAOutflowHandle outflowHandle = getOutflowHandle();
                            if (outflowHandle != null) outflowHandle.forgetEnlistment();
                        }
                        disassociateRemoteTxIfPossible(receiverInvocationContext.getClientInvocationContext());
                        final String message = inputStream.readUTF();
                        receiverInvocationContext.requestFailed(new EJBException(message), getRetryExecutor(receiverInvocationContext));
                    } catch (IOException e) {
                        receiverInvocationContext.requestFailed(new EJBException("Failed to read 'EJB not stateful' response"), getRetryExecutor(receiverInvocationContext));
                    } finally {
                        safeClose(inputStream);
                    }
                    break;
                }
                case Protocol.PROCEED_ASYNC_RESPONSE: {
                    // do not free; response is forthcoming
                    safeClose(inputStream);
                    receiverInvocationContext.proceedAsynchronously();
                    break;
                }
                default: {
                    free();
                    safeClose(inputStream);
                    receiverInvocationContext.requestFailed(new EJBException("Unknown protocol response"), getRetryExecutor(receiverInvocationContext));
                    break;
                }
            }
        }

        public void handleClosed() {
            receiverInvocationContext.requestFailed(Logs.REMOTING.channelTimeoutOrClosed(new ClosedChannelException()),
                    getRetryExecutor(receiverInvocationContext));
        }

        public void handleException(IOException cause) {
            receiverInvocationContext.requestFailed(new EJBException(cause), getRetryExecutor(receiverInvocationContext));
        }

        XAOutflowHandle getOutflowHandle() {
            return outflowHandle;
        }

        void setOutflowHandle(final XAOutflowHandle outflowHandle) {
            this.outflowHandle = outflowHandle;
        }

        class MethodCallResultProducer implements EJBReceiverInvocationContext.ResultProducer, ExceptionBiFunction {

            private final NamingProvider namingProvider;
            private final InputStream inputStream;
            private final int id;

            MethodCallResultProducer(final NamingProvider provider, final InputStream inputStream, final int id) {
                namingProvider = provider;
                this.inputStream = inputStream;
                this.id = id;
            }

            /**
             * Clean the ContextData on the {@link EJBClientInvocationContext} by removing all keys, contained in the
             * {@link EJBClientInvocationContext#RETURNED_CONTEXT_DATA_KEY}.
             *
             * @param clientInvocationContext
             */
            private void cleanContextDataBeforeUnmarshalling(EJBClientInvocationContext clientInvocationContext) {
                Map contextData = clientInvocationContext.getContextData();
                contextData.keySet().removeIf(k -> (!k.equals(EJBClientInvocationContext.RETURNED_CONTEXT_DATA_KEY)));
            }


            public Object apply(final Void ignored0, final Void ignored1) throws Exception {
                final ResponseMessageInputStream response;
                if(inputStream instanceof ResponseMessageInputStream) {
                    response = (ResponseMessageInputStream) inputStream;
                } else {
                    response = new ResponseMessageInputStream(inputStream, id);
                }
                Object result;
                try (final Unmarshaller unmarshaller = createUnmarshaller()) {
                    unmarshaller.start(response);
                    result = unmarshaller.readObject();
                    int attachments = unmarshaller.readUnsignedByte();
                    final EJBClientInvocationContext clientInvocationContext = receiverInvocationContext.getClientInvocationContext();
                    // clean the client side ContextData to allow server side removed keys being removed - EJBCLIENT-425
                    cleanContextDataBeforeUnmarshalling(clientInvocationContext);
                    for (int i = 0; i < attachments; i ++) {
                        String key = unmarshaller.readObject(String.class);
                        if (version < 3 && key.equals(Affinity.WEAK_AFFINITY_CONTEXT_KEY)) {
                            final Affinity affinity = unmarshaller.readObject(Affinity.class);
                            clientInvocationContext.putAttachment(AttachmentKeys.WEAK_AFFINITY, affinity);
                            clientInvocationContext.setWeakAffinity(affinity);

                            if (Logs.INVOCATION.isDebugEnabled()) {
                                Logs.INVOCATION.debugf("EJBClientChannel.MethodCallResultProducer: updated weak affinity (version < 3) = %s", affinity);
                            }
                        } else if (key.equals(EJBClientInvocationContext.PRIVATE_ATTACHMENTS_KEY)) {
                            // skip
                            unmarshaller.readObject();
                        } else {
                            final Object value = unmarshaller.readObject();
                            if (value != null) clientInvocationContext.getContextData().put(key, value);
                        }
                    }
                    unmarshaller.finish();
                } catch (IOException | ClassNotFoundException ex) {
                    discardResult();
                    throw new EJBException("Failed to read response", ex);
                }
                return result;
            }

            public Object getResult() throws Exception {
                if (namingProvider != null) {
                    return namingProvider.performExceptionAction(this, null, null);
                } else {
                    return apply(null, null);
                }
            }

            public void discardResult() {
                safeClose(inputStream);
            }
        }

        class ExceptionResultProducer implements EJBReceiverInvocationContext.ResultProducer {

            private final InputStream inputStream;
            private final int id;

            ExceptionResultProducer(final InputStream inputStream, final int id) {
                this.inputStream = inputStream;
                this.id = id;
            }

            public Object getResult() throws Exception {
                Exception e;
                try (final ResponseMessageInputStream response = new ResponseMessageInputStream(inputStream, id)) {
                    if (version >= 3) {
                        final int cmd = response.readUnsignedByte();
                        final XAOutflowHandle outflowHandle = getOutflowHandle();
                        if (outflowHandle != null) {
                            if (cmd == 0) {
                                outflowHandle.forgetEnlistment();
                            } else if (cmd == 1) {
                                try {
                                    outflowHandle.verifyEnlistment();
                                } catch (RollbackException | SystemException e1) {
                                    throw new EJBException(e1);
                                }
                            } else if (cmd == 2) {
                                outflowHandle.nonMasterEnlistment();
                            }
                        }
                    }
                    try (final Unmarshaller unmarshaller = createUnmarshaller()) {
                        unmarshaller.start(response);
                        e = unmarshaller.readObject(Exception.class);
                        if (version < 3) {
                            // discard attachment data, if any
                            int attachments = unmarshaller.readUnsignedByte();
                            for (int i = 0; i < attachments; i ++) {
                                unmarshaller.readObject();
                                unmarshaller.readObject();
                            }
                        }
                        unmarshaller.finish();
                    }
                } catch (IOException | ClassNotFoundException ex) {
                    discardResult();
                    throw new EJBException("Failed to read response", ex);
                }
                if (e == null) {
                    throw new EJBException("Null exception response");
                }
                glueStackTraces(e, Thread.currentThread().getStackTrace(), 1, "asynchronous invocation");
                throw e;
            }

            public void discardResult() {
                safeClose(inputStream);
            }
        }
    }

    private static void disassociateRemoteTxIfPossible(AbstractInvocationContext context) {
        AbstractTransaction transaction = context.getTransaction();
        if (transaction instanceof RemoteTransaction) {
            RemoteTransaction remote = (RemoteTransaction) transaction;
            if (!remote.tryClearLocation()) {
                Logs.TXN.tracef("Could not disassociate remote transaction (already in-use or completed) from %s", remote.getLocation());
            }
        }
    }

    private Executor getRetryExecutor() {
        return retryExecutorWrapper.getExecutor(getChannel().getConnection().getEndpoint().getXnioWorker());
    }

    /*
     * Provides a retry executor which will transfer the given thread contexts to the worker thread before execution.
     */
    private Executor getRetryExecutor(EJBReceiverInvocationContext ejbReceiverInvocationContext) {
        EJBClientContext ejbClientContext = ejbReceiverInvocationContext.getClientContext();
        Discovery discovery = ejbReceiverInvocationContext.getDiscovery();
        AuthenticationContext authentoicationContext = ejbReceiverInvocationContext.getAuthenticationContext();
        Executor executor = getChannel().getConnection().getEndpoint().getXnioWorker();

        return retryExecutorWrapper.getExecutor(executor, ejbClientContext, discovery, authentoicationContext);
    }

    static class ResponseMessageInputStream extends MessageInputStream implements ByteInput {
        private final InputStream delegate;
        private final int id;

        ResponseMessageInputStream(final InputStream delegate, final int id) {
            this.delegate = delegate;
            this.id = id;
        }

        public int read() throws IOException {
            return delegate.read();
        }

        public int read(final byte[] b, final int off, final int len) throws IOException {
            return delegate.read(b, off, len);
        }

        public long skip(final long n) throws IOException {
            return delegate.skip(n);
        }

        public void close() throws IOException {
            delegate.close();
        }

        public int getId() {
            return id;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy