All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.springboot.server.NettyRequestUpgradeStrategy Maven / Gradle / Ivy

The newest version!
package com.github.netty.springboot.server;

import com.github.netty.protocol.servlet.DispatcherChannelHandler;
import com.github.netty.protocol.servlet.ServletHttpExchange;
import com.github.netty.protocol.servlet.ServletHttpServletRequest;
import com.github.netty.protocol.servlet.util.HttpHeaderConstants;
import com.github.netty.protocol.servlet.util.ServletUtil;
import com.github.netty.protocol.servlet.websocket.NettyMessageToWebSocketRunnable;
import com.github.netty.protocol.servlet.websocket.WebSocketServerContainer;
import com.github.netty.protocol.servlet.websocket.WebSocketServerHandshaker13Extension;
import com.github.netty.protocol.servlet.websocket.WebSocketSession;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.standard.AbstractStandardUpgradeStrategy;
import org.springframework.web.socket.server.standard.ServerEndpointRegistration;

import javax.servlet.http.HttpServletRequest;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpointConfig;
import java.security.Principal;
import java.util.*;

/**
 * Websocket version number: the version number of draft 8 to draft 12 is 8, and the version number of draft 13 and later is the same as the draft number
 *
 * @author wangzihao
 */
public class NettyRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
    private int maxFramePayloadLength;
    private static final String[] SUPPORTED_VERSIONS = new String[]{WebSocketVersion.V13.toHttpHeaderValue()};

    public NettyRequestUpgradeStrategy() {
        this(65536);
    }

    public NettyRequestUpgradeStrategy(int maxFramePayloadLength) {
        this.maxFramePayloadLength = maxFramePayloadLength;
    }

    @Override
    public String[] getSupportedVersions() {
        return SUPPORTED_VERSIONS;
    }

    @Override
    protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol,
                                   List selectedExtensions, Endpoint endpoint) throws HandshakeFailureException {
        HttpServletRequest servletRequest = getHttpServletRequest(request);
        ServletHttpServletRequest httpServletRequest = ServletUtil.unWrapper(servletRequest);
        if (httpServletRequest == null) {
            throw new HandshakeFailureException(
                    "Servlet request failed to upgrade to WebSocket: " + servletRequest.getRequestURL());
        }

        WebSocketServerContainer serverContainer = getContainer(servletRequest);
        Principal principal = request.getPrincipal();
        Map pathParams = new LinkedHashMap<>(3);

        ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(servletRequest.getRequestURI(), endpoint);
        List subprotocols = new ArrayList<>();
        subprotocols.add("*");
        if (selectedProtocol != null && !subprotocols.contains(selectedProtocol)) {
            subprotocols.add(selectedProtocol);
        }
        endpointConfig.setSubprotocols(subprotocols);
        if (selectedExtensions != null) {
            endpointConfig.setExtensions(selectedExtensions);
        }

        try {
            handshakeToWebsocket(httpServletRequest, selectedProtocol, maxFramePayloadLength, principal,
                    selectedExtensions, pathParams, endpoint,
                    endpointConfig, serverContainer);
        } catch (Exception e) {
            throw new HandshakeFailureException(
                    "Servlet request failed to upgrade to WebSocket: " + servletRequest.getRequestURL(), e);
        }
    }

    @Override
    protected List getInstalledExtensions(WebSocketContainer container) {
        List result = new ArrayList<>();
        for (Extension extension : container.getInstalledExtensions()) {
            Map parameters = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
            for (Extension.Parameter parameter : extension.getParameters()) {
                parameters.put(parameter.getName(), parameter.getValue());
            }
            result.add(new WebSocketExtension(extension.getName(), parameters));
        }
        return result;
    }

    @Override
    protected WebSocketServerContainer getContainer(HttpServletRequest request) {
        return (WebSocketServerContainer) super.getContainer(request);
    }

    /**
     * The WebSocket handshake
     *
     * @param servletRequest        servletRequest
     * @param subprotocols          subprotocols
     * @param maxFramePayloadLength maxFramePayloadLength
     * @param userPrincipal         userPrincipal
     * @param negotiatedExtensions  negotiatedExtensions
     * @param pathParameters        pathParameters
     * @param localEndpoint         localEndpoint
     * @param endpointConfig        endpointConfig
     * @param webSocketContainer    webSocketContainer
     */
    protected void handshakeToWebsocket(ServletHttpServletRequest servletRequest, String subprotocols, int maxFramePayloadLength, Principal userPrincipal,
                                        List negotiatedExtensions, Map pathParameters,
                                        Endpoint localEndpoint, ServerEndpointConfig endpointConfig, WebSocketServerContainer webSocketContainer) {
        FullHttpRequest nettyRequest = convertFullHttpRequest(servletRequest);
        ServletHttpExchange exchange = servletRequest.getHttpExchange();
        exchange.setWebsocket(true);
        String queryString = servletRequest.getQueryString();
        String httpSessionId = servletRequest.getRequestedSessionId();
        String webSocketURL = getWebSocketLocation(servletRequest);
        Map> requestParameterMap = getRequestParameterMap(servletRequest);

        WebSocketServerHandshaker13Extension wsHandshaker = new WebSocketServerHandshaker13Extension(webSocketURL, subprotocols, true, maxFramePayloadLength);
        ChannelFuture handshakelFuture = wsHandshaker.handshake(exchange.getChannelHandlerContext().channel(), nettyRequest);
        handshakelFuture.addListener((ChannelFutureListener) future -> {
            if (future.isSuccess()) {
                Channel channel = future.channel();
                DispatcherChannelHandler.setMessageToRunnable(channel, new NettyMessageToWebSocketRunnable(DispatcherChannelHandler.getMessageToRunnable(channel)));
                WebSocketSession websocketSession = new WebSocketSession(
                        channel, webSocketContainer, wsHandshaker,
                        requestParameterMap,
                        queryString, userPrincipal, httpSessionId,
                        negotiatedExtensions, pathParameters, localEndpoint, endpointConfig);

                WebSocketSession.setSession(channel, websocketSession);

                localEndpoint.onOpen(websocketSession, endpointConfig);
            } else {
                logger.warn("The Websocket handshake failed : " + webSocketURL, future.cause());
            }
        });
    }

    private FullHttpRequest convertFullHttpRequest(ServletHttpServletRequest request) {
        HttpRequest nettyRequest = request.getNettyRequest();
        if (nettyRequest instanceof FullHttpRequest) {
            return (FullHttpRequest) nettyRequest;
        }
        return new DefaultFullHttpRequest(nettyRequest.protocolVersion(), nettyRequest.method(), nettyRequest.uri(), Unpooled.buffer(0), nettyRequest.headers(), EmptyHttpHeaders.INSTANCE);
    }

    protected Map> getRequestParameterMap(HttpServletRequest request) {
        MultiValueMap requestParameterMap = new LinkedMultiValueMap<>();
        for (Map.Entry entry : request.getParameterMap().entrySet()) {
            for (String value : entry.getValue()) {
                requestParameterMap.add(entry.getKey(), value);
            }
        }
        return requestParameterMap;
    }

    protected String getWebSocketLocation(HttpServletRequest req) {
        String host = req.getHeader(HttpHeaderConstants.HOST.toString());
        if (host == null || host.isEmpty()) {
            host = req.getServerName();
        }
        String scheme = req.isSecure() ? "wss://" : "ws://";
        return scheme + host + req.getRequestURI();
    }

    public int getMaxFramePayloadLength() {
        return maxFramePayloadLength;
    }

    public void setMaxFramePayloadLength(int maxFramePayloadLength) {
        this.maxFramePayloadLength = maxFramePayloadLength;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy