All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.cometd.websocket.server.WebSocketTransport Maven / Gradle / Ivy

/*
 * Copyright (c) 2008-2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.cometd.websocket.server;

import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpSession;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.MessageHandler;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import javax.websocket.Session;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;

import org.cometd.bayeux.server.ServerMessage;
import org.cometd.bayeux.server.ServerSession;
import org.cometd.server.AbstractServerTransport;
import org.cometd.server.BayeuxServerImpl;
import org.cometd.websocket.server.common.AbstractBayeuxContext;
import org.cometd.websocket.server.common.AbstractWebSocketTransport;
import org.eclipse.jetty.util.Callback;

public class WebSocketTransport extends AbstractWebSocketTransport {
    public WebSocketTransport(BayeuxServerImpl bayeux) {
        super(bayeux);
    }

    @Override
    public void init() {
        super.init();

        final ServletContext context = (ServletContext)getOption(ServletContext.class.getName());
        if (context == null) {
            throw new IllegalArgumentException("Missing ServletContext");
        }

        String cometdURLMapping = (String)getOption(COMETD_URL_MAPPING_OPTION);
        if (cometdURLMapping == null) {
            throw new IllegalArgumentException("Missing '" + COMETD_URL_MAPPING_OPTION + "' parameter");
        }

        ServerContainer container = (ServerContainer)context.getAttribute(ServerContainer.class.getName());
        if (container == null) {
            throw new IllegalArgumentException("Missing WebSocket ServerContainer");
        }

        // JSR 356 does not support a input buffer size option
        int maxMessageSize = getMaxMessageSize();
        if (maxMessageSize < 0) {
            maxMessageSize = container.getDefaultMaxTextMessageBufferSize();
        }
        container.setDefaultMaxTextMessageBufferSize(maxMessageSize);

        long idleTimeout = getOption(IDLE_TIMEOUT_OPTION, container.getDefaultMaxSessionIdleTimeout());
        container.setDefaultMaxSessionIdleTimeout(idleTimeout);

        String protocol = getProtocol();
        List protocols = protocol == null ? null : Collections.singletonList(protocol);

        Configurator configurator = new Configurator(context);

        for (String mapping : normalizeURLMapping(cometdURLMapping)) {
            ServerEndpointConfig config = ServerEndpointConfig.Builder.create(WebSocketScheduler.class, mapping)
                    .subprotocols(protocols)
                    .configurator(configurator)
                    .build();
            try {
                container.addEndpoint(config);
            } catch (DeploymentException x) {
                throw new RuntimeException(x);
            }
        }
    }

    protected boolean checkOrigin(String origin) {
        return true;
    }

    protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) {
    }

    @Override
    protected void send(final Session wsSession, final ServerSession session, String data, final Callback callback) {
        if (_logger.isDebugEnabled()) {
            _logger.debug("Sending {}", data);
        }

        // Async write.
        wsSession.getAsyncRemote().sendText(data, new SendHandler() {
            @Override
            public void onResult(SendResult result) {
                Throwable failure = result.getException();
                if (failure == null) {
                    callback.succeeded();
                } else {
                    handleException(wsSession, session, failure);
                    callback.failed(failure);
                }
            }
        });
    }

    private class WebSocketScheduler extends Endpoint implements AbstractServerTransport.Scheduler, MessageHandler.Whole {
        private final AbstractWebSocketScheduler delegate;
        private volatile Session _wsSession;

        private WebSocketScheduler(WebSocketContext context) {
            delegate = new AbstractWebSocketScheduler(context) {
                @Override
                protected void close(final int code, String reason) {
                    try {
                        // Limits of the WebSocket APIs, otherwise an exception is thrown.
                        reason = reason.substring(0, Math.min(reason.length(), 30));
                        if (_logger.isDebugEnabled()) {
                            _logger.debug("Closing {}/{}", code, reason);
                        }
                        _wsSession.close(new CloseReason(CloseReason.CloseCodes.getCloseCode(code), reason));
                    } catch (Throwable x) {
                        _logger.trace("Could not close WebSocket session " + _wsSession, x);
                    }
                }

                @Override
                protected void schedule(boolean timeout, ServerMessage.Mutable expiredConnectReply) {
                    schedule(_wsSession, timeout, expiredConnectReply);
                }
            };
        }

        @Override
        public void onOpen(Session wsSession, EndpointConfig config) {
            _wsSession = wsSession;
            wsSession.addMessageHandler(this);
        }

        @Override
        public void onClose(Session wsSession, CloseReason closeReason) {
            delegate.onClose(closeReason.getCloseCode().getCode(), closeReason.getReasonPhrase());
        }

        @Override
        public void onError(Session wsSession, Throwable failure) {
            delegate.onError(failure);
        }

        @Override
        public void cancel() {
            delegate.cancel();
        }

        @Override
        public void schedule() {
            delegate.schedule();
        }

        @Override
        public void onMessage(String data) {
            if (_logger.isDebugEnabled()) {
                _logger.debug("WebSocket Text message on {}@{}/{}@{}",
                        WebSocketTransport.this.getClass().getSimpleName(),
                        Integer.toHexString(WebSocketTransport.this.hashCode()),
                        getClass().getSimpleName(),
                        Integer.toHexString(hashCode()));
            }
            delegate.onMessage(_wsSession, data);
        }
    }

    private class WebSocketContext extends AbstractBayeuxContext {
        private WebSocketContext(ServletContext context, HandshakeRequest request, Map userProperties) {
            super(context, request.getRequestURI().toString(), request.getQueryString(), request.getHeaders(),
                    request.getParameterMap(), request.getUserPrincipal(), (HttpSession)request.getHttpSession(),
                    // Hopefully these will become a standard, for now they are Jetty specific.
                    (InetSocketAddress)userProperties.get("javax.websocket.endpoint.localAddress"),
                    (InetSocketAddress)userProperties.get("javax.websocket.endpoint.remoteAddress"),
                    retrieveLocales(userProperties));
        }
    }

    private static List retrieveLocales(Map userProperties) {
        @SuppressWarnings("unchecked")
        List locales = (List)userProperties.get("javax.websocket.upgrade.locales");
        if (locales == null || locales.isEmpty()) {
            return Collections.singletonList(Locale.getDefault());
        }
        return locales;
    }

    private class Configurator extends ServerEndpointConfig.Configurator {
        private final ServletContext servletContext;

        private Configurator(ServletContext servletContext) {
            this.servletContext = servletContext;
        }

        @Override
        public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
            ContextHolder context = provideContext();
            context.bayeuxContext = new WebSocketContext(servletContext, request, sec.getUserProperties());
            WebSocketTransport.this.modifyHandshake(request, response);
        }

        @Override
        public boolean checkOrigin(String originHeaderValue) {
            return WebSocketTransport.this.checkOrigin(originHeaderValue);
        }

        @Override
        public String getNegotiatedSubprotocol(List supported, List requested) {
            ContextHolder context = provideContext();
            context.protocolMatches = checkProtocol(supported, requested);
            if (context.protocolMatches) {
                return super.getNegotiatedSubprotocol(supported, requested);
            }
            _logger.warn("Could not negotiate WebSocket SubProtocols: server{} != client{}", supported, requested);
            return null;
        }

        @Override
        public List getNegotiatedExtensions(List installed, List requested) {
            List negotiated = new ArrayList<>();
            for (Extension extension : requested) {
                String name = extension.getName();
                boolean option = getOption(ENABLE_EXTENSION_PREFIX_OPTION + name, true);
                if (option) {
                    negotiated.add(extension);
                }
            }
            return negotiated;
        }

        @Override
        @SuppressWarnings("unchecked")
        public  T getEndpointInstance(Class endpointClass) throws InstantiationException {
            ContextHolder context = provideContext();
            if (!getBayeux().getAllowedTransports().contains(getName())) {
                throw new InstantiationException("Transport not allowed");
            }
            if (!context.protocolMatches) {
                throw new InstantiationException("Could not negotiate WebSocket SubProtocols");
            }
            T instance = (T)new WebSocketScheduler(context.bayeuxContext);
            context.clear();
            return instance;
        }

        private ContextHolder provideContext() {
            ContextHolder result = ContextHolder.holder.get();
            if (result == null) {
                result = new ContextHolder();
                result.clear();
                ContextHolder.holder.set(result);
            }
            return result;
        }
    }

    private static class ContextHolder {
        private static final ThreadLocal holder = new ThreadLocal<>();
        private WebSocketContext bayeuxContext;
        private boolean protocolMatches;

        public void clear() {
            ContextHolder.holder.set(null);
            bayeuxContext = null;
            // Use a sensible default in case getNegotiatedSubprotocol() is not invoked.
            protocolMatches = true;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy