All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.activemq.transport.ws.jetty9.WSServlet Maven / Gradle / Ivy

There is a newer version: 6.1.2
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.activemq.transport.ws.jetty9;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.TransportFactory;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.apache.activemq.transport.ws.WSTransportProxy;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

/**
 * Handle connection upgrade requests and creates web sockets
 */
public class WSServlet extends WebSocketServlet implements BrokerServiceAware {

    private static final long serialVersionUID = -4716657876092884139L;

    private TransportAcceptListener listener;

    private final static Map stompProtocols = new ConcurrentHashMap<>();
    private final static Map mqttProtocols = new ConcurrentHashMap<>();

    private Map transportOptions;
    private BrokerService brokerService;

    private enum Protocol {
        MQTT, STOMP, UNKNOWN
    }

    static {
        stompProtocols.put("v12.stomp", 3);
        stompProtocols.put("v11.stomp", 2);
        stompProtocols.put("v10.stomp", 1);
        stompProtocols.put("stomp", 0);

        mqttProtocols.put("mqttv3.1", 1);
        mqttProtocols.put("mqtt", 0);
    }

    @Override
    public void init() throws ServletException {
        super.init();
        listener = (TransportAcceptListener) getServletContext().getAttribute("acceptListener");
        if (listener == null) {
            throw new ServletException("No such attribute 'acceptListener' available in the ServletContext");
        }
    }

    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        //return empty response - AMQ-6491
    }

    @Override
    public void configure(WebSocketServletFactory factory) {
        factory.setCreator(new WebSocketCreator() {
            @Override
            public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
                WebSocketListener socket;
                Protocol requestedProtocol = Protocol.UNKNOWN;

                // When no sub-protocol is requested we default to STOMP for legacy reasons.
                if (!req.getSubProtocols().isEmpty()) {
                    for (String subProtocol : req.getSubProtocols()) {
                        if (subProtocol.startsWith("mqtt")) {
                            requestedProtocol = Protocol.MQTT;
                        } else if (subProtocol.contains("stomp")) {
                            requestedProtocol = Protocol.STOMP;
                        }
                    }
                } else {
                    requestedProtocol = Protocol.STOMP;
                }

                switch (requestedProtocol) {
                    case MQTT:
                        socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                        ((MQTTSocket) socket).setTransportOptions(new HashMap<>(transportOptions));
                        ((MQTTSocket) socket).setPeerCertificates(req.getCertificates());
                        resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt"));
                        break;
                    case UNKNOWN:
                        socket = findWSTransport(req, resp);
                        if (socket != null) {
                            break;
                        }
                    case STOMP:
                        socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                        ((StompSocket) socket).setPeerCertificates(req.getCertificates());
                        resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp"));
                        break;
                    default:
                        socket = null;
                        listener.onAcceptError(new IOException("Unknown protocol requested"));
                        break;
                }

                if (socket != null) {
                    listener.onAccept((Transport) socket);
                }

                return socket;
            }
        });
    }

    private WebSocketListener findWSTransport(ServletUpgradeRequest request, ServletUpgradeResponse response) {
        WSTransportProxy proxy = null;

        for (String subProtocol : request.getSubProtocols()) {
            try {
                String remoteAddress = HttpTransportUtils.generateWsRemoteAddress(request.getHttpServletRequest(), subProtocol);
                URI remoteURI = new URI(remoteAddress);

                TransportFactory factory = TransportFactory.findTransportFactory(remoteURI);

                if (factory instanceof BrokerServiceAware) {
                    ((BrokerServiceAware) factory).setBrokerService(brokerService);
                }

                Transport transport = factory.doConnect(remoteURI);

                proxy = new WSTransportProxy(remoteAddress, transport);
                proxy.setPeerCertificates(request.getCertificates());
                proxy.setTransportOptions(transportOptions);

                response.setAcceptedSubProtocol(proxy.getSubProtocol());
            } catch (Exception e) {
                proxy = null;

                // Keep going and try any other sub-protocols present.
                continue;
            }
        }

        return proxy;
    }

    private String getAcceptedSubProtocol(final Map protocols, List subProtocols, String defaultProtocol) {
        List matchedProtocols = new ArrayList<>();
        if (subProtocols != null && subProtocols.size() > 0) {
            // detect which subprotocols match accepted protocols and add to the
            // list
            for (String subProtocol : subProtocols) {
                Integer priority = protocols.get(subProtocol);
                if (subProtocol != null && priority != null) {
                    // only insert if both subProtocol and priority are not null
                    matchedProtocols.add(new SubProtocol(subProtocol, priority));
                }
            }
            // sort the list by priority
            if (matchedProtocols.size() > 0) {
                Collections.sort(matchedProtocols, new Comparator() {
                    @Override
                    public int compare(SubProtocol s1, SubProtocol s2) {
                        return s2.priority.compareTo(s1.priority);
                    }
                });
                return matchedProtocols.get(0).protocol;
            }
        }
        return defaultProtocol;
    }

    private class SubProtocol {
        private String protocol;
        private Integer priority;

        public SubProtocol(String protocol, Integer priority) {
            this.protocol = protocol;
            this.priority = priority;
        }
    }

    public void setTransportOptions(Map transportOptions) {
        this.transportOptions = transportOptions;
    }

    @Override
    public void setBrokerService(BrokerService brokerService) {
        this.brokerService = brokerService;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy