All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.aoju.bus.socket.spring.extension.StompEndpointRegistry Maven / Gradle / Ivy

/*********************************************************************************
 *                                                                               *
 * The MIT License                                                               *
 *                                                                               *
 * Copyright (c) 2015-2020 aoju.org and other contributors.                      *
 *                                                                               *
 * Permission is hereby granted, free of charge, to any person obtaining a copy  *
 * of this software and associated documentation files (the "Software"), to deal *
 * in the Software without restriction, including without limitation the rights  *
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell     *
 * copies of the Software, and to permit persons to whom the Software is         *
 * furnished to do so, subject to the following conditions:                      *
 *                                                                               *
 * The above copyright notice and this permission notice shall be included in    *
 * all copies or substantial portions of the Software.                           *
 *                                                                               *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR    *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,      *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE   *
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER        *
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, *
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN     *
 * THE SOFTWARE.                                                                 *
 ********************************************************************************/
package org.aoju.bus.socket.spring.extension;

import org.aoju.bus.core.lang.exception.SocketException;
import org.aoju.bus.logger.Logger;
import org.aoju.bus.socket.spring.intercept.FromClientInterceptor;
import org.aoju.bus.socket.spring.intercept.ToClientInterceptor;
import org.springframework.context.ApplicationContext;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.StompWebSocketEndpointRegistration;
import org.springframework.web.socket.config.annotation.WebMvcStompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebMvcStompWebSocketEndpointRegistration;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.WebSocketHandlerMapping;
import org.springframework.web.util.UrlPathHelper;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * 代替{@link WebMvcStompEndpointRegistry}
 *
 * @author Kimi Liu
 * @version 5.8.5
 * @since JDK 1.8+
 */
public class StompEndpointRegistry extends WebMvcStompEndpointRegistry {

    // 代替父类中的属性
    private final WebSocketHandler webSocketHandler;

    private final TaskScheduler sockJsScheduler;
    private final SubProtocolWebSocketHandler subProtocolWebSocketHandler;
    // 使用CustomizeStompSubProtocolHandler代替StompSubProtocolHandler
    private final StompSubProtocolHandler stompHandler;
    private final List registrations = new ArrayList<>();
    private int order = 1;
    private UrlPathHelper urlPathHelper;

    public StompEndpointRegistry(WebSocketHandler webSocketHandler, WebSocketTransportRegistration transportRegistration, TaskScheduler defaultSockJsTaskScheduler) {
        super(webSocketHandler, transportRegistration, defaultSockJsTaskScheduler);
        this.webSocketHandler = webSocketHandler;
        this.subProtocolWebSocketHandler = unwrapSubProtocolWebSocketHandler(webSocketHandler);
        // 使用类反射获取transportRegistration中的属性
        Integer sendTimeLimit = getTransportRegistrationValue(transportRegistration, "sendTimeLimit");
        Integer sendBufferSizeLimit = getTransportRegistrationValue(transportRegistration, "sendBufferSizeLimit");
        Integer timeToFirstMessage = getTransportRegistrationValue(transportRegistration, "timeToFirstMessage");
        Integer messageSizeLimit = getTransportRegistrationValue(transportRegistration, "messageSizeLimit");
        if (sendTimeLimit != null) {
            this.subProtocolWebSocketHandler.setSendTimeLimit(sendTimeLimit);
        }
        if (sendBufferSizeLimit != null) {
            this.subProtocolWebSocketHandler.setSendBufferSizeLimit(sendBufferSizeLimit);
        }
        if (timeToFirstMessage != null) {
            this.subProtocolWebSocketHandler.setTimeToFirstMessage(timeToFirstMessage);
        }
        // 替换为自定义的stompHandler
        this.stompHandler = new StompSubProtocolHandler();
        if (messageSizeLimit != null) {
            this.stompHandler.setMessageSizeLimit(messageSizeLimit);
        }
        this.sockJsScheduler = defaultSockJsTaskScheduler;
    }

    private static SubProtocolWebSocketHandler unwrapSubProtocolWebSocketHandler(WebSocketHandler handler) {
        WebSocketHandler actual = WebSocketHandlerDecorator.unwrap(handler);
        if (!(actual instanceof SubProtocolWebSocketHandler)) {
            throw new IllegalArgumentException("No SubProtocolWebSocketHandler in " + handler);
        }
        return (SubProtocolWebSocketHandler) actual;
    }

    // 使用反射获取WebSocketTransportRegistration实例的属性
    private static Integer getTransportRegistrationValue(WebSocketTransportRegistration transportRegistration, String fieldName) {
        Integer ret = null;
        try {
            Field limitField = WebSocketTransportRegistration.class.getDeclaredField(fieldName);
            limitField.setAccessible(true);
            Object value = limitField.get(transportRegistration);
            if (value != null) {
                ret = (Integer) value;
            }
        } catch (NoSuchFieldException | IllegalAccessException e) {
            Logger.error(e.getMessage(), e);
            throw new SocketException("获取" + fieldName + "的值出错", e);
        }
        return ret;
    }

    @Override
    public StompWebSocketEndpointRegistration addEndpoint(String... paths) {
        this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler);
        WebMvcStompWebSocketEndpointRegistration registration =
                new WebMvcStompWebSocketEndpointRegistration(paths, this.webSocketHandler, this.sockJsScheduler);
        this.registrations.add(registration);
        return registration;
    }

    @Override
    protected int getOrder() {
        return this.order;
    }

    @Override
    public void setOrder(int order) {
        this.order = order;
    }

    @Override
    protected UrlPathHelper getUrlPathHelper() {
        return this.urlPathHelper;
    }

    @Override
    public void setUrlPathHelper(UrlPathHelper urlPathHelper) {
        this.urlPathHelper = urlPathHelper;
    }

    @Override
    public StompEndpointRegistry setErrorHandler(StompSubProtocolErrorHandler errorHandler) {
        this.stompHandler.setErrorHandler(errorHandler);
        return this;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) {
        super.setApplicationContext(applicationContext);
    }

    public AbstractHandlerMapping getHandlerMapping() {
        Map urlMap = new LinkedHashMap<>();
        for (WebMvcStompWebSocketEndpointRegistration registration : this.registrations) {
            MultiValueMap mappings = registration.getMappings();
            mappings.forEach((httpHandler, patterns) -> {
                for (String pattern : patterns) {
                    urlMap.put(pattern, httpHandler);
                }
            });
        }
        WebSocketHandlerMapping hm = new WebSocketHandlerMapping();
        hm.setUrlMap(urlMap);
        hm.setOrder(this.order);
        if (this.urlPathHelper != null) {
            hm.setUrlPathHelper(this.urlPathHelper);
        }
        return hm;
    }

    // 增加消息拦截器
    public StompEndpointRegistry addFromClientInterceptor(FromClientInterceptor interceptor) {
        this.stompHandler.addFromClientInterceptor(interceptor);
        return this;
    }

    // 增加消息拦截器
    public StompEndpointRegistry addToClientInterceptor(ToClientInterceptor interceptor) {
        this.stompHandler.addToClientInterceptor(interceptor);
        return this;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy