All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.web.socket.messaging.StompSubProtocolHandler Maven / Gradle / Ivy

There is a newer version: 6.1.6
Show newest version
 * Copyright 2002-2015 the original author or authors.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.

package org.springframework.web.socket.messaging;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.util.Assert;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

 * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2
 * of the STOMP specification.
 * @author Rossen Stoyanchev
 * @author Andy Wilkinson
 * @since 4.0
public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {

	 * This handler supports assembling large STOMP messages split into multiple
	 * WebSocket messages and STOMP clients (like stomp.js) indeed split large STOMP
	 * messages at 16K boundaries. Therefore the WebSocket server input message
	 * buffer size must allow 16K at least plus a little extra for SockJS framing.
	public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16 * 1024 + 256;

	 * The name of the header set on the CONNECTED frame indicating the name
	 * of the user authenticated on the WebSocket session.
	public static final String CONNECTED_USER_HEADER = "user-name";

	private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);

	private static final byte[] EMPTY_PAYLOAD = new byte[0];

	private StompSubProtocolErrorHandler errorHandler;

	private int messageSizeLimit = 64 * 1024;

	private org.springframework.messaging.simp.user.UserSessionRegistry userSessionRegistry;

	private final StompEncoder stompEncoder = new StompEncoder();

	private final StompDecoder stompDecoder = new StompDecoder();

	private final Map decoders = new ConcurrentHashMap();

	private MessageHeaderInitializer headerInitializer;

	private Boolean immutableMessageInterceptorPresent;

	private ApplicationEventPublisher eventPublisher;

	private final Stats stats = new Stats();

	 * Configure a handler for error messages sent to clients which allows
	 * customizing the error messages or preventing them from being sent.

By default this isn't configured in which case an ERROR frame is sent * with a message header reflecting the error. * @param errorHandler the error handler */ public void setErrorHandler(StompSubProtocolErrorHandler errorHandler) { this.errorHandler = errorHandler; } /** * Return the configured error handler. */ public StompSubProtocolErrorHandler getErrorHandler() { return this.errorHandler; } /** * Configure the maximum size allowed for an incoming STOMP message. * Since a STOMP message can be received in multiple WebSocket messages, * buffering may be required and therefore it is necessary to know the maximum * allowed message size. *

By default this property is set to 64K. * @since 4.0.3 */ public void setMessageSizeLimit(int messageSizeLimit) { this.messageSizeLimit = messageSizeLimit; } /** * Get the configured message buffer size limit in bytes. * @since 4.0.3 */ public int getMessageSizeLimit() { return this.messageSizeLimit; } /** * Provide a registry with which to register active user session ids. * @see org.springframework.messaging.simp.user.UserDestinationMessageHandler * @deprecated as of 4.2 in favor of {@link DefaultSimpUserRegistry} which relies * on the ApplicationContext events published by this class and is created via * {@link org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurationSupport#createLocalUserRegistry * WebSocketMessageBrokerConfigurationSupport.createLocalUserRegistry} */ @Deprecated public void setUserSessionRegistry(org.springframework.messaging.simp.user.UserSessionRegistry registry) { this.userSessionRegistry = registry; } /** * @deprecated as of 4.2 */ @Deprecated public org.springframework.messaging.simp.user.UserSessionRegistry getUserSessionRegistry() { return this.userSessionRegistry; } /** * Configure a {@link MessageHeaderInitializer} to apply to the headers of all * messages created from decoded STOMP frames and other messages sent to the * client inbound channel. *

By default this property is not set. */ public void setHeaderInitializer(MessageHeaderInitializer headerInitializer) { this.headerInitializer = headerInitializer; this.stompDecoder.setHeaderInitializer(headerInitializer); } /** * Return the configured header initializer. */ public MessageHeaderInitializer getHeaderInitializer() { return this.headerInitializer; } @Override public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); } @Override public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { this.eventPublisher = applicationEventPublisher; } /** * Return a String describing internal state and counters. */ public String getStatsInfo() { return this.stats.toString(); } /** * Handle incoming WebSocket messages from clients. */ public void handleMessageFromClient(WebSocketSession session, WebSocketMessage webSocketMessage, MessageChannel outputChannel) { List> messages; try { ByteBuffer byteBuffer; if (webSocketMessage instanceof TextMessage) { byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes()); } else if (webSocketMessage instanceof BinaryMessage) { byteBuffer = ((BinaryMessage) webSocketMessage).getPayload(); } else { return; } BufferingStompDecoder decoder = this.decoders.get(session.getId()); if (decoder == null) { throw new IllegalStateException("No decoder for session id '" + session.getId() + "'"); } messages = decoder.decode(byteBuffer); if (messages.isEmpty()) { if (logger.isTraceEnabled()) { logger.trace("Incomplete STOMP frame content received in session " + session + ", bufferSize=" + decoder.getBufferSize() + ", bufferSizeLimit=" + decoder.getBufferSizeLimit() + "."); } return; } } catch (Throwable ex) { if (logger.isErrorEnabled()) { logger.error("Failed to parse " + webSocketMessage + " in session " + session.getId() + ". Sending STOMP ERROR to client.", ex); } handleError(session, ex, null); return; } for (Message message : messages) { try { StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Principal user = session.getPrincipal(); headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(user); headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat()); if (!detectImmutableMessageInterceptor(outputChannel)) { headerAccessor.setImmutable(); } if (logger.isTraceEnabled()) { logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload())); } if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) { this.stats.incrementConnectCount(); } else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) { this.stats.incrementDisconnectCount(); } try { SimpAttributesContextHolder.setAttributesFromMessage(message); boolean sent = outputChannel.send(message); if (sent && this.eventPublisher != null) { if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) { publishEvent(new SessionConnectEvent(this, message, user)); } else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) { publishEvent(new SessionSubscribeEvent(this, message, user)); } else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) { publishEvent(new SessionUnsubscribeEvent(this, message, user)); } } } finally { SimpAttributesContextHolder.resetAttributes(); } } catch (Throwable ex) { logger.error("Failed to send client message to application via MessageChannel" + " in session " + session.getId() + ". Sending STOMP ERROR to client.", ex); handleError(session, ex, message); } } } @SuppressWarnings("deprecation") private void handleError(WebSocketSession session, Throwable ex, Message clientMessage) { if (getErrorHandler() == null) { sendErrorMessage(session, ex); return; } Message message = getErrorHandler().handleClientMessageProcessingError(clientMessage, ex); if (message == null) { return; } StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.notNull(accessor, "Expected STOMP headers"); sendToClient(session, accessor, message.getPayload()); } /** * Invoked when no * {@link #setErrorHandler(StompSubProtocolErrorHandler) errorHandler} * is configured to send an ERROR frame to the client. * @deprecated as of Spring 4.2, in favor of * {@link #setErrorHandler(StompSubProtocolErrorHandler) configuring} * a {@code StompSubProtocolErrorHandler} */ @Deprecated protected void sendErrorMessage(WebSocketSession session, Throwable error) { StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR); headerAccessor.setMessage(error.getMessage()); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); try { session.sendMessage(new TextMessage(bytes)); } catch (Throwable ex) { // Could be part of normal workflow (e.g. browser tab closed) logger.debug("Failed to send STOMP ERROR to client", ex); } } private boolean detectImmutableMessageInterceptor(MessageChannel channel) { if (this.immutableMessageInterceptorPresent != null) { return this.immutableMessageInterceptorPresent; } if (channel instanceof AbstractMessageChannel) { for (ChannelInterceptor interceptor : ((AbstractMessageChannel) channel).getInterceptors()) { if (interceptor instanceof ImmutableMessageChannelInterceptor) { this.immutableMessageInterceptorPresent = true; return true; } } } this.immutableMessageInterceptorPresent = false; return false; } private void publishEvent(ApplicationEvent event) { try { this.eventPublisher.publishEvent(event); } catch (Throwable ex) { logger.error("Error publishing " + event, ex); } } /** * Handle STOMP messages going back out to WebSocket clients. */ @Override @SuppressWarnings("unchecked") public void handleMessageToClient(WebSocketSession session, Message message) { if (!(message.getPayload() instanceof byte[])) { logger.error("Expected byte[] payload. Ignoring " + message + "."); return; } StompHeaderAccessor stompAccessor = getStompHeaderAccessor(message); StompCommand command = stompAccessor.getCommand(); if (StompCommand.MESSAGE.equals(command)) { if (stompAccessor.getSubscriptionId() == null) { logger.warn("No STOMP \"subscription\" header in " + message); } String origDestination = stompAccessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION); if (origDestination != null) { stompAccessor = toMutableAccessor(stompAccessor, message); stompAccessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION); stompAccessor.setDestination(origDestination); } } else if (StompCommand.CONNECTED.equals(command)) { this.stats.incrementConnectedCount(); stompAccessor = afterStompSessionConnected(message, stompAccessor, session); if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) { try { SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes()); SimpAttributesContextHolder.setAttributes(simpAttributes); Principal user = session.getPrincipal(); publishEvent(new SessionConnectedEvent(this, (Message) message, user)); } finally { SimpAttributesContextHolder.resetAttributes(); } } } byte[] payload = (byte[]) message.getPayload(); if (StompCommand.ERROR.equals(command) && getErrorHandler() != null) { Message errorMessage = getErrorHandler().handleErrorMessageToClient((Message) message); stompAccessor = MessageHeaderAccessor.getAccessor(errorMessage, StompHeaderAccessor.class); Assert.notNull(stompAccessor, "Expected STOMP headers"); payload = errorMessage.getPayload(); } sendToClient(session, stompAccessor, payload); } private void sendToClient(WebSocketSession session, StompHeaderAccessor stompAccessor, byte[] payload) { StompCommand command = stompAccessor.getCommand(); try { byte[] bytes = this.stompEncoder.encode(stompAccessor.getMessageHeaders(), payload); boolean useBinary = (payload.length > 0 && !(session instanceof SockJsSession) && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompAccessor.getContentType())); if (useBinary) { session.sendMessage(new BinaryMessage(bytes)); } else { session.sendMessage(new TextMessage(bytes)); } } catch (SessionLimitExceededException ex) { // Bad session, just get out throw ex; } catch (Throwable ex) { // Could be part of normal workflow (e.g. browser tab closed) logger.debug("Failed to send WebSocket message to client in session " + session.getId(), ex); command = StompCommand.ERROR; } finally { if (StompCommand.ERROR.equals(command)) { try { session.close(CloseStatus.PROTOCOL_ERROR); } catch (IOException ex) { // Ignore } } } } private StompHeaderAccessor getStompHeaderAccessor(Message message) { MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class); if (accessor == null) { // Shouldn't happen (only broker broadcasts directly to clients) throw new IllegalStateException("No header accessor in " + message); } StompHeaderAccessor stompAccessor; if (accessor instanceof StompHeaderAccessor) { stompAccessor = (StompHeaderAccessor) accessor; } else if (accessor instanceof SimpMessageHeaderAccessor) { stompAccessor = StompHeaderAccessor.wrap(message); SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders()); if (SimpMessageType.CONNECT_ACK.equals(messageType)) { stompAccessor = convertConnectAcktoStompConnected(stompAccessor); } else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) { stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR); stompAccessor.setMessage("Session closed."); } else if (SimpMessageType.HEARTBEAT.equals(messageType)) { stompAccessor = StompHeaderAccessor.createForHeartbeat(); } else if (stompAccessor.getCommand() == null || StompCommand.SEND.equals(stompAccessor.getCommand())) { stompAccessor.updateStompCommandAsServerMessage(); } } else { // Shouldn't happen (only broker broadcasts directly to clients) throw new IllegalStateException( "Unexpected header accessor type: " + accessor.getClass() + " in " + message); } return stompAccessor; } /** * The simple broker produces {@code SimpMessageType.CONNECT_ACK} that's not STOMP * specific and needs to be turned into a STOMP CONNECTED frame. */ private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor connectAckHeaders) { String name = StompHeaderAccessor.CONNECT_MESSAGE_HEADER; Message message = (Message) connectAckHeaders.getHeader(name); Assert.notNull(message, "Original STOMP CONNECT not found in " + connectAckHeaders); StompHeaderAccessor connectHeaders = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED); Set acceptVersions = connectHeaders.getAcceptVersion(); if (acceptVersions.contains("1.2")) { connectedHeaders.setVersion("1.2"); } else if (acceptVersions.contains("1.1")) { connectedHeaders.setVersion("1.1"); } else if (!acceptVersions.isEmpty()) { throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersions + "'"); } long[] heartbeat = (long[]) connectAckHeaders.getHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER); if (heartbeat != null) { connectedHeaders.setHeartbeat(heartbeat[0], heartbeat[1]); } else { connectedHeaders.setHeartbeat(0, 0); } return connectedHeaders; } protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message message) { return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message)); } @SuppressWarnings("deprecation") private StompHeaderAccessor afterStompSessionConnected(Message message, StompHeaderAccessor accessor, WebSocketSession session) { Principal principal = session.getPrincipal(); if (principal != null) { accessor = toMutableAccessor(accessor, message); accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName()); if (this.userSessionRegistry != null) { String userName = getSessionRegistryUserName(principal); this.userSessionRegistry.registerSessionId(userName, session.getId()); } } long[] heartbeat = accessor.getHeartbeat(); if (heartbeat[1] > 0) { session = WebSocketSessionDecorator.unwrap(session); if (session instanceof SockJsSession) { ((SockJsSession) session).disableHeartbeat(); } } return accessor; } private String getSessionRegistryUserName(Principal principal) { String userName = principal.getName(); if (principal instanceof DestinationUserNameProvider) { userName = ((DestinationUserNameProvider) principal).getDestinationUserName(); } return userName; } @Override public String resolveSessionId(Message message) { return SimpMessageHeaderAccessor.getSessionId(message.getHeaders()); } @Override public void afterSessionStarted(WebSocketSession session, MessageChannel outputChannel) { if (session.getTextMessageSizeLimit() < MINIMUM_WEBSOCKET_MESSAGE_SIZE) { session.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE); } this.decoders.put(session.getId(), new BufferingStompDecoder(this.stompDecoder, getMessageSizeLimit())); } @Override @SuppressWarnings("deprecation") public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) { this.decoders.remove(session.getId()); Principal principal = session.getPrincipal(); if (principal != null && this.userSessionRegistry != null) { String userName = getSessionRegistryUserName(principal); this.userSessionRegistry.unregisterSessionId(userName, session.getId()); } Message message = createDisconnectMessage(session); SimpAttributes simpAttributes = SimpAttributes.fromMessage(message); try { SimpAttributesContextHolder.setAttributes(simpAttributes); if (this.eventPublisher != null) { Principal user = session.getPrincipal(); publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user)); } outputChannel.send(message); } finally { SimpAttributesContextHolder.resetAttributes(); simpAttributes.sessionCompleted(); } } private Message createDisconnectMessage(WebSocketSession session) { StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.DISCONNECT); if (getHeaderInitializer() != null) { getHeaderInitializer().initHeaders(headerAccessor); } headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(session.getPrincipal()); return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders()); } @Override public String toString() { return "StompSubProtocolHandler" + getSupportedProtocols(); } private static class Stats { private final AtomicInteger connect = new AtomicInteger(); private final AtomicInteger connected = new AtomicInteger(); private final AtomicInteger disconnect = new AtomicInteger(); public void incrementConnectCount() { this.connect.incrementAndGet(); } public void incrementConnectedCount() { this.connected.incrementAndGet(); } public void incrementDisconnectCount() { this.disconnect.incrementAndGet(); } public String toString() { return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" + this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + ")"; } } }

© 2015 - 2024 Weber Informatics LLC | Privacy Policy