All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.protocol.servlet.ServletAsyncContext Maven / Gradle / Ivy

The newest version!
package com.github.netty.protocol.servlet;

import com.github.netty.core.util.ExpiryLRUMap;
import com.github.netty.core.util.LoggerFactoryX;
import com.github.netty.core.util.LoggerX;
import com.github.netty.core.util.Recyclable;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Context for asynchronous processing
 *
 * @author wangzihao
 * 2018/7/15/015
 */
public class ServletAsyncContext implements AsyncContext, Recyclable {
    private static final LoggerX logger = LoggerFactoryX.getLogger(ServletAsyncContext.class);
    private static final int STATUS_INIT = 0;
    private static final int STATUS_START = 1;
    private static final int STATUS_DISPATCH = 2;
    private static final int STATUS_COMPLETE = 3;
    private static final AtomicInteger TASK_ID_INCR = new AtomicInteger();
    private static final ExpiryLRUMap TIMEOUT_TASK_MAP = new ExpiryLRUMap<>(256, Long.MAX_VALUE, Long.MAX_VALUE, null);

    static {
        TIMEOUT_TASK_MAP.setOnExpiryConsumer(e -> {
            ServletAsyncContext asyncContext = e.getData();
            if (asyncContext.status.get() >= STATUS_DISPATCH) {
                return;
            }
            asyncContext.executor.execute(asyncContext.timeoutTask);
        });
    }

    private final Executor executor;
    private final AtomicBoolean timeoutFlag = new AtomicBoolean();
    /**
     * Has it been recycled
     */
    private final AtomicBoolean recycleFlag = new AtomicBoolean(false);
    /**
     * Whether the IO thread has finished executing
     */
    private final AtomicBoolean ioThreadExecuteOverFlag = new AtomicBoolean(false);
    /**
     * 0=init, 1=start, 2=complete
     */
    private final AtomicInteger status = new AtomicInteger(STATUS_INIT);
    /**
     * Timeout time -> ms
     */
    private long timeout;
    private List asyncListenerWrapperList;
    private final ServletContext servletContext;
    private final ServletHttpExchange servletHttpExchange;
    private final Runnable timeoutTask = () -> {
        //Notice the timeout
        List asyncListenerWrapperList = this.asyncListenerWrapperList;
        if (timeoutFlag.compareAndSet(false, true) && asyncListenerWrapperList != null) {
            Throwable throwable = null;
            boolean eventNotify = false;
            for (ServletAsyncListenerWrapper listenerWrapper : new ArrayList<>(asyncListenerWrapperList)) {
                eventNotify = throwable != null;
                AsyncEvent event = new AsyncEvent(this, listenerWrapper.servletRequest, listenerWrapper.servletResponse, throwable);
                try {
                    listenerWrapper.asyncListener.onTimeout(event);
                } catch (Throwable e) {
                    if (throwable != null) {
                        e.addSuppressed(throwable);
                    }
                    throwable = e;
                }
            }
            if (throwable != null && !eventNotify) {
                logger.warn("asyncContext notifyEvent.onTimeout() error={}", throwable.toString(), throwable);
            }
        }
    };
    private ServletRequest servletRequest;
    private ServletResponse servletResponse;
    private /*volatile*/ Integer timeoutTaskId;
    private /*volatile*/ long startTimestamp;

    public ServletAsyncContext(ServletHttpExchange servletHttpExchange, ServletContext servletContext, Executor executor) {
        this.servletHttpExchange = Objects.requireNonNull(servletHttpExchange);
        this.servletContext = Objects.requireNonNull(servletContext);
        this.executor = Objects.requireNonNull(executor);
    }

    public ServletContext getServletContext() {
        return servletContext;
    }

    public boolean isTimeout() {
        return timeoutFlag.get();
    }

    @Override
    public ServletRequest getRequest() {
        return servletRequest;
    }

    @Override
    public ServletResponse getResponse() {
        return servletResponse;
    }

    @Override
    public boolean hasOriginalRequestAndResponse() {
        return servletHttpExchange.getRequest() == servletRequest && servletHttpExchange.getResponse() == servletResponse;
    }

    public void setServletResponse(ServletResponse servletResponse) {
        this.servletResponse = servletResponse;
    }

    public void setServletRequest(ServletRequest servletRequest) {
        this.servletRequest = servletRequest;
    }

    @Override
    public void dispatch() {
        String path;
        String contextPath;
        ServletRequest servletRequest = this.servletRequest;
        if (servletRequest instanceof HttpServletRequest) {
            path = ((HttpServletRequest) servletRequest).getRequestURI();
            contextPath = ((HttpServletRequest) servletRequest).getContextPath();
        } else {
            path = servletHttpExchange.getRequest().getRequestURI();
            contextPath = servletHttpExchange.getRequest().getContextPath();
        }
        if (contextPath.length() > 1) {
            path = path.substring(contextPath.length());
        }
        dispatch(servletContext, path);
    }

    @Override
    public void dispatch(String path) {
        dispatch(servletContext, path);
    }

    @Override
    public void dispatch(javax.servlet.ServletContext context, String path) {
        int statusInt = status.get();
        if (statusInt == STATUS_COMPLETE) {
            throw new IllegalStateException("The request associated with the AsyncContext has already completed processing.");
        }
        if (statusInt == STATUS_DISPATCH) {
            throw new IllegalStateException("Asynchronous dispatch operation has already been called. Additional asynchronous dispatch operation within the same asynchronous cycle is not allowed.");
        }
        status.set(STATUS_DISPATCH);
        String contextPath = context.getContextPath();
        String dispatcherPath = contextPath == null || contextPath.isEmpty() ? path : contextPath + path;
        HttpServletRequest httpServletRequest;
        HttpServletResponse httpServletResponse;
        if (servletRequest instanceof HttpServletRequest) {
            httpServletRequest = (HttpServletRequest) servletRequest;
        } else {
            httpServletRequest = servletHttpExchange.getRequest();
        }
        if (servletResponse instanceof HttpServletResponse) {
            httpServletResponse = (HttpServletResponse) servletResponse;
        } else {
            httpServletResponse = servletHttpExchange.getResponse();
        }

        ServletRequestDispatcher dispatcher = servletContext.getRequestDispatcher(dispatcherPath, DispatcherType.ASYNC);
        if (NettyMessageToServletRunnable.isCurrentRunAtRequesting()) {
            NettyMessageToServletRunnable.addAsyncContextDispatch(() -> dispatchAsync(dispatcher, httpServletRequest, httpServletResponse, path));
        } else {
            dispatchAsync(dispatcher, httpServletRequest, httpServletResponse, path);
        }
    }

    private void dispatchAsync(ServletRequestDispatcher dispatcher,
                               HttpServletRequest httpServletRequest,
                               HttpServletResponse httpServletResponse,
                               String path) {
        Throwable throwable = null;
        try {
            try {
                if (dispatcher == null) {
                    logger.warn("not found dispatcher. path={}", path);
                    httpServletResponse.setStatus(HttpServletResponse.SC_NOT_FOUND);
                } else {
                    dispatcher.dispatchAsync(httpServletRequest, httpServletResponse, this);
                }
            } catch (Throwable e) {
                throwable = e;
                onError(e);
            }
        } finally {
            complete(throwable);
        }
    }

    public void onError(Throwable throwable) {
        //Notify the complete
        if (asyncListenerWrapperList != null) {
            boolean eventNotify = false;
            for (ServletAsyncListenerWrapper listenerWrapper : new ArrayList<>(asyncListenerWrapperList)) {
                eventNotify = throwable != null;
                AsyncEvent event = new AsyncEvent(this, listenerWrapper.servletRequest, listenerWrapper.servletResponse, throwable);
                try {
                    listenerWrapper.asyncListener.onError(event);
                } catch (Throwable e) {
                    if (throwable != null) {
                        e.addSuppressed(throwable);
                    }
                    throwable = e;
                }
            }
            if (throwable != null && !eventNotify) {
                logger.warn("asyncContext notifyEvent.onError() error={}", throwable.toString(), throwable);
            }
        }
    }

    @Override
    public void complete() {
        complete(null);
    }

    public void complete(Throwable rootThrowable) {
        if (isComplete()) {
            return;
        }
        Integer timeoutTaskId = this.timeoutTaskId;
        if (timeoutTaskId != null) {
            TIMEOUT_TASK_MAP.remove(timeoutTaskId);
        }
        try {
            //Notify the complete
            if (asyncListenerWrapperList != null) {
                Throwable throwable = rootThrowable;
                boolean eventNotify = false;
                for (ServletAsyncListenerWrapper listenerWrapper : new ArrayList<>(asyncListenerWrapperList)) {
                    eventNotify = throwable != null;
                    AsyncEvent event = new AsyncEvent(this, listenerWrapper.servletRequest, listenerWrapper.servletResponse, throwable);
                    try {
                        listenerWrapper.asyncListener.onComplete(event);
                    } catch (Throwable e) {
                        if (throwable != null) {
                            e.addSuppressed(throwable);
                        }
                        throwable = e;
                    }
                }
                if (throwable != null && !eventNotify) {
                    logger.warn("asyncContext notifyEvent.onComplete() error={}", throwable.toString(), throwable);
                }
            }
            //If the handler has finished, recycle it yourself
            if (ioThreadExecuteOverFlag.get()) {
                recycle();
            }
        } finally {
            status.set(STATUS_COMPLETE);
        }
    }

    /**
     * Marks the end of the main thread
     */
    public void markIoThreadOverFlag() {
        this.ioThreadExecuteOverFlag.compareAndSet(false, true);
    }

    @Override
    public void recycle() {
        //If not, recycle
        if (recycleFlag.compareAndSet(false, true)) {
            servletHttpExchange.recycle();
        }
    }

    @Override
    public void start(Runnable runnable) {
        executor.execute(runnable);
    }

    void setStart() {
        if (status.get() >= STATUS_DISPATCH || timeoutFlag.get()) {
            throw new IllegalStateException("The request associated with the AsyncContext has already completed processing.");
        }
        startTimestamp = System.currentTimeMillis();
        if (timeoutTaskId == null) {
            timeoutTaskId = TASK_ID_INCR.getAndIncrement();
        } else {
            TIMEOUT_TASK_MAP.remove(timeoutTaskId);
        }
        status.set(STATUS_START);
        TIMEOUT_TASK_MAP.put(timeoutTaskId, this, timeout);

        //Notify the start
        if (asyncListenerWrapperList != null) {
            List list = new ArrayList<>(asyncListenerWrapperList);
            asyncListenerWrapperList.clear();
            Throwable throwable = null;
            boolean eventNotify = false;
            for (ServletAsyncListenerWrapper listenerWrapper : list) {
                eventNotify = throwable != null;
                AsyncEvent event = new AsyncEvent(this, listenerWrapper.servletRequest, listenerWrapper.servletResponse, throwable);
                try {
                    listenerWrapper.asyncListener.onStartAsync(event);
                } catch (Throwable e) {
                    if (throwable != null) {
                        e.addSuppressed(throwable);
                    }
                    throwable = e;
                }
            }
            if (throwable != null && !eventNotify) {
                logger.warn("asyncContext notifyEvent.onTimeout() error={}", throwable.toString(), throwable);
            }
        }
    }

    @Override
    public void addListener(AsyncListener listener) {
        addListener(listener, servletRequest, servletResponse);
    }

    @Override
    public void addListener(AsyncListener listener, ServletRequest servletRequest, ServletResponse servletResponse) {
        if (asyncListenerWrapperList == null) {
            asyncListenerWrapperList = new ArrayList<>(3);
        }
        asyncListenerWrapperList.add(new ServletAsyncListenerWrapper(listener, servletRequest, servletResponse));
    }

    @Override
    public  T createListener(Class clazz) throws ServletException {
        try {
            return clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            throw new ServletException("asyncContext createListener error=" + e, e);
        }
    }

    @Override
    public long getTimeout() {
        return timeout;
    }

    @Override
    public void setTimeout(long timeout) {
        if (isComplete()) {
            return;
        }
        this.timeout = timeout;
        if (timeout <= 0 && timeoutTaskId != null) {
            TIMEOUT_TASK_MAP.remove(timeoutTaskId);
            return;
        }

        if (startTimestamp != 0) {
            if (timeoutTaskId == null) {
                timeoutTaskId = TASK_ID_INCR.getAndIncrement();
            } else {
                TIMEOUT_TASK_MAP.remove(timeoutTaskId);
            }
            long startDiff = System.currentTimeMillis() - startTimestamp;
            long expiryTime = timeout - startDiff;
            if (expiryTime > 0) {
                TIMEOUT_TASK_MAP.put(timeoutTaskId, this, expiryTime);
            } else {
                timeoutTask.run();
            }
        }
    }

    public boolean isStarted() {
        return status.get() >= STATUS_START;
    }

    public boolean isComplete() {
        return status.get() == STATUS_COMPLETE;
    }

    public ServletHttpExchange getExchange() {
        return servletHttpExchange;
    }

    public boolean isChannelActive() {
        return servletHttpExchange.isChannelActive();
    }

    private static class ServletAsyncListenerWrapper {
        AsyncListener asyncListener;
        ServletRequest servletRequest;
        ServletResponse servletResponse;

        ServletAsyncListenerWrapper(AsyncListener asyncListener, ServletRequest servletRequest, ServletResponse servletResponse) {
            this.asyncListener = asyncListener;
            this.servletRequest = servletRequest;
            this.servletResponse = servletResponse;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy