org.jboss.ejb.protocol.remote.EJBServerChannel Maven / Gradle / Ivy
Go to download
This artifact provides a single jar that contains all classes required to use remote EJB and JMS, including
all dependencies. It is intended for use by those not using maven, maven users should just import the EJB and
JMS BOM's instead (shaded JAR's cause lots of problems with maven, as it is very easy to inadvertently end up
with different versions on classes on the class path).
/*
* JBoss, Home of Professional Open Source.
* Copyright 2017 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.ejb.protocol.remote;
import org.jboss.ejb._private.Logs;
import org.jboss.ejb.client.Affinity;
import org.jboss.ejb.client.AttachmentKeys;
import org.jboss.ejb.client.ClusterAffinity;
import org.jboss.ejb.client.EJBClient;
import org.jboss.ejb.client.EJBClientInvocationContext;
import org.jboss.ejb.client.EJBIdentifier;
import org.jboss.ejb.client.EJBLocator;
import org.jboss.ejb.client.EJBMethodLocator;
import org.jboss.ejb.client.EJBModuleIdentifier;
import org.jboss.ejb.client.NodeAffinity;
import org.jboss.ejb.client.RequestSendFailedException;
import org.jboss.ejb.client.SessionID;
import org.jboss.ejb.client.TransactionID;
import org.jboss.ejb.client.UserTransactionID;
import org.jboss.ejb.client.XidTransactionID;
import org.jboss.ejb.client.annotation.CompressionHint;
import org.jboss.ejb.server.Association;
import org.jboss.ejb.server.CancelHandle;
import org.jboss.ejb.server.ClusterTopologyListener;
import org.jboss.ejb.server.InvocationRequest;
import org.jboss.ejb.server.ListenerHandle;
import org.jboss.ejb.server.ModuleAvailabilityListener;
import org.jboss.ejb.server.Request;
import org.jboss.ejb.server.SessionOpenRequest;
import org.jboss.marshalling.AbstractClassResolver;
import org.jboss.marshalling.Marshaller;
import org.jboss.marshalling.MarshallerFactory;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.MarshallingConfiguration;
import org.jboss.marshalling.Unmarshaller;
import org.jboss.marshalling.river.RiverMarshallerFactory;
import org.jboss.remoting3.Channel;
import org.jboss.remoting3.Connection;
import org.jboss.remoting3.MessageInputStream;
import org.jboss.remoting3.MessageOutputStream;
import org.jboss.remoting3._private.IntIndexHashMap;
import org.jboss.remoting3.util.MessageTracker;
import org.wildfly.common.Assert;
import org.wildfly.common.annotation.NotNull;
import org.wildfly.common.function.ExceptionSupplier;
import org.wildfly.security.auth.server.SecurityIdentity;
import org.wildfly.transaction.client.ContextTransactionManager;
import org.wildfly.transaction.client.ImportResult;
import org.wildfly.transaction.client.LocalTransaction;
import org.wildfly.transaction.client.SimpleXid;
import org.wildfly.transaction.client.provider.remoting.RemotingTransactionServer;
import org.wildfly.transaction.client.spi.SubordinateTransactionControl;
import jakarta.ejb.EJBException;
import jakarta.transaction.HeuristicMixedException;
import jakarta.transaction.HeuristicRollbackException;
import jakarta.transaction.RollbackException;
import jakarta.transaction.SystemException;
import jakarta.transaction.Transaction;
import javax.transaction.xa.XAException;
import javax.transaction.xa.Xid;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.Inet6Address;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.zip.Deflater;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterInputStream;
import static java.lang.Math.min;
import static java.security.AccessController.doPrivileged;
import static org.jboss.ejb.protocol.remote.TCCLUtils.getAndSetSafeTCCL;
import static org.jboss.ejb.protocol.remote.TCCLUtils.resetTCCL;
import static org.xnio.IoUtils.safeClose;
/**
* @author David M. Lloyd
* @author Tomasz Adamski
* @author Richard Opalka
* @author Joerg Baesner
*/
@SuppressWarnings("deprecation")
final class EJBServerChannel {
private static final char METHOD_PARAM_TYPE_SEPARATOR = ',';
private final RemotingTransactionServer transactionServer;
private final Channel channel;
private final int version;
private final MessageTracker messageTracker;
private final MarshallerFactory marshallerFactory;
private final MarshallingConfiguration configuration;
private final IntIndexHashMap invocations = new IntIndexHashMap<>(InProgress::getInvId);
private final Function classResolverFilter;
EJBServerChannel(final RemotingTransactionServer transactionServer, final Channel channel, final int version, final MessageTracker messageTracker,
final Function classResolverFilter) {
this.transactionServer = transactionServer;
this.channel = channel;
this.version = version;
this.messageTracker = messageTracker;
final MarshallingConfiguration configuration = new MarshallingConfiguration();
if (version < 3) {
configuration.setClassTable(ProtocolV1ClassTable.INSTANCE);
configuration.setObjectTable(ProtocolV1ObjectTable.INSTANCE);
configuration.setObjectResolver(new ProtocolV1ObjectResolver(channel.getConnection(), true));
configuration.setVersion(2);
} else {
configuration.setObjectTable(ProtocolV3ObjectTable.INSTANCE);
configuration.setObjectResolver(new ProtocolV3ObjectResolver(channel.getConnection(), true));
configuration.setVersion(4);
}
EENamespaceInteroperability.handleInteroperability(configuration, version);
marshallerFactory = new RiverMarshallerFactory();
this.configuration = configuration;
this.classResolverFilter = classResolverFilter;
}
Channel.Receiver getReceiver(final Association association, final ListenerHandle handle1, final ListenerHandle handle2) {
return new ReceiverImpl(association, handle1, handle2);
}
ClusterTopologyListener createTopologyListener() {
return new ClusterTopologyWriter();
}
ModuleAvailabilityListener createModuleListener() {
return new ModuleAvailabilityWriter();
}
class ReceiverImpl implements Channel.Receiver {
private final Association association;
private final ListenerHandle handle1;
private final ListenerHandle handle2;
ReceiverImpl(final Association association, final ListenerHandle handle1, final ListenerHandle handle2) {
this.association = association;
this.handle1 = handle1;
this.handle2 = handle2;
}
public void handleError(final Channel channel, final IOException error) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
handle1.close();
handle2.close();
} finally {
resetTCCL(oldCL);
}
}
public void handleEnd(final Channel channel) {
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
handle1.close();
handle2.close();
} finally {
resetTCCL(oldCL);
}
}
public void handleMessage(final Channel channel, final MessageInputStream message) {
channel.receiveMessage(this);
final ClassLoader oldCL = getAndSetSafeTCCL();
try {
final int code = message.readUnsignedByte();
switch (code) {
case Protocol.COMPRESSED_INVOCATION_MESSAGE:
case Protocol.INVOCATION_REQUEST: {
try (InputStream input = code == Protocol.COMPRESSED_INVOCATION_MESSAGE ? new InflaterInputStream(message) : message) {
// now if we get an error, we can respond.
if(code == Protocol.COMPRESSED_INVOCATION_MESSAGE) {
int verify = input.read();
if(verify != Protocol.INVOCATION_REQUEST) {
throw new RuntimeException();
}
}
final int invId = (input.read() << 8) | input.read();
try {
handleInvocationRequest(invId, input);
} catch (IOException | ClassNotFoundException e) {
// write response back to client
writeFailedResponse(invId, e);
}
}
break;
}
case Protocol.OPEN_SESSION_REQUEST: {
final int invId = message.readUnsignedShort();
try {
handleSessionOpenRequest(invId, message);
} catch (IOException e) {
// write response back to client
writeFailedResponse(invId, e);
}
break;
}
case Protocol.CANCEL_REQUEST: {
final int invId = message.readUnsignedShort();
try {
handleCancelRequest(invId, message);
} catch (IOException e) {
// ignored
}
break;
}
case Protocol.TXN_COMMIT_REQUEST:
case Protocol.TXN_ROLLBACK_REQUEST:
case Protocol.TXN_PREPARE_REQUEST:
case Protocol.TXN_FORGET_REQUEST:
case Protocol.TXN_BEFORE_COMPLETION_REQUEST: {
final int invId = message.readUnsignedShort();
try {
handleTxnRequest(code, invId, message);
} catch (IOException e) {
// ignored
}
break;
}
case Protocol.TXN_RECOVERY_REQUEST: {
final int invId = message.readUnsignedShort();
try {
handleTxnRecoverRequest(invId, message);
} catch (IOException e) {
// ignored
}
break;
}
default: {
// unrecognized
Logs.REMOTING.invalidMessageReceived(code);
break;
}
}
} catch (IOException e) {
// nothing we can do.
} finally {
safeClose(message);
resetTCCL(oldCL);
}
}
private void writeTxnResponse(final int invId, final int flag) {
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.TXN_RESPONSE);
os.writeShort(invId);
os.writeBoolean(true);
PackedInteger.writePackedInteger(os, flag);
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnTransactionResponseWrite(invId, channel, e);
}
}
private void writeTxnResponse(final int invId) {
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.TXN_RESPONSE);
os.writeShort(invId);
os.writeBoolean(false);
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnTransactionResponseWrite(invId, channel, e);
}
}
private void handleTxnRequest(final int code, final int invId, final MessageInputStream message) throws IOException {
final byte[] bytes = new byte[PackedInteger.readPackedInteger(message)];
message.readFully(bytes);
final TransactionID transactionID = TransactionID.createTransactionID(bytes);
if (transactionID instanceof XidTransactionID) try {
final SubordinateTransactionControl control = transactionServer.getTransactionService().getTransactionContext().findOrImportTransaction(((XidTransactionID) transactionID).getXid(), 0).getControl();
switch (code) {
case Protocol.TXN_COMMIT_REQUEST: {
boolean opc = message.readBoolean();
control.commit(opc);
writeTxnResponse(invId);
break;
}
case Protocol.TXN_ROLLBACK_REQUEST: {
control.rollback();
writeTxnResponse(invId);
break;
}
case Protocol.TXN_PREPARE_REQUEST: {
int res = control.prepare();
writeTxnResponse(invId, res);
break;
}
case Protocol.TXN_FORGET_REQUEST: {
control.forget();
writeTxnResponse(invId);
break;
}
case Protocol.TXN_BEFORE_COMPLETION_REQUEST: {
control.beforeCompletion();
writeTxnResponse(invId);
break;
}
default: throw Assert.impossibleSwitchCase(code);
}
} catch (XAException e) {
//EJBCLIENT-373
if (!((version <= 2) && (e.errorCode == XAException.XAER_NOTA))) {
writeFailedResponse(invId, e);
}
} else if (transactionID instanceof UserTransactionID) try {
final LocalTransaction localTransaction = transactionServer.removeTransaction(((UserTransactionID) transactionID).getId());
switch (code) {
case Protocol.TXN_COMMIT_REQUEST: {
// Discard unused parameter
message.readBoolean();
if (localTransaction != null) localTransaction.commit();
writeTxnResponse(invId);
break;
}
case Protocol.TXN_ROLLBACK_REQUEST: {
if (localTransaction != null) localTransaction.rollback();
writeTxnResponse(invId);
break;
}
case Protocol.TXN_PREPARE_REQUEST:
case Protocol.TXN_FORGET_REQUEST:
case Protocol.TXN_BEFORE_COMPLETION_REQUEST: {
writeFailedResponse(invId, Logs.TXN.userTxNotSupportedByTxContext());
break;
}
default: throw Assert.impossibleSwitchCase(code);
}
} catch (SystemException | HeuristicMixedException | RollbackException | HeuristicRollbackException e) {
writeFailedResponse(invId, e);
} catch (Throwable t) {
// Narayana uses Errors, Exceptions, and RuntimeExceptions
writeFailedResponse(invId, Logs.TXN.internalSystemErrorWithTx(t));
} else {
throw Assert.unreachableCode();
}
}
void handleTxnRecoverRequest(final int invId, final MessageInputStream message) throws IOException {
final String parentName = message.readUTF();
final int flags = message.readInt();
final Xid[] xids;
try {
xids = transactionServer.getTransactionService().getTransactionContext().getRecoveryInterface().recover(flags, parentName);
} catch (XAException e) {
writeFailedResponse(invId, e);
return;
}
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.TXN_RECOVERY_RESPONSE);
os.writeShort(invId);
PackedInteger.writePackedInteger(os, xids.length);
final Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(os)));
for (Xid xid : xids) {
marshaller.writeObject(new XidTransactionID(xid));
}
marshaller.finish();
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnTransactionRecoveryResponseWrite(invId, channel, e);
}
}
void handleCancelRequest(final int invId, final MessageInputStream message) throws IOException {
final boolean cancelIfRunning = version < 3 || message.readBoolean();
final InProgress inProgress = invocations.get(invId);
if (inProgress != null) {
inProgress.cancel(cancelIfRunning);
}
}
void handleSessionOpenRequest(final int invId, final MessageInputStream inputStream) throws IOException {
final String appName = inputStream.readUTF();
final String moduleName = inputStream.readUTF();
final String distName = inputStream.readUTF();
final String beanName = inputStream.readUTF();
final int securityContext;
final ExceptionSupplier, SystemException> transactionSupplier;
if (version >= 3) {
securityContext = inputStream.readInt();
transactionSupplier = readTransaction(inputStream);
} else {
securityContext = 0;
transactionSupplier = null;
}
final Connection connection = channel.getConnection();
final EJBIdentifier identifier = new EJBIdentifier(appName, moduleName, beanName, distName);
association.receiveSessionOpenRequest(new RemotingSessionOpenRequest(
invId,
identifier,
transactionSupplier,
connection.getLocalIdentity(securityContext)));
}
void handleInvocationRequest(final int invId, final InputStream input) throws IOException, ClassNotFoundException {
final MarshallingConfiguration configuration = EJBServerChannel.this.configuration.clone();
final ServerClassResolver classResolver = new ServerClassResolver(EJBServerChannel.this.classResolverFilter);
configuration.setClassResolver(classResolver);
final Unmarshaller unmarshaller;
final EJBIdentifier identifier;
final EJBMethodLocator methodLocator;
final Connection connection = channel.getConnection();
final SecurityIdentity identity;
if (version >= 3) {
unmarshaller = marshallerFactory.createUnmarshaller(configuration);
unmarshaller.start(Marshalling.createByteInput(input));
identifier = unmarshaller.readObject(EJBIdentifier.class);
methodLocator = unmarshaller.readObject(EJBMethodLocator.class);
int identityId = unmarshaller.readInt();
identity = identityId == 0 ? connection.getLocalIdentity() : connection.getLocalIdentity(identityId);
} else {
assert version <= 2;
DataInputStream data = new DataInputStream(input);
final String methodName = data.readUTF();
// method signature
final String sigString = data.readUTF();
unmarshaller = marshallerFactory.createUnmarshaller(configuration);
unmarshaller.start(Marshalling.createByteInput(data));
String appName = unmarshaller.readObject(String.class);
String moduleName = unmarshaller.readObject(String.class);
String distinctName = unmarshaller.readObject(String.class);
String beanName = unmarshaller.readObject(String.class);
identifier = new EJBIdentifier(appName, moduleName, beanName, distinctName);
// parse out the signature string
final String[] parameterTypeNames;
if (sigString.isEmpty()) {
parameterTypeNames = new String[0];
} else {
parameterTypeNames = sigString.split(String.valueOf(METHOD_PARAM_TYPE_SEPARATOR));
}
methodLocator = new EJBMethodLocator(methodName, parameterTypeNames);
identity = connection.getLocalIdentity();
}
final RemotingInvocationRequest request = new RemotingInvocationRequest(
invId, identifier, methodLocator, classResolver, unmarshaller, identity
);
InProgress value = new InProgress(request);
invocations.put(value);
try {
value.setCancelHandle(association.receiveInvocationRequest(request));
} catch (Throwable t) {
//this should not happen
//but no harm in being defensive
Logs.INVOCATION.unexpectedException(t);
if(t instanceof Exception) {
request.writeException((Exception) t);
} else {
request.writeException(new EJBException(new RuntimeException(t)));
}
}
}
}
ExceptionSupplier, SystemException> readTransaction(final DataInput input) throws IOException {
final int type = input.readUnsignedByte();
if (type == 0) {
return null;
} else if (type == 1) {
// remote user transaction
final int id = input.readInt();
final int timeout = PackedInteger.readPackedInteger(input);
return () -> new ImportResult(transactionServer.getOrBeginTransaction(id, timeout), SubordinateTransactionControl.EMPTY, false);
} else if (type == 2) {
final int fmt = PackedInteger.readPackedInteger(input);
final byte[] gtid = new byte[input.readUnsignedByte()];
input.readFully(gtid);
final byte[] bq = new byte[input.readUnsignedByte()];
input.readFully(bq);
final int timeout = PackedInteger.readPackedInteger(input);
return () -> {
try {
return transactionServer.getTransactionService().getTransactionContext().findOrImportTransaction(new SimpleXid(fmt, gtid, bq), timeout);
} catch (XAException e) {
throw new SystemException(e.getMessage());
}
};
} else {
throw Logs.REMOTING.invalidTransactionType(type);
}
}
private void writeFailedResponse(final int invId, final Throwable e) {
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.APPLICATION_EXCEPTION);
os.writeShort(invId);
final Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(os)));
marshaller.writeObject(new RequestSendFailedException(e.getMessage() + "@" + channel.getConnection().getPeerURI(), e));
marshaller.writeByte(0);
marshaller.finish();
} catch (IOException e2) {
e2.addSuppressed(e);
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e2);
}
}
abstract class RemotingRequest implements Request {
final int invId;
SessionID sessionId;
final SecurityIdentity identity;
ClusterAffinity strongAffinityUpdate;
NodeAffinity weakAffinityUpdate;
RemotingRequest(final int invId, final SecurityIdentity identity) {
this.invId = invId;
this.identity = identity;
}
public Executor getRequestExecutor() {
return channel.getConnection().getEndpoint().getXnioWorker();
}
public SocketAddress getPeerAddress() {
return channel.getConnection().getPeerAddress();
}
public SocketAddress getLocalAddress() {
return channel.getConnection().getLocalAddress();
}
public String getProtocol() {
return channel.getConnection().getProtocol();
}
public boolean isBlockingCaller() {
return false;
}
public SecurityIdentity getSecurityIdentity() {
return identity;
}
public void writeNoSuchEJB() {
final String message = Logs.REMOTING.remoteMessageNoSuchEJB(getEJBIdentifier());
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.NO_SUCH_EJB);
os.writeShort(invId);
os.writeUTF(message);
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e);
} finally {
invocations.removeKey(invId);
}
}
public void writeWrongViewType() {
final String message = Logs.REMOTING.remoteMessageBadViewType(getEJBIdentifier());
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
if (version >= 3) {
os.writeByte(Protocol.BAD_VIEW_TYPE);
os.writeShort(invId);
os.writeUTF(message);
} else {
os.writeByte(Protocol.APPLICATION_EXCEPTION);
os.writeShort(invId);
final Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(os)));
marshaller.writeObject(Logs.REMOTING.invalidViewTypeForInvocation(message));
marshaller.writeByte(0);
marshaller.finish();
}
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e);
} finally {
invocations.removeKey(invId);
}
}
public void writeCancelResponse() {
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.CANCEL_RESPONSE);
os.writeShort(invId);
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e);
} finally {
invocations.removeKey(invId);
}
}
public void writeNotStateful() {
final String message = Logs.REMOTING.remoteMessageEJBNotStateful(getEJBIdentifier());
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.EJB_NOT_STATEFUL);
os.writeShort(invId);
os.writeUTF(message);
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e);
} finally {
invocations.removeKey(invId);
}
}
public void convertToStateful(@NotNull final SessionID sessionId) throws IllegalArgumentException, IllegalStateException {
Assert.checkNotNullParam("sessionId", sessionId);
final SessionID ourSessionId = this.sessionId;
if (ourSessionId != null) {
if (! sessionId.equals(ourSessionId)) {
throw new IllegalStateException();
}
} else {
this.sessionId = sessionId;
}
}
public C getProviderInterface(Class providerInterfaceType) {
final Connection connection = channel.getConnection();
return providerInterfaceType.isInstance(connection) ? providerInterfaceType.cast(connection) : null;
}
abstract int getEnlistmentStatus();
protected void writeFailure(Exception reason) {
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.APPLICATION_EXCEPTION);
os.writeShort(invId);
if (version >= 3) os.writeByte(getEnlistmentStatus());
final Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(os)));
marshaller.writeObject(reason);
marshaller.writeByte(0);
marshaller.finish();
} catch (IOException e) {
e.addSuppressed(reason);
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBResponseWrite(invId, channel, e);
} finally {
invocations.removeKey(invId);
}
}
public void updateStrongAffinity(@NotNull final Affinity affinity) {
Assert.checkNotNullParam("affinity", affinity);
if (affinity instanceof ClusterAffinity) {
strongAffinityUpdate = (ClusterAffinity) affinity;
}
}
public void updateWeakAffinity(@NotNull final Affinity affinity) {
Assert.checkNotNullParam("affinity", affinity);
if (affinity instanceof NodeAffinity) {
weakAffinityUpdate = (NodeAffinity) affinity;
}
}
}
final class RemotingSessionOpenRequest extends RemotingRequest implements SessionOpenRequest {
private final EJBIdentifier identifier;
final ExceptionSupplier, SystemException> transactionSupplier;
int txnCmd = 0; // assume nobody will ask about the transaction
RemotingSessionOpenRequest(final int invId, final EJBIdentifier identifier, final ExceptionSupplier, SystemException> transactionSupplier, final SecurityIdentity identity) {
super(invId, identity);
this.transactionSupplier = transactionSupplier;
this.identifier = identifier;
}
@NotNull
public EJBIdentifier getEJBIdentifier() {
return identifier;
}
public boolean hasTransaction() {
return transactionSupplier != null;
}
public Transaction getTransaction() throws SystemException, IllegalStateException {
final ExceptionSupplier, SystemException> transactionSupplier = this.transactionSupplier;
if (transactionSupplier == null) {
return null;
}
if (txnCmd != 0) {
throw new IllegalStateException();
}
final ImportResult> importResult = transactionSupplier.get();
if (importResult.isNew()) {
txnCmd = 1;
} else {
txnCmd = 2;
}
return importResult.getTransaction();
}
int getEnlistmentStatus() {
return txnCmd;
}
public void writeException(@NotNull final Exception exception) {
Assert.checkNotNullParam("exception", exception);
writeFailure(exception);
}
public void convertToStateful(@NotNull final SessionID sessionId) throws IllegalArgumentException, IllegalStateException {
super.convertToStateful(sessionId);
try (MessageOutputStream os = messageTracker.openMessageUninterruptibly()) {
os.writeByte(Protocol.OPEN_SESSION_RESPONSE);
os.writeShort(invId);
final byte[] encodedForm = sessionId.getEncodedForm();
PackedInteger.writePackedInteger(os, encodedForm.length);
os.write(encodedForm);
if (1 <= version && version <= 2) {
final Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
marshaller.start(new NoFlushByteOutput(Marshalling.createByteOutput(os)));
if (strongAffinityUpdate != null) {
marshaller.writeObject(strongAffinityUpdate);
} else {
marshaller.writeObject(new NodeAffinity(channel.getConnection().getEndpoint().getName()));
}
marshaller.finish();
} else {
assert version >= 3;
os.writeByte(txnCmd);
int updateBits = 0;
if (weakAffinityUpdate != null) {
updateBits |= Protocol.UPDATE_BIT_WEAK_AFFINITY;
}
if (strongAffinityUpdate != null) {
updateBits |= Protocol.UPDATE_BIT_STRONG_AFFINITY;
}
os.writeByte(updateBits);
if (weakAffinityUpdate != null) {
final String nodeName = weakAffinityUpdate.getNodeName();
final byte[] bytes = nodeName.getBytes(StandardCharsets.UTF_8);
PackedInteger.writePackedInteger(os, bytes.length);
os.write(bytes);
}
if (strongAffinityUpdate != null) {
final String clusterName = strongAffinityUpdate.getClusterName();
final byte[] bytes = clusterName.getBytes(StandardCharsets.UTF_8);
PackedInteger.writePackedInteger(os, bytes.length);
os.write(bytes);
}
}
} catch (IOException e) {
// nothing to do at this point; the client doesn't want the response
Logs.REMOTING.ioExceptionOnEJBSessionOpenResponseWrite(invId, channel, e);
}
}
}
final class RemotingInvocationRequest extends RemotingRequest implements InvocationRequest {
final EJBIdentifier identifier;
final EJBMethodLocator methodLocator;
final ServerClassResolver classResolver;
final Unmarshaller remaining;
int txnCmd = 0; // assume nobody will ask about the transaction
RemotingInvocationRequest(final int invId, final EJBIdentifier identifier, final EJBMethodLocator methodLocator, final ServerClassResolver classResolver, final Unmarshaller remaining, final SecurityIdentity identity) {
super(invId, identity);
this.identifier = identifier;
this.methodLocator = methodLocator;
this.classResolver = classResolver;
this.remaining = remaining;
}
public void convertToStateful(final SessionID sessionId) throws IllegalArgumentException, IllegalStateException {
if (version < 3) {
throw Logs.REMOTING.cannotAddSessionID();
}
super.convertToStateful(sessionId);
}
public Resolved getRequestContent(final ClassLoader classLoader) throws IOException, ClassNotFoundException {
classResolver.setClassLoader(classLoader);
SetretainContextDataKeys = new HashSet<>();
int responseCompressLevel = 0;
// resolve the rest of everything here
try (Unmarshaller unmarshaller = remaining) {
Affinity weakAffinity = Affinity.NONE;
ExceptionSupplier, SystemException> transactionSupplier = null;
final EJBLocator> locator;
if (version >= 3) {
weakAffinity = unmarshaller.readObject(Affinity.class);
if (weakAffinity == null) weakAffinity = Affinity.NONE;
int flags = unmarshaller.readUnsignedByte();
responseCompressLevel = flags & Protocol.COMPRESS_RESPONSE;
transactionSupplier = readTransaction(unmarshaller);
locator = unmarshaller.readObject(EJBLocator.class);
// do identity checks for these strings to guarantee integrity.
// noinspection StringEquality
if (identifier != locator.getIdentifier()) {
throw Logs.REMOTING.mismatchedMethodLocation();
}
} else {
assert version <= 2;
locator = unmarshaller.readObject(EJBLocator.class);
// do identity checks for these strings to guarantee integrity. can't check identifier because that class didn't exist in V2
//noinspection StringEquality
if (identifier.getAppName() != locator.getAppName() ||
identifier.getModuleName() != locator.getModuleName() ||
identifier.getBeanName() != locator.getBeanName() ||
identifier.getDistinctName() != locator.getDistinctName()) {
throw Logs.REMOTING.mismatchedMethodLocation();
}
// Protocol version <= 2 is dealing with the WEAK_AFFINITY_CONTEXT_KEY as an attachment, see writeInvocationResult(final Object result)
retainContextDataKeys.add(Affinity.WEAK_AFFINITY_CONTEXT_KEY);
}
Object[] parameters = new Object[methodLocator.getParameterCount()];
for (int i = 0; i < parameters.length; i ++) {
parameters[i] = unmarshaller.readObject();
}
int attachmentCount = PackedInteger.readPackedInteger(unmarshaller);
final Map attachments = new HashMap<>(attachmentCount);
for (int i = 0; i < attachmentCount; i ++) {
String attName = unmarshaller.readObject(String.class);
if (attName.equals(EJBClientInvocationContext.PRIVATE_ATTACHMENTS_KEY)) {
if (version <= 2) {
// only supported for protocol v1/2 - read out transaction ID
@SuppressWarnings("unchecked")
Map