All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jboss.ejb.client.EJBClientInvocationContext Maven / Gradle / Ivy

The newest version!
/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2017 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jboss.ejb.client;

import static java.lang.Math.max;
import static java.lang.Thread.holdsLock;

import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;

import org.jboss.ejb._private.Logs;
import org.jboss.ejb.client.annotation.ClientTransactionPolicy;
import org.wildfly.common.Assert;
import org.wildfly.common.annotation.NotNull;
import org.wildfly.discovery.Discovery;
import org.wildfly.security.auth.client.AuthenticationContext;

/**
 * An invocation context for Enterprise Beans invocations from an Enterprise Bean client
 *
 * @author David M. Lloyd
 * @author Jaikiran Pai
 */
public final class EJBClientInvocationContext extends AbstractInvocationContext {

    private static final Logs log = Logs.MAIN;

    public static final String PRIVATE_ATTACHMENTS_KEY = "org.jboss.ejb.client.invocation.attachments";
    /**
     * A context data key that may contain a Set of Strings. Any context data on the server side invocation context
     * stored under these keys will be returned to the client.
     */
    public static final String RETURNED_CONTEXT_DATA_KEY = "jboss.returned.keys";

    // Contextual stuff
    private final EJBInvocationHandler invocationHandler;
    private final Discovery discoveryContext;

    // Invocation data
    private final Object invokedProxy;
    private final Object[] parameters;
    private final EJBProxyInformation.ProxyMethodInfo methodInfo;
    private final EJBReceiverInvocationContext receiverInvocationContext = new EJBReceiverInvocationContext(this);
    private final EJBClientContext.InterceptorList interceptorList;
    private final long startTime = System.nanoTime();
    private long timeout;

    // Invocation state
    private final Object lock = new Object();
    private EJBReceiverInvocationContext.ResultProducer resultProducer;

    private volatile boolean cancelRequested;
    private boolean retryRequested;
    private EJBClientInvocationContext.State state = EJBClientInvocationContext.State.SENDING;
    private int remainingRetries;
    private Supplier pendingFailure;
    private List> suppressedExceptions;
    private Object cachedResult;

    private int interceptorChainIndex;
    private boolean blockingCaller;
    private int waiters = 0;

    EJBClientInvocationContext(final EJBInvocationHandler invocationHandler, final EJBClientContext ejbClientContext, final Object invokedProxy, final Object[] parameters, final EJBProxyInformation.ProxyMethodInfo methodInfo, final int allowedRetries, final Supplier authenticationContextSupplier, final Discovery discoveryContext) {
        super(invocationHandler.getLocator(), ejbClientContext, authenticationContextSupplier != null ? authenticationContextSupplier.get() : AuthenticationContext.captureCurrent());
        this.invocationHandler = invocationHandler;
        this.discoveryContext = discoveryContext;
        this.invokedProxy = invokedProxy;
        this.parameters = parameters;
        this.methodInfo = methodInfo;
        long timeout = invocationHandler.getInvocationTimeout();
        if (timeout == -1) {
            timeout = ejbClientContext.getInvocationTimeout();
        }
        this.timeout = timeout;
        remainingRetries = allowedRetries;
        interceptorList = getClientContext().getInterceptors(getViewClass(), getInvokedMethod());
    }

    enum State {
        // waiting states

        SENDING(true),
        SENT(true),
        WAITING(true),

        // completion states

        READY(false),
        CONSUMING(false),
        DONE(false),
        ;

        private final boolean waiting;

        State(final boolean waiting) {
            this.waiting = waiting;
        }

        boolean isWaiting() {
            return waiting;
        }
    }

    /**
     * Indicates to the server that a client side interceptor is interested in the context data that is stored
     * under the given key. Any object on the context data map (i.e. {@code InvocationContext.getContextData()} that is
     * present at the end of the server invocation will be serialised and returned to the client.
     *
     * If an object is present under this key and is not serializable then the request will fail.
     *
     * @param key The context data key the client interceptor is interested in
     */
    public void addReturnedContextDataKey(String key) {
        Object returnedData = getContextData().get(RETURNED_CONTEXT_DATA_KEY);
        if(returnedData == null) {
            getContextData().put(RETURNED_CONTEXT_DATA_KEY, returnedData = new HashSet<>());
        } else if(!(returnedData instanceof Set)) {
            throw Logs.INVOCATION.returnedContextDataKeyOfWrongType();
        }
        Set keys = (Set) returnedData;
        keys.add(key);
    }

    /**
     * Get a value attached to the proxy.
     *
     * @param key the attachment key
     * @param  the value type
     * @return the value, or {@code null} if there is none
     */
    public  T getProxyAttachment(AttachmentKey key) {
        return invocationHandler.getAttachment(key);
    }

    /**
     * Remove a value attached to the proxy.
     *
     * @param key the attachment key
     * @param  the value type
     * @return the value, or {@code null} if there is none
     */
    public  T removeProxyAttachment(final AttachmentKey key) {
        return invocationHandler.removeAttachment(key);
    }

    /**
     * Determine whether the method is marked client-asynchronous, meaning that invocation should be asynchronous regardless
     * of whether the server-side method is asynchronous.
     *
     * @return {@code true} if the method is marked client-asynchronous, {@code false} otherwise
     */
    public boolean isClientAsync() {
        return invocationHandler.isAsyncHandler() || methodInfo.isClientAsync();
    }

    /**
     * Determine whether the method is definitely synchronous, that is, it is not marked client-async, and the return
     * value of the method is not {@code void} or {@code Future}.
     *
     * @return {@code true} if the method is definitely synchronous, {@code false} if the method may be asynchronous
     */
    public boolean isSynchronous() {
        return ! isClientAsync() && methodInfo.isSynchronous();
    }

    /**
     * Determine whether the method is marked idempotent, meaning that the method may be invoked more than one time with
     * no additional effect.
     *
     * @return {@code true} if the method is marked idempotent, {@code false} otherwise
     */
    public boolean isIdempotent() {
        return methodInfo.isIdempotent();
    }

    /**
     * Determine whether the method has an explicit transaction policy set.
     *
     * @return the transaction policy, if any, or {@code null} if none was explicitly set
     */
    public ClientTransactionPolicy getTransactionPolicy() {
        return methodInfo.getTransactionPolicy();
    }

    /**
     * Determine whether the request is expected to be compressed.
     *
     * @return {@code true} if the request is expected to be compressed, {@code false} otherwise
     */
    public boolean isCompressRequest() {
        if (methodInfo.getCompressRequestHint() == EJBProxyInformation.ProxyMethodInfo.CompressionHint.NONE) {
            return getClientContext().getDefaultCompression() != -1;
        }
        return methodInfo.getCompressRequestHint() == EJBProxyInformation.ProxyMethodInfo.CompressionHint.TRUE;
    }

    /**
     * Determine whether the response is expected to be compressed.
     *
     * @return {@code true} if the response is expected to be compressed, {@code false} otherwise
     */
    public boolean isCompressResponse() {
        if (methodInfo.getCompressResponseHint() == EJBProxyInformation.ProxyMethodInfo.CompressionHint.NONE) {
            return getClientContext().getDefaultCompression() != -1;
        }
        return methodInfo.getCompressResponseHint() == EJBProxyInformation.ProxyMethodInfo.CompressionHint.TRUE;
    }

    /**
     * Get the compression hint level.  If no compression hint is given, -1 is returned.
     *
     * @return the compression hint level, or -1 for no compression hint
     */
    public int getCompressionLevel() {
        if(methodInfo.getCompressionLevel() != -1){
            return methodInfo.getCompressionLevel();
        }
        return getClientContext().getDefaultCompression();
    }

    /**
     * Get the method type signature string, used to identify the method.
     *
     * @return the method signature string
     */
    public String getMethodSignatureString() {
        return methodInfo.getSignature();
    }

    /**
     * Get the Enterprise Beans method locator.
     *
     * @return the Enterprise Beans method locator
     */
    public EJBMethodLocator getMethodLocator() {
        return methodInfo.getMethodLocator();
    }

    /**
     * Determine whether this invocation is currently blocking the calling thread.
     *
     * @return {@code true} if the calling thread is being blocked; {@code false} otherwise
     */
    public boolean isBlockingCaller() {
        synchronized (lock) {
            return blockingCaller;
        }
    }

    /**
     * Establish whether this invocation is currently blocking the calling thread.
     *
     * @param blockingCaller {@code true} if the calling thread is being blocked; {@code false} otherwise
     */
    public void setBlockingCaller(final boolean blockingCaller) {
        synchronized (lock) {
            this.blockingCaller = blockingCaller;
        }
    }

    /**
     * Add a suppressed exception to the request.
     *
     * @param cause the suppressed exception (must not be {@code null})
     */
    public void addSuppressed(Throwable cause) {
        Assert.checkNotNullParam("cause", cause);
        synchronized (lock) {
            if (state == State.DONE) {
                return;
            }
            if (suppressedExceptions == null) {
                suppressedExceptions = new ArrayList<>();
            }
            suppressedExceptions.add(() -> cause);
            checkStateInvariants();
        }
    }

    /**
     * Add a suppressed exception to the request.
     *
     * @param cause the suppressed exception (must not be {@code null})
     */
    public void addSuppressed(Supplier cause) {
        Assert.checkNotNullParam("cause", cause);
        synchronized (lock) {
            if (state == State.DONE) {
                return;
            }
            if (suppressedExceptions == null) {
                suppressedExceptions = new ArrayList<>();
            }
            suppressedExceptions.add(cause);
            checkStateInvariants();
        }
    }

    public void requestRetry() {
        if (Logs.INVOCATION.isDebugEnabled()) {
            Logs.INVOCATION.debugf("Requesting retry of invocation!");
        }
        synchronized (lock) {
            retryRequested = true;
        }
    }

    void sendRequestInitial() {
        assert checkState() == State.SENDING;
        for (;;) {
            assert interceptorChainIndex == 0;
            try {
                getAuthenticationContext().runExConsumer(EJBClientInvocationContext::sendRequest, this);
                // back to the start of the chain; decide what to do next.
                synchronized (lock) {
                    try {
                        assert state == State.SENT || state == State.READY;
                        // from here we can go to: READY, or WAITING, or retry SENDING.
                        Supplier pendingFailure = this.pendingFailure;
                        EJBReceiverInvocationContext.ResultProducer resultProducer = this.resultProducer;
                        // now see if we're retrying or returning.
                        if (pendingFailure != null) {
                            if (resultProducer == null) {
                                this.resultProducer = new ThrowableResult(pendingFailure);
                            } else {
                                addSuppressed(pendingFailure);
                            }

                            // Run the result interceptor chain to see if a retry is necessary
                            this.pendingFailure = null;
                            transition(State.CONSUMING);
                            try {
                                getResult(true);
                            } catch (Throwable t) {
                                if (state == State.SENDING) {
                                    // Retry has been requested
                                    continue;
                                }
                            }

                            // Return with getResult's set state
                            return;
                        }

                        if (resultProducer != null) {
                            transition(State.READY);
                            return;
                        }

                        transition(State.WAITING);
                        return;
                    } finally {
                        checkStateInvariants();
                    }
                }
                // not reachable
            } catch (Throwable t) {
                log.tracef("Encountered exception when calling sendRequestInitial: exception = %s)", t.getMessage());
                // back to the start of the chain; decide what to do next.
                synchronized (lock) {
                    if (state == State.SENDING) {
                        // didn't make it to the end of the chain even... but we won't suppress the thrown exception
                        transition(State.SENT);
                    }
                    assert state == State.SENT || state == State.READY;
                    try {
                        // from here we can go to: FAILED, READY, or retry SENDING.
                        Supplier pendingFailure = this.pendingFailure;
                        EJBReceiverInvocationContext.ResultProducer resultProducer = this.resultProducer;
                        if (resultProducer != null) {
                            // READY, even if we have a pending failure.
                            if (pendingFailure != null) {
                                addSuppressed(t);
                                addSuppressed(pendingFailure);
                                this.pendingFailure = null;
                            }
                            transition(State.READY);
                            return;
                        }
                        // FAILED, or retry SENDING.
                        if (! retryRequested || remainingRetries == 0) {
                            // nobody wants retry, or there are none left; go to FAILED
                            if (pendingFailure != null) {
                                addSuppressed(pendingFailure);
                            }
                            if (t instanceof Exception) {
                                this.resultProducer = new EJBReceiverInvocationContext.ResultProducer.Failed((Exception) t);
                            } else {
                                this.resultProducer = new EJBReceiverInvocationContext.ResultProducer.Failed(new UndeclaredThrowableException(t));
                            }
                            this.pendingFailure = null;
                            transition(State.READY);
                            return;
                        }
                        log.tracef("retrying the invocation!: remaining retries = %d", remainingRetries);
                        // retry SENDING
                        if (pendingFailure != null) {
                            addSuppressed(pendingFailure);
                        }
                        setReceiver(null);
                        this.pendingFailure = null;
                        transition(State.SENDING);
                        retryRequested = false;
                        remainingRetries --;
                    } finally {
                        checkStateInvariants();
                    }
                }
                // record for later
                addSuppressed(t);
                // redo the loop
                //noinspection UnnecessaryContinue
                continue;
            }
        }
    }

    State checkState() {
        synchronized (lock) {
            return state;
        }
    }

    /**
     * Proceed with sending the request normally.
     *
     * @throws Exception if the request was not successfully sent
     */
    public void sendRequest() throws Exception {
        final Object lock = this.lock;
        Assert.assertNotHoldsLock(lock);
        final EJBClientInterceptorInformation[] chain = interceptorList.getInformation();
        synchronized (lock) {
            if (state != State.SENDING) {
                throw Logs.MAIN.sendRequestCalledDuringWrongPhase();
            }
        }
        final int idx = interceptorChainIndex ++;
        try {
            if (cancelRequested) {
                synchronized (lock) {
                    transition(State.SENT);
                    resultReady(CANCELLED);
                    checkStateInvariants();
                }
            } else if (chain.length == idx) {
                // End of the chain processing; deliver to receiver or throw an exception.
                final URI destination = getDestination();
                final EJBReceiver receiver;
                try {
                    if (Logs.INVOCATION.isDebugEnabled()) {
                        Logs.INVOCATION.debugf("sendRequest: setting receiver, strong affinity = %s, weak affinity = %s, remote destination is: %s",
                                getLocator().getAffinity(), invocationHandler.getWeakAffinity(), destination);
                    }
                    receiver = getClientContext().resolveReceiver(destination, getLocator());
                } catch (Throwable t) {
                    synchronized (lock) {
                        if (state != State.SENT) {
                            transition(State.SENT);
                        }
                        checkStateInvariants();
                    }
                    throw t;
                }
                setReceiver(receiver);
                synchronized (lock) {
                    transition(State.SENT);
                    checkStateInvariants();
                }
                try {
                    receiver.processInvocation(receiverInvocationContext);
                } catch (Throwable t) {
                    synchronized (lock) {
                        if (state != State.SENT) {
                            transition(State.SENT);
                        }
                        checkStateInvariants();
                    }
                    throw t;
                }
            } else {
                try {
                    if (Logs.INVOCATION.isDebugEnabled()) {
                        Logs.INVOCATION.debugf("sendRequest: calling interceptor: %s", chain[idx].getInterceptorInstance());
                    }
                    chain[idx].getInterceptorInstance().handleInvocation(this);
                } catch (Throwable t) {
                    synchronized (lock) {
                        if (state != State.SENT) {
                            transition(State.SENT);
                        }
                        checkStateInvariants();
                    }
                    throw t;
                }
                synchronized (lock) {
                    try {
                        if (state != State.SENT && state != State.READY) {
                            assert state == State.SENDING;
                            transition(State.SENT);
                            throw Logs.INVOCATION.requestNotSent();
                        }
                    } finally {
                        checkStateInvariants();
                    }
                }
            }
        } finally {
            interceptorChainIndex--;
        }
        // return to enclosing interceptor
        return;
    }

    /**
     * Get the invocation result from this request.  The result is not actually acquired unless all interceptors
     * call this method.  Should only be called from {@link EJBClientInterceptor#handleInvocationResult(EJBClientInvocationContext)}.
     *
     * @return the invocation result
     * @throws Exception if the invocation did not succeed
     */
    public Object getResult() throws Exception {
        return getResult(false);
    }

    /**
     * Get the invocation result (internal operation).
     *
     * @param retry {@code true} if the caller is the retry process, {@code false} for user call
     * @return the invocation result
     * @throws Exception if the invocation did not succeed
     */
    Object getResult(boolean retry) throws Exception {
        final EJBClientContext.InterceptorList list = getClientContext().getInterceptors(getViewClass(), getInvokedMethod());
        final EJBClientInterceptorInformation[] chain = list.getInformation();
        final EJBReceiverInvocationContext.ResultProducer resultProducer;
        Throwable fail = null;
        final int idx = this.interceptorChainIndex;
        final Object lock = this.lock;
        synchronized (lock) {
            try {
                if (idx == 0) {
                    if (retry) {
                        assert state == State.CONSUMING;
                    } else {
                        while (state == State.CONSUMING) try {
                            checkStateInvariants();
                            try {
                                waiters++;
                                lock.wait();
                            } finally {
                                waiters--;
                            }
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            throw Logs.MAIN.operationInterrupted();
                        }
                        if (state == State.DONE) {
                            Supplier pendingFailure = this.pendingFailure;
                            if (pendingFailure != null) {
                                fail = pendingFailure.get();
                                if (fail == null) {
                                    return cachedResult;
                                }
                            } else {
                                return cachedResult;
                            }
                        } else if (state != State.READY) {
                            throw Logs.MAIN.getResultCalledDuringWrongPhase();
                        } else {
                            transition(State.CONSUMING);
                        }
                    }
                }
                resultProducer = this.resultProducer;
            } finally {
                checkStateInvariants();
            }
        }
        if (fail != null) try {
            throw fail;
        } catch (Exception | Error e) {
            throw e;
        } catch (Throwable t) {
            throw new UndeclaredThrowableException(t);
        }
        this.interceptorChainIndex = idx + 1;
        try {
            final Object result;
            try {
                if (idx == chain.length) {
                    result = resultProducer.getResult();
                } else {
                    result = chain[idx].getInterceptorInstance().handleInvocationResult(this);
                }
                if (idx == 0) {
                    synchronized (lock) {
                        transition(State.DONE);
                        pendingFailure = null;
                        suppressedExceptions = null;
                        cachedResult = result;
                        this.resultProducer = null;
                        checkStateInvariants();
                    }
                }
                return result;
            } catch (Throwable t) {
                if (idx == 0) {
                    log.tracef("Encountered exception while calling getResult(): exception = %s", t.toString());
                    synchronized (lock) {
                        // retry if we can
                        this.resultProducer = null;
                        List> suppressedExceptions = this.suppressedExceptions;
                        final int remainingRetries = this.remainingRetries;
                        final boolean retryRequested = this.retryRequested;
                        if (retryRequested && remainingRetries > 0) {
                            log.tracef("Will retry (requested = %s, remaining = %d)", retryRequested, remainingRetries);
                            if (suppressedExceptions == null) {
                                suppressedExceptions = this.suppressedExceptions = new ArrayList<>();
                            }
                            suppressedExceptions.add(() -> t);
                            this.remainingRetries --;
                            this.retryRequested = false;
                            this.cachedResult = null;
                            this.pendingFailure = null;
                            setReceiver(null);
                            transition(State.SENDING);
                            checkStateInvariants();
                        } else {
                            log.tracef("Will not retry (requested = %s, remaining = %d)", retryRequested, remainingRetries);
                            pendingFailure = () -> t;
                            if (suppressedExceptions != null) {
                                this.suppressedExceptions = null;
                                for (Supplier supplier : suppressedExceptions) {
                                    try {
                                        t.addSuppressed(supplier.get());
                                    } catch (Throwable ignored) {}
                                }
                            }
                            transition(State.DONE);
                            checkStateInvariants();
                        }
                    }
                }
                throw t;
            } finally {
                if (idx == 0) {
                    if (Logs.INVOCATION.isDebugEnabled()) {
                        Logs.INVOCATION.debugf("getResult(): invocation returned, relocating EJB: strong affinity = %s, weak affinity = %s", getLocator().getAffinity(), getWeakAffinity());
                    }
                    // relocate the EJB
                    invocationHandler.setWeakAffinity(getWeakAffinity());
                    invocationHandler.setStrongAffinity(getLocator().getAffinity());
                }
            }
        } finally {
            interceptorChainIndex = idx;
        }
    }

    /**
     * Discard the result from this request.  Should only be called from {@link EJBClientInterceptor#handleInvocationResult(EJBClientInvocationContext)}.
     *
     * @throws IllegalStateException if there is no result to discard
     */
    public void discardResult() throws IllegalStateException {
        resultReady(EJBClientInvocationContext.ONE_WAY);
    }

    void resultReady(EJBReceiverInvocationContext.ResultProducer resultProducer) {
        Assert.checkNotNullParam("resultProducer", resultProducer);
        synchronized (lock) {
            if (state.isWaiting() && this.resultProducer == null) {
                this.resultProducer = resultProducer;
                if (state == State.WAITING || state == State.SENT) {
                    transition(State.READY);
                }
                checkStateInvariants();
                if (log.isTraceEnabled()) {
                    log.tracef("Result is ready for %s: result producer: %s, current state: %s", this, resultProducer, state);
                }
                return;
            }
            checkStateInvariants();
            if (log.isTraceEnabled()) {
                log.tracef("Result discarded for %s: result producer: %s, current state: %s", this, resultProducer, state);
            }
        }
        // for whatever reason, we don't care
        resultProducer.discardResult();
    }

    /**
     * Get the invoked proxy object.
     *
     * @return the invoked proxy
     */
    public Object getInvokedProxy() {
        return invokedProxy;
    }

    /**
     * Get the invoked proxy method.
     *
     * @return the invoked method
     */
    public Method getInvokedMethod() {
        return methodInfo.getMethod();
    }

    /**
     * Get the invocation method parameters.
     *
     * @return the invocation method parameters
     */
    public Object[] getParameters() {
        return parameters;
    }

    /**
     * Get the remaining invocation time in the given unit.  If there is no invocation timeout, {@link Long#MAX_VALUE} is
     * always returned.  If the invocation time has elapsed, 0 is returned.
     *
     * @param timeUnit the time unit (must not be {@code null})
     * @return the invocation's remaining time in the provided unit
     */
    public long getRemainingInvocationTime(TimeUnit timeUnit) {
        Assert.checkNotNullParam("timeUnit", timeUnit);
        final long timeout = this.timeout;
        if (timeout <= 0L) {
            return Long.MAX_VALUE;
        }
        return max(0L, timeUnit.convert(timeout - (System.nanoTime() - startTime) / 1_000_000L, TimeUnit.MILLISECONDS));
    }

    @Override
    public  void setLocator(EJBLocator locator) {
        super.setLocator(locator);

        Affinity affinity = locator.getAffinity();
        if (affinity instanceof ClusterAffinity) {
            ClusterAffinityInterest interest = invocationHandler.getAttachment(ClusterAffinityInterest.KEY);
            if (interest != null) {
                interest.notifyAssignment((ClusterAffinity)affinity);
            }
        }
    }

    @NotNull
    Discovery getDiscovery() {
        return discoveryContext;
    }

    Future getFutureResponse() {
        return new FutureResponse();
    }

    void proceedAsynchronously() {
        if (getInvokedMethod().getReturnType() == void.class) {
            resultReady(EJBReceiverInvocationContext.ResultProducer.NULL);
        }
    }

    /**
     * Transition to a new state, guarded by assertions.
     *
     * @param newState the state to transition to (must not be {@code null})
     */
    private void transition(State newState) {
        final Object lock = this.lock;
        Assert.assertHoldsLock(lock);
        final State oldState = this.state;
        if (oldState == newState) {
            return;
        }
        if (log.isTraceEnabled()) {
            StackTraceElement caller = (new Exception()).getStackTrace()[1];
            log.tracef("Transitioning %s from %s to %s (%s)", this, oldState, newState, caller);
        }
        switch (oldState) {
            case SENDING: {
                assert newState == State.SENT;
                break;
            }
            case SENT: {
                assert newState == State.READY || newState == State.SENDING || newState == State.DONE || newState == State.WAITING || newState == State.CONSUMING;
                break;
            }
            case WAITING: {
                assert newState == State.DONE || newState == State.READY || newState == State.CONSUMING;
                break;
            }
            case READY: {
                assert newState == State.CONSUMING;
                break;
            }
            case CONSUMING: {
                assert newState == State.SENDING || newState == State.DONE || newState == State.READY;
                break;
            }
            default: {
                assert false;
                break;
            }
        }
        switch (newState) {
            case READY:
            case DONE: {
                this.remainingRetries = 0;
                // fall thru
            }
            case WAITING:{
                if(waiters > 0) {
                    lock.notifyAll();
                }
                break;
            }
        }
        // everything is OK
        this.state = newState;
    }

    /**
     * Check the invariants of the current state with assertions before the caller releases the lock.
     */
    private void checkStateInvariants() {
        final Object lock = this.lock;
        Assert.assertHoldsLock(lock);
        final State state = this.state;
        switch (state) {
            case SENDING: {
                assert resultProducer == null && cachedResult == null && getReceiver() == null;
                break;
            }
            case SENT: {
                assert cachedResult == null;
                break;
            }
            case WAITING: {
                assert resultProducer == null && pendingFailure == null && cachedResult == null;
                break;
            }
            case READY: {
                assert resultProducer != null && pendingFailure == null && cachedResult == null && remainingRetries == 0;
                break;
            }
            case CONSUMING: {
                assert resultProducer != null && pendingFailure == null && cachedResult == null;
                break;
            }
            case DONE: {
                assert resultProducer == null && (pendingFailure == null || pendingFailure != null && cachedResult == null) && remainingRetries == 0;
                break;
            }
            default: {
                assert false;
                break;
            }
        }
    }

    /**
     * Wait to determine whether this invocation was cancelled.
     *
     * @return {@code true} if the invocation was cancelled; {@code false} if it completed or failed or the thread was
     *  interrupted
     */
    public boolean awaitCancellationResult() {
        final Object lock = this.lock;
        Assert.assertNotHoldsLock(lock);
        synchronized (lock) {
            for (;;) {
                if (resultProducer == CANCELLED) {
                    return true;
                } else if (! state.isWaiting()) {
                    return false;
                }
                try {
                    checkStateInvariants();
                    try {
                        waiters++;
                        lock.wait();
                    } finally {
                        waiters--;
                    }
                } catch (InterruptedException ex) {
                    Thread.currentThread().interrupt();
                    return false;
                }
            }
        }
    }

    Object awaitResponse() throws Exception {
        Assert.assertNotHoldsLock(lock);
        boolean intr = false, timedOut = false;
        try {
            final Object lock = this.lock;
            final long timeout = this.timeout;
            synchronized (lock) {
                try {
                    out: for (;;) {
                        switch (state) {
                            case SENDING:
                            case SENT:
                            case CONSUMING:
                            case WAITING: {
                                if (timeout <= 0 || timedOut) {
                                    // no timeout; lighter code path
                                    try {
                                        checkStateInvariants();
                                        try {
                                            waiters++;
                                            lock.wait();
                                        } finally {
                                            waiters--;
                                        }
                                    } catch (InterruptedException e) {
                                        intr = true;
                                    }
                                } else {
                                    // timeout in ms, elapsed time in nanosecs
                                    long remaining = max(0L, timeout * 1_000_000L - max(0L, System.nanoTime() - startTime));
                                    if (remaining == 0L) {
                                        // timed out
                                        timedOut = true;
                                        this.timeout = 0;
                                        if (state == State.CONSUMING) {
                                            addSuppressed(new TimeoutException("No invocation response received in " + timeout + " milliseconds"));
                                            transition(State.READY);
                                        } else {
                                            resultReady(new ThrowableResult(() -> new TimeoutException("No invocation response received in " + timeout + " milliseconds")));
                                        }
                                    } else try {
                                        checkStateInvariants();
                                        try {
                                            waiters++;
                                            lock.wait(remaining / 1_000_000L, (int) (remaining % 1_000_000L));
                                        } finally {
                                            waiters--;
                                        }
                                    } catch (InterruptedException e) {
                                        intr = true;
                                    }
                                }
                                break;
                            }
                            case READY: {
                                // we have to get the result, so break out of here.
                                checkStateInvariants();
                                break out;
                            }
                            case DONE: {
                                checkStateInvariants();
                                if (pendingFailure != null) {
                                    try {
                                        throw pendingFailure.get();
                                    } catch (Error | Exception e) {
                                        throw e;
                                    } catch (Throwable t) {
                                        throw new UndeclaredThrowableException(t);
                                    }
                                }
                                return cachedResult;
                            }
                            default: {
                                throw new IllegalStateException();
                            }
                        }
                    }
                } finally {
                    blockingCaller = false;
                }
            }
            return getResult();
        } finally {
            if (intr) Thread.currentThread().interrupt();
            if (timedOut) {
                final EJBReceiver receiver = getReceiver();
                if (receiver != null) receiver.cancelInvocation(receiverInvocationContext, true);
            }
        }
    }

    void setDiscardResult() {
        final Object lock = this.lock;
        assert !holdsLock(lock);
        final EJBReceiverInvocationContext.ResultProducer resultProducer;
        synchronized (lock) {
            resultProducer = this.resultProducer;
            this.resultProducer = EJBReceiverInvocationContext.ResultProducer.NULL;
            // result is waiting, discard it
            if (state == State.WAITING) {
                transition(State.DONE);
            }
            // fall out of the lock to discard the old result (if any)
            checkStateInvariants();
        }
        if (resultProducer != null) resultProducer.discardResult();
    }

    void cancelled() {
        resultReady(CANCELLED);
    }

    void failed(Exception exception, Executor retryExecutor) {
        final Object lock = this.lock;
        synchronized (lock) {
            log.tracef("Invocation marked failed, state is currently: %s", state);
            switch (state) {
                case CONSUMING:
                case DONE: {
                    // ignore
                    return;
                }
                case SENDING: {
                    throw new IllegalStateException();
                }
                case SENT: {
                    final Supplier pendingFailure = this.pendingFailure;
                    if (pendingFailure != null) {
                        addSuppressed(pendingFailure);
                    }
                    this.pendingFailure = () -> exception;
                    return;
                }
                case READY: {
                    addSuppressed(exception);
                    return;
                }
                case WAITING: {
                    // moving to CONSUMING, which requires a resultProducer.
                    this.resultProducer = new ThrowableResult(() -> exception);
                    this.pendingFailure = null;
                    // process result immediately, possibly retrying at the end
                    // retry SENDING via CONSUMING
                    transition(State.CONSUMING);
                    // redo the request
                    checkStateInvariants();
                    break;
                }
                default: {
                    throw Assert.impossibleSwitchCase(state);
                }
            }
        }
        retryExecutor.execute(this::retryOperation);
        return;
    }

    void retryOperation() {
        try {
            getResult(true);
        } catch (Throwable t) {
            final boolean retry;
            synchronized (lock) {
                retry = state == EJBClientInvocationContext.State.SENDING;
            }
            if (retry) sendRequestInitial();
        }
    }

    final class FutureResponse implements Future {

        FutureResponse() {
        }

        public boolean cancel(final boolean mayInterruptIfRunning) {
            final Object lock = EJBClientInvocationContext.this.lock;
            assert !holdsLock(lock);
            synchronized (lock) {
                if (state == EJBClientInvocationContext.State.DONE) {
                    // cannot cancel now; also resultProducer is gone
                    return pendingFailure == CANCELLED_PRODUCER;
                } else if (! state.isWaiting()) {
                    // cannot cancel now
                    return resultProducer == CANCELLED;
                } else {
                    if (resultProducer == CANCELLED) {
                        return true;
                    }
                    // at this point the task is running and we are allowed to interrupt it. So set
                    // the cancel request flag and a fall out to send the request
                    cancelRequested = true;
                }
            }
            final EJBReceiver receiver = getReceiver();
            final boolean result = receiver != null && receiver.cancelInvocation(receiverInvocationContext, mayInterruptIfRunning);
            if (! result) {
                synchronized (lock) {
                    if (resultProducer == CANCELLED || state == EJBClientInvocationContext.State.DONE && pendingFailure == CANCELLED_PRODUCER) {
                        return true;
                    }
                }
            }
            return result;
        }

        public boolean isCancelled() {
            final Object lock = EJBClientInvocationContext.this.lock;
            assert !holdsLock(lock);
            synchronized (lock) {
                return state == EJBClientInvocationContext.State.DONE ? pendingFailure == CANCELLED_PRODUCER : resultProducer == CANCELLED;
            }
        }

        public boolean isDone() {
            final Object lock = EJBClientInvocationContext.this.lock;
            assert !holdsLock(lock);
            synchronized (lock) {
                if (state == EJBClientInvocationContext.State.CONSUMING) {
                    return retryRequested && remainingRetries > 0 && resultProducer instanceof ThrowableResult;
                } else {
                    // TODO: we should also calculate whether the invocation timed out
                    return ! state.isWaiting();
                }
            }
        }

        public Object get() throws InterruptedException, ExecutionException {
            try {
                return awaitResponse();
            } catch (ExecutionException | InterruptedException e) {
                throw e;
            } catch (Exception e) {
                throw log.remoteInvFailed(e);
            }
        }

        public Object get(final long timeout, final TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            final Object lock = EJBClientInvocationContext.this.lock;
            assert !holdsLock(lock);
            final long handlerInvTimeout = invocationHandler.getInvocationTimeout();
            final long invocationTimeout = handlerInvTimeout != -1 ? handlerInvTimeout : getClientContext().getInvocationTimeout();
            final long ourStart = System.nanoTime();
            // invocationTimeout in ms, elapsed time in nanosecs
            if (unit.convert(max(0L, invocationTimeout * 1_000_000L - max(0L, ourStart - startTime)), TimeUnit.NANOSECONDS) <= timeout) {
                // the invocation will expire before we finish
                return get();
            }
            long remaining = unit.toNanos(timeout);
            synchronized (lock) {
                out: for (;;) {
                    switch (state) {
                        case SENDING:
                        case SENT:
                        case CONSUMING:
                        case WAITING: {
                            checkStateInvariants();
                            if (remaining <= 0L) {
                                throw log.timedOut();
                            }
                            try {
                                waiters++;
                                lock.wait(remaining / 1_000_000L, (int) (remaining % 1_000_000L));
                            } finally {
                                waiters--;
                            }
                            remaining = unit.toNanos(timeout) - (System.nanoTime() - ourStart);
                            break;
                        }
                        case READY: {
                            // we have to get the result, so break out of here.
                            checkStateInvariants();
                            break out;
                        }
                        case DONE: {
                            checkStateInvariants();
                            if (pendingFailure != null) {
                                throw log.remoteInvFailed(pendingFailure.get());
                            }
                            return cachedResult;
                        }
                        default:
                            throw new IllegalStateException();
                    }
                }
            }
            // we've gotten the result
            try {
                return getResult();
            } catch (ExecutionException | InterruptedException e) {
                throw e;
            } catch (Exception e) {
                throw log.remoteInvFailed(e);
            }
        }
    }

    static final class ThrowableResult implements EJBReceiverInvocationContext.ResultProducer {
        private final Supplier pendingFailure;

        ThrowableResult(final Supplier pendingFailure) {
            this.pendingFailure = pendingFailure;
        }

        public Object getResult() throws Exception {
            try {
                throw pendingFailure.get();
            } catch (Exception | Error e) {
                throw e;
            } catch (Throwable t) {
                throw new UndeclaredThrowableException(t);
            }
        }

        public void discardResult() {
            // ignored
        }
    }

    static final Supplier CANCELLED_PRODUCER = Logs.INVOCATION::requestCancelled;

    static final ThrowableResult CANCELLED = new ThrowableResult(CANCELLED_PRODUCER);

    static final ThrowableResult ONE_WAY = new ThrowableResult(Logs.INVOCATION::oneWayInvocation);
}