All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tomcat.websocket.WsWebSocketContainer Maven / Gradle / Ivy

There is a newer version: 11.0.1
Show newest version
/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.ProxySelector;
import java.net.SocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousChannelGroup;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.TrustManagerFactory;
import javax.websocket.ClientEndpoint;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;

import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.InstanceManager;
import org.apache.tomcat.InstanceManagerBindings;
import org.apache.tomcat.util.buf.StringUtils;
import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.security.KeyStoreUtil;

public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess {

    private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class);
    private static final Random RANDOM = new Random();
    private static final byte[] CRLF = new byte[] { 13, 10 };

    private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1);
    private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1);
    private static final byte[] HTTP_VERSION_BYTES = " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1);

    private volatile AsynchronousChannelGroup asynchronousChannelGroup = null;
    private final Object asynchronousChannelGroupLock = new Object();

    private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static
    // Server side uses the endpoint path as the key
    // Client side uses the client endpoint instance
    private final Map> endpointSessionMap = new HashMap<>();
    private final Map sessions = new ConcurrentHashMap<>();
    private final Object endPointSessionMapLock = new Object();

    private long defaultAsyncTimeout = -1;
    private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
    private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
    private volatile long defaultMaxSessionIdleTimeout = 0;
    private int backgroundProcessCount = 0;
    private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD;

    private InstanceManager instanceManager;

    protected InstanceManager getInstanceManager(ClassLoader classLoader) {
        if (instanceManager != null) {
            return instanceManager;
        }
        return InstanceManagerBindings.get(classLoader);
    }

    protected void setInstanceManager(InstanceManager instanceManager) {
        this.instanceManager = instanceManager;
    }

    @Override
    public Session connectToServer(Object pojo, URI path) throws DeploymentException {
        ClientEndpointConfig config = createClientEndpointConfig(pojo.getClass());
        ClientEndpointHolder holder = new PojoHolder(pojo, config);
        return connectToServerRecursive(holder, config, path, new HashSet<>());
    }


    @Override
    public Session connectToServer(Class annotatedEndpointClass, URI path) throws DeploymentException {
        ClientEndpointConfig config = createClientEndpointConfig(annotatedEndpointClass);
        ClientEndpointHolder holder = new PojoClassHolder(annotatedEndpointClass, config);
        return connectToServerRecursive(holder, config, path, new HashSet<>());
    }


    private ClientEndpointConfig createClientEndpointConfig(Class annotatedEndpointClass)
            throws DeploymentException {
        ClientEndpoint annotation = annotatedEndpointClass.getAnnotation(ClientEndpoint.class);
        if (annotation == null) {
            throw new DeploymentException(
                    sm.getString("wsWebSocketContainer.missingAnnotation", annotatedEndpointClass.getName()));
        }

        Class configuratorClazz = annotation.configurator();

        ClientEndpointConfig.Configurator configurator = null;
        if (!ClientEndpointConfig.Configurator.class.equals(configuratorClazz)) {
            try {
                configurator = configuratorClazz.getConstructor().newInstance();
            } catch (ReflectiveOperationException e) {
                throw new DeploymentException(sm.getString("wsWebSocketContainer.defaultConfiguratorFail"), e);
            }
        }

        ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create();
        // Avoid NPE when using RI API JAR - see BZ 56343
        if (configurator != null) {
            builder.configurator(configurator);
        }
        ClientEndpointConfig config = builder.decoders(Arrays.asList(annotation.decoders()))
                .encoders(Arrays.asList(annotation.encoders()))
                .preferredSubprotocols(Arrays.asList(annotation.subprotocols())).build();

        return config;
    }


    @Override
    public Session connectToServer(Class clazz, ClientEndpointConfig clientEndpointConfiguration,
            URI path) throws DeploymentException {
        ClientEndpointHolder holder = new EndpointClassHolder(clazz);
        return connectToServerRecursive(holder, clientEndpointConfiguration, path, new HashSet<>());
    }


    @Override
    public Session connectToServer(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path)
            throws DeploymentException {
        ClientEndpointHolder holder = new EndpointHolder(endpoint);
        return connectToServerRecursive(holder, clientEndpointConfiguration, path, new HashSet<>());
    }


    private Session connectToServerRecursive(ClientEndpointHolder clientEndpointHolder,
            ClientEndpointConfig clientEndpointConfiguration, URI path, Set redirectSet)
            throws DeploymentException {

        if (log.isTraceEnabled()) {
            log.trace(sm.getString("wsWebSocketContainer.connect.entry", clientEndpointHolder.getClassName(), path));
        }

        boolean secure = false;
        ByteBuffer proxyConnect = null;
        URI proxyPath;

        // Validate scheme (and build proxyPath)
        String scheme = path.getScheme();
        if ("ws".equalsIgnoreCase(scheme)) {
            proxyPath = URI.create("http" + path.toString().substring(2));
        } else if ("wss".equalsIgnoreCase(scheme)) {
            proxyPath = URI.create("https" + path.toString().substring(3));
            secure = true;
        } else {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.pathWrongScheme", scheme));
        }

        // Validate host
        String host = path.getHost();
        if (host == null) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.pathNoHost"));
        }
        int port = path.getPort();

        SocketAddress sa = null;

        // Check to see if a proxy is configured. Javadoc indicates return value
        // will never be null
        List proxies = ProxySelector.getDefault().select(proxyPath);
        Proxy selectedProxy = null;
        for (Proxy proxy : proxies) {
            if (proxy.type().equals(Proxy.Type.HTTP)) {
                sa = proxy.address();
                if (sa instanceof InetSocketAddress) {
                    InetSocketAddress inet = (InetSocketAddress) sa;
                    if (inet.isUnresolved()) {
                        sa = new InetSocketAddress(inet.getHostName(), inet.getPort());
                    }
                }
                selectedProxy = proxy;
                break;
            }
        }

        // If the port is not explicitly specified, compute it based on the
        // scheme
        if (port == -1) {
            if ("ws".equalsIgnoreCase(scheme)) {
                port = 80;
            } else {
                // Must be wss due to scheme validation above
                port = 443;
            }
        }

        Map userProperties = clientEndpointConfiguration.getUserProperties();

        // If sa is null, no proxy is configured so need to create sa
        if (sa == null) {
            sa = new InetSocketAddress(host, port);
        } else {
            proxyConnect = createProxyRequest(host, port,
                    (String) userProperties.get(Constants.PROXY_AUTHORIZATION_HEADER_NAME));
        }

        // Create the initial HTTP request to open the WebSocket connection
        Map> reqHeaders = createRequestHeaders(host, port, secure, clientEndpointConfiguration);
        clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders);
        if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) {
            List originValues = new ArrayList<>(1);
            originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE);
            reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues);
        }
        ByteBuffer request = createRequest(path, reqHeaders);

        // Get the connection timeout
        long timeout = Constants.IO_TIMEOUT_MS_DEFAULT;
        String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY);
        if (timeoutValue != null) {
            timeout = Long.valueOf(timeoutValue).intValue();
        }

        AsynchronousSocketChannel socketChannel;
        try {
            socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup());
        } catch (IOException ioe) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.asynchronousSocketChannelFail"), ioe);
        }

        // Set-up
        // Same size as the WsFrame input buffer
        ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize());
        String subProtocol;
        boolean success = false;
        List extensionsAgreed = new ArrayList<>();
        Transformation transformation = null;
        AsyncChannelWrapper channel = null;

        try {
            // Open the connection
            Future fConnect = socketChannel.connect(sa);

            if (proxyConnect != null) {
                fConnect.get(timeout, TimeUnit.MILLISECONDS);
                // Proxy CONNECT is clear text
                channel = new AsyncChannelWrapperNonSecure(socketChannel);
                writeRequest(channel, proxyConnect, timeout);
                HttpResponse httpResponse = processResponse(response, channel, timeout);
                if (httpResponse.status == Constants.PROXY_AUTHENTICATION_REQUIRED) {
                    return processAuthenticationChallenge(clientEndpointHolder, clientEndpointConfiguration, path,
                            redirectSet, userProperties, request, httpResponse, AuthenticationType.PROXY);
                } else if (httpResponse.getStatus() != 200) {
                    throw new DeploymentException(sm.getString("wsWebSocketContainer.proxyConnectFail", selectedProxy,
                            Integer.toString(httpResponse.getStatus())));
                }
            }

            if (secure) {
                // Regardless of whether a non-secure wrapper was created for a
                // proxy CONNECT, need to use TLS from this point on so wrap the
                // original AsynchronousSocketChannel
                SSLEngine sslEngine = createSSLEngine(clientEndpointConfiguration, host, port);
                channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
            } else if (channel == null) {
                // Only need to wrap as this point if it wasn't wrapped to process a
                // proxy CONNECT
                channel = new AsyncChannelWrapperNonSecure(socketChannel);
            }

            fConnect.get(timeout, TimeUnit.MILLISECONDS);

            Future fHandshake = channel.handshake();
            fHandshake.get(timeout, TimeUnit.MILLISECONDS);

            if (log.isTraceEnabled()) {
                SocketAddress localAddress = null;
                try {
                    localAddress = channel.getLocalAddress();
                } catch (IOException ioe) {
                    // Ignore
                }
                log.trace(sm.getString("wsWebSocketContainer.connect.write", Integer.valueOf(request.position()),
                        Integer.valueOf(request.limit()), localAddress));
            }
            writeRequest(channel, request, timeout);

            HttpResponse httpResponse = processResponse(response, channel, timeout);

            // Check maximum permitted redirects
            int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT;
            String maxRedirectsValue = (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY);
            if (maxRedirectsValue != null) {
                maxRedirects = Integer.parseInt(maxRedirectsValue);
            }

            if (httpResponse.status != 101) {
                if (isRedirectStatus(httpResponse.status)) {
                    List locationHeader = httpResponse.getHandshakeResponse().getHeaders()
                            .get(Constants.LOCATION_HEADER_NAME);

                    if (locationHeader == null || locationHeader.isEmpty() || locationHeader.get(0) == null ||
                            locationHeader.get(0).isEmpty()) {
                        throw new DeploymentException(sm.getString("wsWebSocketContainer.missingLocationHeader",
                                Integer.toString(httpResponse.status)));
                    }

                    URI redirectLocation = URI.create(locationHeader.get(0)).normalize();

                    if (!redirectLocation.isAbsolute()) {
                        redirectLocation = path.resolve(redirectLocation);
                    }

                    String redirectScheme = redirectLocation.getScheme().toLowerCase();

                    if (redirectScheme.startsWith("http")) {
                        redirectLocation = new URI(redirectScheme.replace("http", "ws"), redirectLocation.getUserInfo(),
                                redirectLocation.getHost(), redirectLocation.getPort(), redirectLocation.getPath(),
                                redirectLocation.getQuery(), redirectLocation.getFragment());
                    }

                    if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) {
                        throw new DeploymentException(
                                sm.getString("wsWebSocketContainer.redirectThreshold", redirectLocation,
                                        Integer.toString(redirectSet.size()), Integer.toString(maxRedirects)));
                    }

                    return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, redirectLocation,
                            redirectSet);

                } else if (httpResponse.status == Constants.UNAUTHORIZED) {
                    return processAuthenticationChallenge(clientEndpointHolder, clientEndpointConfiguration, path,
                            redirectSet, userProperties, request, httpResponse, AuthenticationType.WWW);

                } else {
                    throw new DeploymentException(
                            sm.getString("wsWebSocketContainer.invalidStatus", Integer.toString(httpResponse.status)));
                }
            }
            HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse();
            clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse);

            // Sub-protocol
            List protocolHeaders = handshakeResponse.getHeaders().get(Constants.WS_PROTOCOL_HEADER_NAME);
            if (protocolHeaders == null || protocolHeaders.size() == 0) {
                subProtocol = null;
            } else if (protocolHeaders.size() == 1) {
                subProtocol = protocolHeaders.get(0);
            } else {
                throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidSubProtocol"));
            }

            // Extensions
            // Should normally only be one header but handle the case of
            // multiple headers
            List extHeaders = handshakeResponse.getHeaders().get(Constants.WS_EXTENSIONS_HEADER_NAME);
            if (extHeaders != null) {
                for (String extHeader : extHeaders) {
                    Util.parseExtensionHeader(extensionsAgreed, extHeader);
                }
            }

            // Build the transformations
            TransformationFactory factory = TransformationFactory.getInstance();
            for (Extension extension : extensionsAgreed) {
                List> wrapper = new ArrayList<>(1);
                wrapper.add(extension.getParameters());
                Transformation t = factory.create(extension.getName(), wrapper, false);
                if (t == null) {
                    throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidExtensionParameters"));
                }
                if (transformation == null) {
                    transformation = t;
                } else {
                    transformation.setNext(t);
                }
            }

            success = true;
        } catch (ExecutionException | InterruptedException | SSLException | EOFException | TimeoutException
                | URISyntaxException | AuthenticationException e) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.httpRequestFailed", path), e);
        } finally {
            if (!success) {
                if (channel != null) {
                    channel.close();
                } else {
                    try {
                        socketChannel.close();
                    } catch (IOException ioe) {
                        // Ignore
                    }
                }
            }
        }

        // Switch to WebSocket
        WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel);

        WsSession wsSession = new WsSession(clientEndpointHolder, wsRemoteEndpointClient, this, extensionsAgreed,
                subProtocol, Collections.emptyMap(), secure, clientEndpointConfiguration);

        WsFrameClient wsFrameClient = new WsFrameClient(response, channel, wsSession, transformation);
        // WsFrame adds the necessary final transformations. Copy the
        // completed transformation chain to the remote end point.
        wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation());

        wsSession.getLocal().onOpen(wsSession, clientEndpointConfiguration);
        registerSession(wsSession.getLocal(), wsSession);

        /*
         * It is possible that the server sent one or more messages as soon as the WebSocket connection was established.
         * Depending on the exact timing of when those messages were sent they could be sat in the input buffer waiting
         * to be read and will not trigger a "data available to read" event. Therefore, it is necessary to process the
         * input buffer here. Note that this happens on the current thread which means that this thread will be used for
         * any onMessage notifications. This is a special case. Subsequent "data available to read" events will be
         * handled by threads from the AsyncChannelGroup's executor.
         */
        wsFrameClient.startInputProcessing();

        return wsSession;
    }


    private Session processAuthenticationChallenge(ClientEndpointHolder clientEndpointHolder,
            ClientEndpointConfig clientEndpointConfiguration, URI path, Set redirectSet,
            Map userProperties, ByteBuffer request, HttpResponse httpResponse,
            AuthenticationType authenticationType) throws DeploymentException, AuthenticationException {

        if (userProperties.get(authenticationType.getAuthorizationHeaderName()) != null) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.failedAuthentication",
                    Integer.valueOf(httpResponse.status), authenticationType.getAuthorizationHeaderName()));
        }

        List authenticateHeaders = httpResponse.getHandshakeResponse().getHeaders()
                .get(authenticationType.getAuthenticateHeaderName());

        if (authenticateHeaders == null || authenticateHeaders.isEmpty() || authenticateHeaders.get(0) == null ||
                authenticateHeaders.get(0).isEmpty()) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.missingAuthenticateHeader",
                    Integer.toString(httpResponse.status), authenticationType.getAuthenticateHeaderName()));
        }

        String authScheme = authenticateHeaders.get(0).split("\\s+", 2)[0];

        Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme);

        if (auth == null) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.unsupportedAuthScheme",
                    Integer.valueOf(httpResponse.status), authScheme));
        }

        String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1).split("\\s", 3)[1];

        userProperties.put(authenticationType.getAuthorizationHeaderName(),
                auth.getAuthorization(requestUri, authenticateHeaders.get(0),
                        (String) userProperties.get(authenticationType.getUserNameProperty()),
                        (String) userProperties.get(authenticationType.getUserPasswordProperty()),
                        (String) userProperties.get(authenticationType.getUserRealmProperty())));

        return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, path, redirectSet);
    }


    private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, long timeout)
            throws TimeoutException, InterruptedException, ExecutionException {
        int toWrite = request.limit();

        Future fWrite = channel.write(request);
        Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
        toWrite -= thisWrite.intValue();

        while (toWrite > 0) {
            fWrite = channel.write(request);
            thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
            toWrite -= thisWrite.intValue();
        }
    }


    private static boolean isRedirectStatus(int httpResponseCode) {

        boolean isRedirect = false;

        switch (httpResponseCode) {
            case Constants.MULTIPLE_CHOICES:
            case Constants.MOVED_PERMANENTLY:
            case Constants.FOUND:
            case Constants.SEE_OTHER:
            case Constants.USE_PROXY:
            case Constants.TEMPORARY_REDIRECT:
                isRedirect = true;
                break;
            default:
                break;
        }

        return isRedirect;
    }


    private static ByteBuffer createProxyRequest(String host, int port, String authorizationHeader) {
        StringBuilder request = new StringBuilder();
        request.append("CONNECT ");
        request.append(host);
        request.append(':');
        request.append(port);

        request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: ");
        request.append(host);
        request.append(':');
        request.append(port);

        if (authorizationHeader != null) {
            request.append("\r\n");
            request.append(Constants.PROXY_AUTHORIZATION_HEADER_NAME);
            request.append(':');
            request.append(authorizationHeader);
        }

        request.append("\r\n\r\n");

        byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1);
        return ByteBuffer.wrap(bytes);
    }

    protected void registerSession(Object key, WsSession wsSession) {

        if (!wsSession.isOpen()) {
            // The session was closed during onOpen. No need to register it.
            return;
        }
        synchronized (endPointSessionMapLock) {
            if (endpointSessionMap.size() == 0) {
                BackgroundProcessManager.getInstance().register(this);
            }
            endpointSessionMap.computeIfAbsent(key, k -> new HashSet<>()).add(wsSession);
        }
        sessions.put(wsSession, wsSession);
    }


    protected void unregisterSession(Object key, WsSession wsSession) {

        synchronized (endPointSessionMapLock) {
            Set wsSessions = endpointSessionMap.get(key);
            if (wsSessions != null) {
                wsSessions.remove(wsSession);
                if (wsSessions.size() == 0) {
                    endpointSessionMap.remove(key);
                }
            }
            if (endpointSessionMap.size() == 0) {
                BackgroundProcessManager.getInstance().unregister(this);
            }
        }
        sessions.remove(wsSession);
    }


    Set getOpenSessions(Object key) {
        HashSet result = new HashSet<>();
        synchronized (endPointSessionMapLock) {
            Set sessions = endpointSessionMap.get(key);
            if (sessions != null) {
                // Some sessions may be in the process of closing
                for (WsSession session : sessions) {
                    if (session.isOpen()) {
                        result.add(session);
                    }
                }
            }
        }
        return result;
    }

    private static Map> createRequestHeaders(String host, int port, boolean secure,
            ClientEndpointConfig clientEndpointConfiguration) {

        Map> headers = new HashMap<>();
        List extensions = clientEndpointConfiguration.getExtensions();
        List subProtocols = clientEndpointConfiguration.getPreferredSubprotocols();
        Map userProperties = clientEndpointConfiguration.getUserProperties();

        if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
            List authValues = new ArrayList<>(1);
            authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME));
            headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues);
        }

        // Host header
        List hostValues = new ArrayList<>(1);
        if (port == 80 && !secure || port == 443 && secure) {
            // Default ports. Do not include port in host header
            hostValues.add(host);
        } else {
            hostValues.add(host + ':' + port);
        }

        headers.put(Constants.HOST_HEADER_NAME, hostValues);

        // Upgrade header
        List upgradeValues = new ArrayList<>(1);
        upgradeValues.add(Constants.UPGRADE_HEADER_VALUE);
        headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues);

        // Connection header
        List connectionValues = new ArrayList<>(1);
        connectionValues.add(Constants.CONNECTION_HEADER_VALUE);
        headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues);

        // WebSocket version header
        List wsVersionValues = new ArrayList<>(1);
        wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE);
        headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues);

        // WebSocket key
        List wsKeyValues = new ArrayList<>(1);
        wsKeyValues.add(generateWsKeyValue());
        headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues);

        // WebSocket sub-protocols
        if (subProtocols != null && subProtocols.size() > 0) {
            headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols);
        }

        // WebSocket extensions
        if (extensions != null && extensions.size() > 0) {
            headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, generateExtensionHeaders(extensions));
        }

        return headers;
    }


    private static List generateExtensionHeaders(List extensions) {
        List result = new ArrayList<>(extensions.size());
        for (Extension extension : extensions) {
            StringBuilder header = new StringBuilder();
            header.append(extension.getName());
            for (Extension.Parameter param : extension.getParameters()) {
                header.append(';');
                header.append(param.getName());
                String value = param.getValue();
                if (value != null && value.length() > 0) {
                    header.append('=');
                    header.append(value);
                }
            }
            result.add(header.toString());
        }
        return result;
    }


    private static String generateWsKeyValue() {
        byte[] keyBytes = new byte[16];
        RANDOM.nextBytes(keyBytes);
        return Base64.getEncoder().encodeToString(keyBytes);
    }


    private static ByteBuffer createRequest(URI uri, Map> reqHeaders) {
        ByteBuffer result = ByteBuffer.allocate(4 * 1024);

        // Request line
        result.put(GET_BYTES);
        final String path = uri.getPath();
        if (null == path || path.isEmpty()) {
            result.put(ROOT_URI_BYTES);
        } else {
            result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1));
        }
        String query = uri.getRawQuery();
        if (query != null) {
            result.put((byte) '?');
            result.put(query.getBytes(StandardCharsets.ISO_8859_1));
        }
        result.put(HTTP_VERSION_BYTES);

        // Headers
        for (Entry> entry : reqHeaders.entrySet()) {
            result = addHeader(result, entry.getKey(), entry.getValue());
        }

        // Terminating CRLF
        result.put(CRLF);

        result.flip();

        return result;
    }


    private static ByteBuffer addHeader(ByteBuffer result, String key, List values) {
        if (values.isEmpty()) {
            return result;
        }

        result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1));
        result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1));
        result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1));
        result = putWithExpand(result, CRLF);

        return result;
    }


    private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) {
        if (bytes.length > input.remaining()) {
            int newSize;
            if (bytes.length > input.capacity()) {
                newSize = 2 * bytes.length;
            } else {
                newSize = input.capacity() * 2;
            }
            ByteBuffer expanded = ByteBuffer.allocate(newSize);
            input.flip();
            expanded.put(input);
            input = expanded;
        }
        return input.put(bytes);
    }


    /**
     * Process response, blocking until HTTP response has been fully received.
     *
     * @throws ExecutionException   if there is an exception reading the response
     * @throws InterruptedException if the thread is interrupted while reading the response
     * @throws DeploymentException  if the response status line is not correctly formatted
     * @throws TimeoutException     if the response was not read within the expected timeout
     */
    private HttpResponse processResponse(ByteBuffer response, AsyncChannelWrapper channel, long timeout)
            throws InterruptedException, ExecutionException, DeploymentException, EOFException, TimeoutException {

        Map> headers = new CaseInsensitiveKeyMap<>();

        int status = 0;
        boolean readStatus = false;
        boolean readHeaders = false;
        String line = null;
        while (!readHeaders) {
            // On entering loop buffer will be empty and at the start of a new
            // loop the buffer will have been fully read.
            response.clear();
            // Blocking read
            Future read = channel.read(response);
            Integer bytesRead;
            try {
                bytesRead = read.get(timeout, TimeUnit.MILLISECONDS);
            } catch (TimeoutException e) {
                TimeoutException te = new TimeoutException(
                        sm.getString("wsWebSocketContainer.responseFail", Integer.toString(status), headers));
                te.initCause(e);
                throw te;
            }
            if (bytesRead.intValue() == -1) {
                throw new EOFException(
                        sm.getString("wsWebSocketContainer.responseFail", Integer.toString(status), headers));
            }
            response.flip();
            while (response.hasRemaining() && !readHeaders) {
                if (line == null) {
                    line = readLine(response);
                } else {
                    line += readLine(response);
                }
                if ("\r\n".equals(line)) {
                    readHeaders = true;
                } else if (line.endsWith("\r\n")) {
                    if (readStatus) {
                        parseHeaders(line, headers);
                    } else {
                        status = parseStatus(line);
                        readStatus = true;
                    }
                    line = null;
                }
            }
        }

        return new HttpResponse(status, new WsHandshakeResponse(headers));
    }


    private int parseStatus(String line) throws DeploymentException {
        // This client only understands HTTP 1.
        // RFC2616 is case specific
        String[] parts = line.trim().split(" ");
        // CONNECT for proxy may return a 1.0 response
        if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", line));
        }
        try {
            return Integer.parseInt(parts[1]);
        } catch (NumberFormatException nfe) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", line));
        }
    }


    private void parseHeaders(String line, Map> headers) {
        // Treat headers as single values by default.

        int index = line.indexOf(':');
        if (index == -1) {
            log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line));
            return;
        }
        // Header names are case insensitive so always use lower case
        String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH);
        // Multi-value headers are stored as a single header and the client is
        // expected to handle splitting into individual values
        String headerValue = line.substring(index + 1).trim();

        List values = headers.computeIfAbsent(headerName, k -> new ArrayList<>(1));
        values.add(headerValue);
    }

    private String readLine(ByteBuffer response) {
        // All ISO-8859-1
        StringBuilder sb = new StringBuilder();

        char c = 0;
        while (response.hasRemaining()) {
            c = (char) response.get();
            sb.append(c);
            if (c == 10) {
                break;
            }
        }

        return sb.toString();
    }


    private SSLEngine createSSLEngine(ClientEndpointConfig clientEndpointConfig, String host, int port)
            throws DeploymentException {

        Map userProperties = clientEndpointConfig.getUserProperties();
        try {
            // See if a custom SSLContext has been provided
            SSLContext sslContext = (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY);

            if (sslContext == null) {
                // Create the SSL Context
                sslContext = SSLContext.getInstance("TLS");

                // Trust store
                String sslTrustStoreValue = (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY);
                if (sslTrustStoreValue != null) {
                    String sslTrustStorePwdValue = (String) userProperties.get(Constants.SSL_TRUSTSTORE_PWD_PROPERTY);
                    if (sslTrustStorePwdValue == null) {
                        sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT;
                    }

                    File keyStoreFile = new File(sslTrustStoreValue);
                    KeyStore ks = KeyStore.getInstance("JKS");
                    try (InputStream is = new FileInputStream(keyStoreFile)) {
                        KeyStoreUtil.load(ks, is, sslTrustStorePwdValue.toCharArray());
                    }

                    TrustManagerFactory tmf = TrustManagerFactory
                            .getInstance(TrustManagerFactory.getDefaultAlgorithm());
                    tmf.init(ks);

                    sslContext.init(null, tmf.getTrustManagers(), null);
                } else {
                    sslContext.init(null, null, null);
                }
            }

            SSLEngine engine = sslContext.createSSLEngine(host, port);

            String sslProtocolsValue = (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY);
            if (sslProtocolsValue != null) {
                engine.setEnabledProtocols(sslProtocolsValue.split(","));
            }

            engine.setUseClientMode(true);

            // Enable host verification
            // Start with current settings (returns a copy)
            SSLParameters sslParams = engine.getSSLParameters();
            // Use HTTPS since WebSocket starts over HTTP(S)
            sslParams.setEndpointIdentificationAlgorithm("HTTPS");
            // Write the parameters back
            engine.setSSLParameters(sslParams);

            return engine;
        } catch (Exception e) {
            throw new DeploymentException(sm.getString("wsWebSocketContainer.sslEngineFail"), e);
        }
    }


    @Override
    public long getDefaultMaxSessionIdleTimeout() {
        return defaultMaxSessionIdleTimeout;
    }


    @Override
    public void setDefaultMaxSessionIdleTimeout(long timeout) {
        this.defaultMaxSessionIdleTimeout = timeout;
    }


    @Override
    public int getDefaultMaxBinaryMessageBufferSize() {
        return maxBinaryMessageBufferSize;
    }


    @Override
    public void setDefaultMaxBinaryMessageBufferSize(int max) {
        maxBinaryMessageBufferSize = max;
    }


    @Override
    public int getDefaultMaxTextMessageBufferSize() {
        return maxTextMessageBufferSize;
    }


    @Override
    public void setDefaultMaxTextMessageBufferSize(int max) {
        maxTextMessageBufferSize = max;
    }


    /**
     * {@inheritDoc} Currently, this implementation does not support any extensions.
     */
    @Override
    public Set getInstalledExtensions() {
        return Collections.emptySet();
    }


    /**
     * {@inheritDoc} The default value for this implementation is -1.
     */
    @Override
    public long getDefaultAsyncSendTimeout() {
        return defaultAsyncTimeout;
    }


    /**
     * {@inheritDoc} The default value for this implementation is -1.
     */
    @Override
    public void setAsyncSendTimeout(long timeout) {
        this.defaultAsyncTimeout = timeout;
    }


    /**
     * Cleans up the resources still in use by WebSocket sessions created from this container. This includes closing
     * sessions and cancelling {@link Future}s associated with blocking read/writes.
     */
    public void destroy() {
        CloseReason cr = new CloseReason(CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown"));

        for (WsSession session : sessions.keySet()) {
            try {
                session.close(cr);
            } catch (IOException ioe) {
                log.debug(sm.getString("wsWebSocketContainer.sessionCloseFail", session.getId()), ioe);
            }
        }

        // Only unregister with AsyncChannelGroupUtil if this instance
        // registered with it
        if (asynchronousChannelGroup != null) {
            synchronized (asynchronousChannelGroupLock) {
                if (asynchronousChannelGroup != null) {
                    AsyncChannelGroupUtil.unregister();
                    asynchronousChannelGroup = null;
                }
            }
        }
    }


    private AsynchronousChannelGroup getAsynchronousChannelGroup() {
        // Use AsyncChannelGroupUtil to share a common group amongst all
        // WebSocket clients
        AsynchronousChannelGroup result = asynchronousChannelGroup;
        if (result == null) {
            synchronized (asynchronousChannelGroupLock) {
                if (asynchronousChannelGroup == null) {
                    asynchronousChannelGroup = AsyncChannelGroupUtil.register();
                }
                result = asynchronousChannelGroup;
            }
        }
        return result;
    }


    // ----------------------------------------------- BackgroundProcess methods

    @Override
    public void backgroundProcess() {
        // This method gets called once a second.
        backgroundProcessCount++;
        if (backgroundProcessCount >= processPeriod) {
            backgroundProcessCount = 0;

            // Check all registered sessions.
            for (WsSession wsSession : sessions.keySet()) {
                wsSession.checkExpiration();
                wsSession.checkCloseTimeout();
            }
        }

    }


    @Override
    public void setProcessPeriod(int period) {
        this.processPeriod = period;
    }


    /**
     * {@inheritDoc} The default value is 10 which means session expirations are processed every 10 seconds.
     */
    @Override
    public int getProcessPeriod() {
        return processPeriod;
    }


    private static class HttpResponse {
        private final int status;
        private final HandshakeResponse handshakeResponse;

        HttpResponse(int status, HandshakeResponse handshakeResponse) {
            this.status = status;
            this.handshakeResponse = handshakeResponse;
        }


        public int getStatus() {
            return status;
        }


        public HandshakeResponse getHandshakeResponse() {
            return handshakeResponse;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy