All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tomcat.websocket.WsSession Maven / Gradle / Ivy

There is a newer version: 11.0.2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import javax.naming.NamingException;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCode;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.MessageHandler;
import javax.websocket.MessageHandler.Partial;
import javax.websocket.MessageHandler.Whole;
import javax.websocket.PongMessage;
import javax.websocket.RemoteEndpoint;
import javax.websocket.SendResult;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpointConfig;
import javax.websocket.server.ServerEndpointConfig.Configurator;

import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.InstanceManager;
import org.apache.tomcat.util.ExceptionUtils;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.websocket.pojo.PojoEndpointServer;
import org.apache.tomcat.websocket.server.DefaultServerEndpointConfigurator;

public class WsSession implements Session {

    private final Log log = LogFactory.getLog(WsSession.class); // must not be static
    private static final StringManager sm = StringManager.getManager(WsSession.class);

    // An ellipsis is a single character that looks like three periods in a row
    // and is used to indicate a continuation.
    private static final byte[] ELLIPSIS_BYTES = "\u2026".getBytes(StandardCharsets.UTF_8);
    // An ellipsis is three bytes in UTF-8
    private static final int ELLIPSIS_BYTES_LEN = ELLIPSIS_BYTES.length;

    private static final boolean SEC_CONFIGURATOR_USES_IMPL_DEFAULT;

    private static AtomicLong ids = new AtomicLong(0);

    static {
        // Use fake end point and path. They are never used, they just need to
        // be sufficient to pass the validation tests.
        ServerEndpointConfig.Builder builder = ServerEndpointConfig.Builder.create(Object.class, "/");
        ServerEndpointConfig sec = builder.build();
        SEC_CONFIGURATOR_USES_IMPL_DEFAULT = sec.getConfigurator().getClass()
                .equals(DefaultServerEndpointConfigurator.class);
    }

    private final Endpoint localEndpoint;
    private final WsRemoteEndpointImplBase wsRemoteEndpoint;
    private final RemoteEndpoint.Async remoteEndpointAsync;
    private final RemoteEndpoint.Basic remoteEndpointBasic;
    private final ClassLoader applicationClassLoader;
    private final WsWebSocketContainer webSocketContainer;
    private final URI requestUri;
    private final Map> requestParameterMap;
    private final String queryString;
    private final Principal userPrincipal;
    private final EndpointConfig endpointConfig;

    private final List negotiatedExtensions;
    private final String subProtocol;
    private final Map pathParameters;
    private final boolean secure;
    private final String httpSessionId;
    private final String id;

    // Expected to handle message types of  only
    private volatile MessageHandler textMessageHandler = null;
    // Expected to handle message types of  only
    private volatile MessageHandler binaryMessageHandler = null;
    private volatile MessageHandler.Whole pongMessageHandler = null;
    private AtomicReference state = new AtomicReference<>(State.OPEN);
    private final Map userProperties = new ConcurrentHashMap<>();
    private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
    private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
    private volatile long maxIdleTimeout = 0;
    private volatile long lastActiveRead = System.currentTimeMillis();
    private volatile long lastActiveWrite = System.currentTimeMillis();
    private Map futures = new ConcurrentHashMap<>();
    private volatile Long sessionCloseTimeoutExpiry;


    /**
     * Creates a new WebSocket session for communication between the provided client and remote end points. The result
     * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling
     * {@link Endpoint#onClose(Session, CloseReason)}.
     *
     * @param clientEndpointHolder The end point managed by this code
     * @param wsRemoteEndpoint     The other / remote end point
     * @param wsWebSocketContainer The container that created this session
     * @param negotiatedExtensions The agreed extensions to use for this session
     * @param subProtocol          The agreed sub-protocol to use for this session
     * @param pathParameters       The path parameters associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param secure               Was this session initiated over a secure connection?
     * @param clientEndpointConfig The configuration information for the client end point
     *
     * @throws DeploymentException if an invalid encode is specified
     */
    public WsSession(ClientEndpointHolder clientEndpointHolder, WsRemoteEndpointImplBase wsRemoteEndpoint,
            WsWebSocketContainer wsWebSocketContainer, List negotiatedExtensions, String subProtocol,
            Map pathParameters, boolean secure, ClientEndpointConfig clientEndpointConfig)
            throws DeploymentException {
        this.wsRemoteEndpoint = wsRemoteEndpoint;
        this.wsRemoteEndpoint.setSession(this);
        this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
        this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
        this.webSocketContainer = wsWebSocketContainer;
        applicationClassLoader = Thread.currentThread().getContextClassLoader();
        wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
        this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
        this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
        this.requestUri = null;
        this.requestParameterMap = Collections.emptyMap();
        this.queryString = null;
        this.userPrincipal = null;
        this.httpSessionId = null;
        this.negotiatedExtensions = negotiatedExtensions;
        if (subProtocol == null) {
            this.subProtocol = "";
        } else {
            this.subProtocol = subProtocol;
        }
        this.pathParameters = pathParameters;
        this.secure = secure;
        this.wsRemoteEndpoint.setEncoders(clientEndpointConfig);
        this.endpointConfig = clientEndpointConfig;

        this.userProperties.putAll(endpointConfig.getUserProperties());
        this.id = Long.toHexString(ids.getAndIncrement());

        this.localEndpoint = clientEndpointHolder.getInstance(getInstanceManager());

        if (log.isTraceEnabled()) {
            log.trace(sm.getString("wsSession.created", id));
        }
    }


    /**
     * Creates a new WebSocket session for communication between the provided server and remote end points. The result
     * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling
     * {@link Endpoint#onClose(Session, CloseReason)}.
     *
     * @param wsRemoteEndpoint     The other / remote end point
     * @param wsWebSocketContainer The container that created this session
     * @param requestUri           The URI used to connect to this end point or null if this is a client
     *                                 session
     * @param requestParameterMap  The parameters associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param queryString          The query string associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param userPrincipal        The principal associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param httpSessionId        The HTTP session ID associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param negotiatedExtensions The agreed extensions to use for this session
     * @param subProtocol          The agreed sub-protocol to use for this session
     * @param pathParameters       The path parameters associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param secure               Was this session initiated over a secure connection?
     * @param serverEndpointConfig The configuration information for the server end point
     *
     * @throws DeploymentException if an invalid encode is specified
     */
    public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint, WsWebSocketContainer wsWebSocketContainer,
            URI requestUri, Map> requestParameterMap, String queryString, Principal userPrincipal,
            String httpSessionId, List negotiatedExtensions, String subProtocol,
            Map pathParameters, boolean secure, ServerEndpointConfig serverEndpointConfig)
            throws DeploymentException {

        this.wsRemoteEndpoint = wsRemoteEndpoint;
        this.wsRemoteEndpoint.setSession(this);
        this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
        this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
        this.webSocketContainer = wsWebSocketContainer;
        applicationClassLoader = Thread.currentThread().getContextClassLoader();
        wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
        this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
        this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
        this.requestUri = requestUri;
        if (requestParameterMap == null) {
            this.requestParameterMap = Collections.emptyMap();
        } else {
            this.requestParameterMap = requestParameterMap;
        }
        this.queryString = queryString;
        this.userPrincipal = userPrincipal;
        this.httpSessionId = httpSessionId;
        this.negotiatedExtensions = negotiatedExtensions;
        if (subProtocol == null) {
            this.subProtocol = "";
        } else {
            this.subProtocol = subProtocol;
        }
        this.pathParameters = pathParameters;
        this.secure = secure;
        this.wsRemoteEndpoint.setEncoders(serverEndpointConfig);
        this.endpointConfig = serverEndpointConfig;

        this.userProperties.putAll(endpointConfig.getUserProperties());
        this.id = Long.toHexString(ids.getAndIncrement());

        InstanceManager instanceManager = getInstanceManager();
        Configurator configurator = serverEndpointConfig.getConfigurator();
        Class clazz = serverEndpointConfig.getEndpointClass();

        Object endpointInstance;
        try {
            if (instanceManager == null || !isDefaultConfigurator(configurator)) {
                endpointInstance = configurator.getEndpointInstance(clazz);
                if (instanceManager != null) {
                    try {
                        instanceManager.newInstance(endpointInstance);
                    } catch (ReflectiveOperationException | NamingException e) {
                        throw new DeploymentException(sm.getString("wsSession.instanceNew"), e);
                    }
                }
            } else {
                endpointInstance = instanceManager.newInstance(clazz);
            }
        } catch (ReflectiveOperationException | NamingException e) {
            throw new DeploymentException(sm.getString("wsSession.instanceCreateFailed"), e);
        }

        if (endpointInstance instanceof Endpoint) {
            this.localEndpoint = (Endpoint) endpointInstance;
        } else {
            this.localEndpoint = new PojoEndpointServer(pathParameters, endpointInstance);
        }

        if (log.isTraceEnabled()) {
            log.trace(sm.getString("wsSession.created", id));
        }
    }


    private boolean isDefaultConfigurator(Configurator configurator) {
        if (configurator.getClass().equals(DefaultServerEndpointConfigurator.class)) {
            return true;
        }
        if (SEC_CONFIGURATOR_USES_IMPL_DEFAULT &&
                configurator.getClass().equals(ServerEndpointConfig.Configurator.class)) {
            return true;
        }
        return false;
    }


    /**
     * Creates a new WebSocket session for communication between the two provided end points. The result of
     * {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling
     * {@link Endpoint#onClose(Session, CloseReason)}.
     *
     * @param localEndpoint        The end point managed by this code
     * @param wsRemoteEndpoint     The other / remote endpoint
     * @param wsWebSocketContainer The container that created this session
     * @param requestUri           The URI used to connect to this endpoint or null is this is a client
     *                                 session
     * @param requestParameterMap  The parameters associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param queryString          The query string associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param userPrincipal        The principal associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param httpSessionId        The HTTP session ID associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param negotiatedExtensions The agreed extensions to use for this session
     * @param subProtocol          The agreed subprotocol to use for this session
     * @param pathParameters       The path parameters associated with the request that initiated this session or
     *                                 null if this is a client session
     * @param secure               Was this session initiated over a secure connection?
     * @param endpointConfig       The configuration information for the endpoint
     *
     * @throws DeploymentException if an invalid encode is specified
     *
     * @deprecated Unused. This will be removed in Tomcat 10.1
     */
    @Deprecated
    public WsSession(Endpoint localEndpoint, WsRemoteEndpointImplBase wsRemoteEndpoint,
            WsWebSocketContainer wsWebSocketContainer, URI requestUri, Map> requestParameterMap,
            String queryString, Principal userPrincipal, String httpSessionId, List negotiatedExtensions,
            String subProtocol, Map pathParameters, boolean secure, EndpointConfig endpointConfig)
            throws DeploymentException {
        this.localEndpoint = localEndpoint;
        this.wsRemoteEndpoint = wsRemoteEndpoint;
        this.wsRemoteEndpoint.setSession(this);
        this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
        this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
        this.webSocketContainer = wsWebSocketContainer;
        applicationClassLoader = Thread.currentThread().getContextClassLoader();
        wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
        this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
        this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
        this.requestUri = requestUri;
        if (requestParameterMap == null) {
            this.requestParameterMap = Collections.emptyMap();
        } else {
            this.requestParameterMap = requestParameterMap;
        }
        this.queryString = queryString;
        this.userPrincipal = userPrincipal;
        this.httpSessionId = httpSessionId;
        this.negotiatedExtensions = negotiatedExtensions;
        if (subProtocol == null) {
            this.subProtocol = "";
        } else {
            this.subProtocol = subProtocol;
        }
        this.pathParameters = pathParameters;
        this.secure = secure;
        this.wsRemoteEndpoint.setEncoders(endpointConfig);
        this.endpointConfig = endpointConfig;

        this.userProperties.putAll(endpointConfig.getUserProperties());
        this.id = Long.toHexString(ids.getAndIncrement());

        InstanceManager instanceManager = getInstanceManager();
        if (instanceManager != null) {
            try {
                instanceManager.newInstance(localEndpoint);
            } catch (Exception e) {
                throw new DeploymentException(sm.getString("wsSession.instanceNew"), e);
            }
        }

        if (log.isDebugEnabled()) {
            log.debug(sm.getString("wsSession.created", id));
        }
    }


    public InstanceManager getInstanceManager() {
        return webSocketContainer.getInstanceManager(applicationClassLoader);
    }


    @Override
    public WebSocketContainer getContainer() {
        checkState();
        return webSocketContainer;
    }


    @Override
    public void addMessageHandler(MessageHandler listener) {
        Class target = Util.getMessageType(listener);
        doAddMessageHandler(target, listener);
    }


    @Override
    public  void addMessageHandler(Class clazz, Partial handler) throws IllegalStateException {
        doAddMessageHandler(clazz, handler);
    }


    @Override
    public  void addMessageHandler(Class clazz, Whole handler) throws IllegalStateException {
        doAddMessageHandler(clazz, handler);
    }


    @SuppressWarnings("unchecked")
    private void doAddMessageHandler(Class target, MessageHandler listener) {
        checkState();

        // Message handlers that require decoders may map to text messages,
        // binary messages, both or neither.

        // The frame processing code expects binary message handlers to
        // accept ByteBuffer

        // Use the POJO message handler wrappers as they are designed to wrap
        // arbitrary objects with MessageHandlers and can wrap MessageHandlers
        // just as easily.

        Set mhResults = Util.getMessageHandlers(target, listener, endpointConfig, this);

        for (MessageHandlerResult mhResult : mhResults) {
            switch (mhResult.getType()) {
                case TEXT: {
                    if (textMessageHandler != null) {
                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText"));
                    }
                    textMessageHandler = mhResult.getHandler();
                    break;
                }
                case BINARY: {
                    if (binaryMessageHandler != null) {
                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerBinary"));
                    }
                    binaryMessageHandler = mhResult.getHandler();
                    break;
                }
                case PONG: {
                    if (pongMessageHandler != null) {
                        throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong"));
                    }
                    MessageHandler handler = mhResult.getHandler();
                    if (handler instanceof MessageHandler.Whole) {
                        pongMessageHandler = (MessageHandler.Whole) handler;
                    } else {
                        throw new IllegalStateException(sm.getString("wsSession.invalidHandlerTypePong"));
                    }

                    break;
                }
                default: {
                    throw new IllegalArgumentException(
                            sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType()));
                }
            }
        }
    }


    @Override
    public Set getMessageHandlers() {
        checkState();
        Set result = new HashSet<>();
        if (binaryMessageHandler != null) {
            result.add(binaryMessageHandler);
        }
        if (textMessageHandler != null) {
            result.add(textMessageHandler);
        }
        if (pongMessageHandler != null) {
            result.add(pongMessageHandler);
        }
        return result;
    }


    @Override
    public void removeMessageHandler(MessageHandler listener) {
        checkState();
        if (listener == null) {
            return;
        }

        MessageHandler wrapped = null;

        if (listener instanceof WrappedMessageHandler) {
            wrapped = ((WrappedMessageHandler) listener).getWrappedHandler();
        }

        if (wrapped == null) {
            wrapped = listener;
        }

        boolean removed = false;
        if (wrapped.equals(textMessageHandler) || listener.equals(textMessageHandler)) {
            textMessageHandler = null;
            removed = true;
        }

        if (wrapped.equals(binaryMessageHandler) || listener.equals(binaryMessageHandler)) {
            binaryMessageHandler = null;
            removed = true;
        }

        if (wrapped.equals(pongMessageHandler) || listener.equals(pongMessageHandler)) {
            pongMessageHandler = null;
            removed = true;
        }

        if (!removed) {
            // ISE for now. Could swallow this silently / log this if the ISE
            // becomes a problem
            throw new IllegalStateException(sm.getString("wsSession.removeHandlerFailed", listener));
        }
    }


    @Override
    public String getProtocolVersion() {
        checkState();
        return Constants.WS_VERSION_HEADER_VALUE;
    }


    @Override
    public String getNegotiatedSubprotocol() {
        checkState();
        return subProtocol;
    }


    @Override
    public List getNegotiatedExtensions() {
        checkState();
        return negotiatedExtensions;
    }


    @Override
    public boolean isSecure() {
        checkState();
        return secure;
    }


    @Override
    public boolean isOpen() {
        return state.get() == State.OPEN || state.get() == State.OUTPUT_CLOSING || state.get() == State.CLOSING;
    }


    public boolean isClosed() {
        return state.get() == State.CLOSED;
    }


    @Override
    public long getMaxIdleTimeout() {
        checkState();
        return maxIdleTimeout;
    }


    @Override
    public void setMaxIdleTimeout(long timeout) {
        checkState();
        this.maxIdleTimeout = timeout;
    }


    @Override
    public void setMaxBinaryMessageBufferSize(int max) {
        checkState();
        this.maxBinaryMessageBufferSize = max;
    }


    @Override
    public int getMaxBinaryMessageBufferSize() {
        checkState();
        return maxBinaryMessageBufferSize;
    }


    @Override
    public void setMaxTextMessageBufferSize(int max) {
        checkState();
        this.maxTextMessageBufferSize = max;
    }


    @Override
    public int getMaxTextMessageBufferSize() {
        checkState();
        return maxTextMessageBufferSize;
    }


    @Override
    public Set getOpenSessions() {
        checkState();
        return webSocketContainer.getOpenSessions(getSessionMapKey());
    }


    @Override
    public RemoteEndpoint.Async getAsyncRemote() {
        checkState();
        return remoteEndpointAsync;
    }


    @Override
    public RemoteEndpoint.Basic getBasicRemote() {
        checkState();
        return remoteEndpointBasic;
    }


    @Override
    public void close() throws IOException {
        close(new CloseReason(CloseCodes.NORMAL_CLOSURE, ""));
    }


    @Override
    public void close(CloseReason closeReason) throws IOException {
        doClose(closeReason, closeReason);
    }


    /**
     * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006
     * on timeout.
     *
     * @param closeReasonMessage The close reason to pass to the remote endpoint
     * @param closeReasonLocal   The close reason to pass to the local endpoint
     */
    public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal) {
        doClose(closeReasonMessage, closeReasonLocal, false);
    }


    /**
     * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006
     * on timeout.
     *
     * @param closeReasonMessage The close reason to pass to the remote endpoint
     * @param closeReasonLocal   The close reason to pass to the local endpoint
     * @param closeSocket        Should the socket be closed immediately rather than waiting for the server to respond
     */
    public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal, boolean closeSocket) {

        if (!state.compareAndSet(State.OPEN, State.OUTPUT_CLOSING)) {
            // Close process has already been started. Don't start it again.
            return;
        }

        if (log.isTraceEnabled()) {
            log.trace(sm.getString("wsSession.doClose", id));
        }

        // Flush any batched messages not yet sent.
        try {
            wsRemoteEndpoint.setBatchingAllowed(false);
        } catch (Throwable t) {
            ExceptionUtils.handleThrowable(t);
            log.warn(sm.getString("wsSession.flushFailOnClose"), t);
            fireEndpointOnError(t);
        }

        // Send the close message to the remote endpoint.
        sendCloseMessage(closeReasonMessage);
        fireEndpointOnClose(closeReasonLocal);
        if (!state.compareAndSet(State.OUTPUT_CLOSING, State.OUTPUT_CLOSED) || closeSocket) {
            /*
             * A close message was received in another thread or this is handling an error condition. Either way, no
             * further close message is expected to be received. Mark the session as fully closed...
             */
            state.set(State.CLOSED);
            // ... and close the network connection.
            closeConnection();
        } else {
            /*
             * Set close timeout. If the client fails to send a close message response within the timeout, the session
             * and the connection will be closed when the timeout expires.
             */
            sessionCloseTimeoutExpiry =
                    Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
        }

        // Fail any uncompleted messages.
        IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
        SendResult sr = new SendResult(ioe);
        for (FutureToSendHandler f2sh : futures.keySet()) {
            f2sh.onResult(sr);
        }
    }


    /**
     * Called when a close message is received. Should only ever happen once. Also called after a protocol error when
     * the ProtocolHandler needs to force the closing of the connection.
     *
     * @param closeReason The reason contained within the received close message.
     */
    public void onClose(CloseReason closeReason) {
        if (log.isTraceEnabled()) {
            log.trace(sm.getString("wsSession.onClose.entry", closeReason, getId(), state));
        }
        if (state.compareAndSet(State.OPEN, State.CLOSING)) {
            // Standard close.

            // Flush any batched messages not yet sent.
            try {
                wsRemoteEndpoint.setBatchingAllowed(false);
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                log.warn(sm.getString("wsSession.flushFailOnClose"), t);
                fireEndpointOnError(t);
            }

            // Send the close message response to the remote endpoint.
            sendCloseMessage(closeReason);
            fireEndpointOnClose(closeReason);

            // Mark the session as fully closed.
            state.set(State.CLOSED);

            // Close the network connection.
            closeConnection();
        } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
            /*
             * The local endpoint sent a close message the the same time as the remote endpoint. The local close is
             * still being processed. Update the state so the the local close process will also close the network
             * connection once it has finished sending a close message.
             */
        } else if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
            /*
             * The local endpoint sent the first close message. The remote endpoint has now responded with its own close
             * message so mark the session as fully closed and close the network connection.
             */
            closeConnection();
        }
        // CLOSING and CLOSED are NO-OPs
    }


    private void closeConnection() {
        /*
         * Close the network connection.
         */
        wsRemoteEndpoint.close();
        /*
         * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
         * tracking the session close timeout.
         */
        webSocketContainer.unregisterSession(getSessionMapKey(), this);
    }


    /*
     * Returns the session close timeout in milliseconds
     */
    protected long getSessionCloseTimeout() {
        long result = 0;
        Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
        if (obj instanceof Long) {
            result = ((Long) obj).intValue();
        }
        if (result <= 0) {
            result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
        }
        return result;
    }


    /*
     * Returns the session close timeout in milliseconds
     */
    private long getAbnormalSessionCloseSendTimeout() {
        long result = 0;
        Object obj = userProperties.get(Constants.ABNORMAL_SESSION_CLOSE_SEND_TIMEOUT_PROPERTY);
        if (obj instanceof Long) {
            result = ((Long) obj).longValue();
        }
        if (result <= 0) {
            result = Constants.DEFAULT_ABNORMAL_SESSION_CLOSE_SEND_TIMEOUT;
        }
        return result;
    }


    protected void checkCloseTimeout() {
        // Skip the check if no session close timeout has been set.
        if (sessionCloseTimeoutExpiry != null) {
            // Check if the timeout has expired.
            if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
                // Check if the session has been closed in another thread while the timeout was being processed.
                if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
                    closeConnection();
                }
            }
        }
    }


    private void fireEndpointOnClose(CloseReason closeReason) {

        // Fire the onClose event
        Throwable throwable = null;
        InstanceManager instanceManager = getInstanceManager();
        Thread t = Thread.currentThread();
        ClassLoader cl = t.getContextClassLoader();
        t.setContextClassLoader(applicationClassLoader);
        try {
            localEndpoint.onClose(this, closeReason);
        } catch (Throwable t1) {
            ExceptionUtils.handleThrowable(t1);
            throwable = t1;
        } finally {
            if (instanceManager != null) {
                try {
                    instanceManager.destroyInstance(localEndpoint);
                } catch (Throwable t2) {
                    ExceptionUtils.handleThrowable(t2);
                    if (throwable == null) {
                        throwable = t2;
                    }
                }
            }
            t.setContextClassLoader(cl);
        }

        if (throwable != null) {
            fireEndpointOnError(throwable);
        }
    }


    private void fireEndpointOnError(Throwable throwable) {

        // Fire the onError event
        Thread t = Thread.currentThread();
        ClassLoader cl = t.getContextClassLoader();
        t.setContextClassLoader(applicationClassLoader);
        try {
            localEndpoint.onError(this, throwable);
        } finally {
            t.setContextClassLoader(cl);
        }
    }


    private void sendCloseMessage(CloseReason closeReason) {
        // 125 is maximum size for the payload of a control message
        ByteBuffer msg = ByteBuffer.allocate(125);
        CloseCode closeCode = closeReason.getCloseCode();
        // CLOSED_ABNORMALLY should not be put on the wire
        if (closeCode == CloseCodes.CLOSED_ABNORMALLY) {
            // PROTOCOL_ERROR is probably better than GOING_AWAY here
            msg.putShort((short) CloseCodes.PROTOCOL_ERROR.getCode());
        } else {
            msg.putShort((short) closeCode.getCode());
        }

        String reason = closeReason.getReasonPhrase();
        if (reason != null && reason.length() > 0) {
            appendCloseReasonWithTruncation(msg, reason);
        }
        msg.flip();
        try {
            if (closeCode == CloseCodes.NORMAL_CLOSURE) {
                wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true);
            } else {
                wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true,
                        getAbnormalSessionCloseSendTimeout());
            }
        } catch (IOException | IllegalStateException e) {
            // Failed to send close message. Close the socket and let the caller
            // deal with the Exception
            if (log.isDebugEnabled()) {
                log.debug(sm.getString("wsSession.sendCloseFail", id), e);
            }
            closeConnection();
            // Failure to send a close message is not unexpected in the case of
            // an abnormal closure (usually triggered by a failure to read/write
            // from/to the client. In this case do not trigger the endpoint's
            // error handling
            if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
                localEndpoint.onError(this, e);
            }
        }
    }


    private Object getSessionMapKey() {
        if (endpointConfig instanceof ServerEndpointConfig) {
            // Server
            return ((ServerEndpointConfig) endpointConfig).getPath();
        } else {
            // Client
            return localEndpoint;
        }
    }

    /**
     * Use protected so unit tests can access this method directly.
     *
     * @param msg    The message
     * @param reason The reason
     */
    protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) {
        // Once the close code has been added there are a maximum of 123 bytes
        // left for the reason phrase. If it is truncated then care needs to be
        // taken to ensure the bytes are not truncated in the middle of a
        // multi-byte UTF-8 character.
        byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8);

        if (reasonBytes.length <= 123) {
            // No need to truncate
            msg.put(reasonBytes);
        } else {
            // Need to truncate
            int remaining = 123 - ELLIPSIS_BYTES_LEN;
            int pos = 0;
            byte[] bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8);
            while (remaining >= bytesNext.length) {
                msg.put(bytesNext);
                remaining -= bytesNext.length;
                pos++;
                bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8);
            }
            msg.put(ELLIPSIS_BYTES);
        }
    }


    /**
     * Make the session aware of a {@link FutureToSendHandler} that will need to be forcibly closed if the session
     * closes before the {@link FutureToSendHandler} completes.
     *
     * @param f2sh The handler
     */
    protected void registerFuture(FutureToSendHandler f2sh) {
        // Ideally, this code should sync on stateLock so that the correct
        // action is taken based on the current state of the connection.
        // However, a sync on stateLock can't be used here as it will create the
        // possibility of a dead-lock. See BZ 61183.
        // Therefore, a slightly less efficient approach is used.

        // Always register the future.
        futures.put(f2sh, f2sh);

        if (isOpen()) {
            // The session is open. The future has been registered with the open
            // session. Normal processing continues.
            return;
        }

        // The session is closing / closed. The future may or may not have been registered
        // in time for it to be processed during session closure.

        if (f2sh.isDone()) {
            // The future has completed. It is not known if the future was
            // completed normally by the I/O layer or in error by doClose(). It
            // doesn't matter which. There is nothing more to do here.
            return;
        }

        // The session is closing / closed. The Future had not completed when last checked.
        // There is a small timing window that means the Future may have been
        // completed since the last check. There is also the possibility that
        // the Future was not registered in time to be cleaned up during session
        // close.
        // Attempt to complete the Future with an error result as this ensures
        // that the Future completes and any client code waiting on it does not
        // hang. It is slightly inefficient since the Future may have been
        // completed in another thread or another thread may be about to
        // complete the Future but knowing if this is the case requires the sync
        // on stateLock (see above).
        // Note: If multiple attempts are made to complete the Future, the
        // second and subsequent attempts are ignored.

        IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
        SendResult sr = new SendResult(ioe);
        f2sh.onResult(sr);
    }


    /**
     * Remove a {@link FutureToSendHandler} from the set of tracked instances.
     *
     * @param f2sh The handler
     */
    protected void unregisterFuture(FutureToSendHandler f2sh) {
        futures.remove(f2sh);
    }


    @Override
    public URI getRequestURI() {
        checkState();
        return requestUri;
    }


    @Override
    public Map> getRequestParameterMap() {
        checkState();
        return requestParameterMap;
    }


    @Override
    public String getQueryString() {
        checkState();
        return queryString;
    }


    @Override
    public Principal getUserPrincipal() {
        checkState();
        return getUserPrincipalInternal();
    }


    public Principal getUserPrincipalInternal() {
        return userPrincipal;
    }


    @Override
    public Map getPathParameters() {
        checkState();
        return pathParameters;
    }


    @Override
    public String getId() {
        return id;
    }


    @Override
    public Map getUserProperties() {
        checkState();
        return userProperties;
    }


    public Endpoint getLocal() {
        return localEndpoint;
    }


    public String getHttpSessionId() {
        return httpSessionId;
    }


    protected MessageHandler getTextMessageHandler() {
        return textMessageHandler;
    }


    protected MessageHandler getBinaryMessageHandler() {
        return binaryMessageHandler;
    }


    protected MessageHandler.Whole getPongMessageHandler() {
        return pongMessageHandler;
    }


    protected void updateLastActiveRead() {
        lastActiveRead = System.currentTimeMillis();
    }


    protected void updateLastActiveWrite() {
        lastActiveWrite = System.currentTimeMillis();
    }


    protected void checkExpiration() {
        // Local copies to ensure consistent behaviour during method execution
        long timeout = maxIdleTimeout;
        long timeoutRead = getMaxIdleTimeoutRead();
        long timeoutWrite = getMaxIdleTimeoutWrite();

        long currentTime = System.currentTimeMillis();
        String key = null;

        if (timeoutRead > 0 && (currentTime - lastActiveRead) > timeoutRead) {
            key = "wsSession.timeoutRead";
        } else if (timeoutWrite > 0 && (currentTime - lastActiveWrite) > timeoutWrite) {
            key = "wsSession.timeoutWrite";
        } else if (timeout > 0 && (currentTime - lastActiveRead) > timeout &&
                (currentTime - lastActiveWrite) > timeout) {
            key = "wsSession.timeout";
        }

        if (key != null) {
            String msg = sm.getString(key, getId());
            if (log.isDebugEnabled()) {
                log.debug(msg);
            }
            doClose(new CloseReason(CloseCodes.GOING_AWAY, msg), new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg));
        }
    }


    private long getMaxIdleTimeoutRead() {
        Object timeout = userProperties.get(Constants.READ_IDLE_TIMEOUT_MS);
        if (timeout instanceof Long) {
            return ((Long) timeout).longValue();
        }
        return 0;
    }


    private long getMaxIdleTimeoutWrite() {
        Object timeout = userProperties.get(Constants.WRITE_IDLE_TIMEOUT_MS);
        if (timeout instanceof Long) {
            return ((Long) timeout).longValue();
        }
        return 0;
    }


    private void checkState() {
        if (isClosed()) {
            /*
             * As per RFC 6455, a WebSocket connection is considered to be closed once a peer has sent and received a
             * WebSocket close frame.
             */
            throw new IllegalStateException(sm.getString("wsSession.closed", id));
        }
    }

    private enum State {
        OPEN,
        OUTPUT_CLOSING,
        OUTPUT_CLOSED,
        CLOSING,
        CLOSED
    }


    private WsFrameBase wsFrame;

    void setWsFrame(WsFrameBase wsFrame) {
        this.wsFrame = wsFrame;
    }


    /**
     * Suspends the reading of the incoming messages.
     */
    public void suspend() {
        wsFrame.suspend();
    }


    /**
     * Resumes the reading of the incoming messages.
     */
    public void resume() {
        wsFrame.resume();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy