org.springframework.messaging.simp.stomp.StompDecoder Maven / Gradle / Ivy
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpLogging;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderInitializer;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.InvalidMimeTypeException;
import org.springframework.util.MultiValueMap;
/**
* Decodes one or more STOMP frames contained in a {@link ByteBuffer}.
*
* An attempt is made to read all complete STOMP frames from the buffer, which
* could be zero, one, or more. If there is any left-over content, i.e. an incomplete
* STOMP frame, at the end the buffer is reset to point to the beginning of the
* partial content. The caller is then responsible for dealing with that
* incomplete content by buffering until there is more input available.
*
* @author Andy Wilkinson
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompDecoder {
static final byte[] HEARTBEAT_PAYLOAD = new byte[] {'\n'};
private static final Log logger = SimpLogging.forLogName(StompDecoder.class);
@Nullable
private MessageHeaderInitializer headerInitializer;
/**
* Configure a {@link MessageHeaderInitializer} to apply to the headers of
* {@link Message Messages} from decoded STOMP frames.
*/
public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitializer) {
this.headerInitializer = headerInitializer;
}
/**
* Return the configured {@code MessageHeaderInitializer}, if any.
*/
@Nullable
public MessageHeaderInitializer getHeaderInitializer() {
return this.headerInitializer;
}
/**
* Decodes one or more STOMP frames from the given {@code ByteBuffer} into a
* list of {@link Message Messages}. If the input buffer contains partial STOMP frame
* content, or additional content with a partial STOMP frame, the buffer is
* reset and {@code null} is returned.
* @param byteBuffer the buffer to decode the STOMP frame from
* @return the decoded messages, or an empty list if none
* @throws StompConversionException raised in case of decoding issues
*/
public List> decode(ByteBuffer byteBuffer) {
return decode(byteBuffer, null);
}
/**
* Decodes one or more STOMP frames from the given {@code buffer} and returns
* a list of {@link Message Messages}.
* If the given ByteBuffer contains only partial STOMP frame content and no
* complete STOMP frames, an empty list is returned, and the buffer is reset to
* to where it was.
*
If the buffer contains one ore more STOMP frames, those are returned and
* the buffer reset to point to the beginning of the unused partial content.
*
The output partialMessageHeaders map is used to store successfully parsed
* headers in case of partial content. The caller can then check if a
* "content-length" header was read, which helps to determine how much more
* content is needed before the next attempt to decode.
* @param byteBuffer the buffer to decode the STOMP frame from
* @param partialMessageHeaders an empty output map that will store the last
* successfully parsed partialMessageHeaders in case of partial message content
* in cases where the partial buffer ended with a partial STOMP frame
* @return the decoded messages, or an empty list if none
* @throws StompConversionException raised in case of decoding issues
*/
public List> decode(ByteBuffer byteBuffer,
@Nullable MultiValueMap partialMessageHeaders) {
List> messages = new ArrayList<>();
while (byteBuffer.hasRemaining()) {
Message message = decodeMessage(byteBuffer, partialMessageHeaders);
if (message != null) {
messages.add(message);
}
else {
break;
}
}
return messages;
}
/**
* Decode a single STOMP frame from the given {@code buffer} into a {@link Message}.
*/
@Nullable
private Message decodeMessage(ByteBuffer byteBuffer, @Nullable MultiValueMap headers) {
Message decodedMessage = null;
skipLeadingEol(byteBuffer);
// Explicit mark/reset access via Buffer base type for compatibility
// with covariant return type on JDK 9's ByteBuffer...
Buffer buffer = byteBuffer;
buffer.mark();
String command = readCommand(byteBuffer);
if (command.length() > 0) {
StompHeaderAccessor headerAccessor = null;
byte[] payload = null;
if (byteBuffer.remaining() > 0) {
StompCommand stompCommand = StompCommand.valueOf(command);
headerAccessor = StompHeaderAccessor.create(stompCommand);
initHeaders(headerAccessor);
readHeaders(byteBuffer, headerAccessor);
payload = readPayload(byteBuffer, headerAccessor);
}
if (payload != null) {
if (payload.length > 0) {
StompCommand stompCommand = headerAccessor.getCommand();
if (stompCommand != null && !stompCommand.isBodyAllowed()) {
throw new StompConversionException(stompCommand +
" shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
}
}
headerAccessor.updateSimpMessageHeadersFromStompHeaders();
headerAccessor.setLeaveMutable(true);
decodedMessage = MessageBuilder.createMessage(payload, headerAccessor.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(payload));
}
}
else {
logger.trace("Incomplete frame, resetting input buffer...");
if (headers != null && headerAccessor != null) {
String name = NativeMessageHeaderAccessor.NATIVE_HEADERS;
@SuppressWarnings("unchecked")
MultiValueMap map = (MultiValueMap) headerAccessor.getHeader(name);
if (map != null) {
headers.putAll(map);
}
}
buffer.reset();
}
}
else {
StompHeaderAccessor headerAccessor = StompHeaderAccessor.createForHeartbeat();
initHeaders(headerAccessor);
headerAccessor.setLeaveMutable(true);
decodedMessage = MessageBuilder.createMessage(HEARTBEAT_PAYLOAD, headerAccessor.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Decoded " + headerAccessor.getDetailedLogMessage(null));
}
}
return decodedMessage;
}
private void initHeaders(StompHeaderAccessor headerAccessor) {
MessageHeaderInitializer initializer = getHeaderInitializer();
if (initializer != null) {
initializer.initHeaders(headerAccessor);
}
}
/**
* Skip one ore more EOL characters at the start of the given ByteBuffer.
* Those are STOMP heartbeat frames.
*/
protected void skipLeadingEol(ByteBuffer byteBuffer) {
while (true) {
if (!tryConsumeEndOfLine(byteBuffer)) {
break;
}
}
}
private String readCommand(ByteBuffer byteBuffer) {
ByteArrayOutputStream command = new ByteArrayOutputStream(256);
while (byteBuffer.remaining() > 0 && !tryConsumeEndOfLine(byteBuffer)) {
command.write(byteBuffer.get());
}
return new String(command.toByteArray(), StandardCharsets.UTF_8);
}
private void readHeaders(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
while (true) {
ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256);
boolean headerComplete = false;
while (byteBuffer.hasRemaining()) {
if (tryConsumeEndOfLine(byteBuffer)) {
headerComplete = true;
break;
}
headerStream.write(byteBuffer.get());
}
if (headerStream.size() > 0 && headerComplete) {
String header = new String(headerStream.toByteArray(), StandardCharsets.UTF_8);
int colonIndex = header.indexOf(':');
if (colonIndex <= 0) {
if (byteBuffer.remaining() > 0) {
throw new StompConversionException("Illegal header: '" + header +
"'. A header must be of the form :[].");
}
}
else {
String headerName = unescape(header.substring(0, colonIndex));
String headerValue = unescape(header.substring(colonIndex + 1));
try {
headerAccessor.addNativeHeader(headerName, headerValue);
}
catch (InvalidMimeTypeException ex) {
if (byteBuffer.remaining() > 0) {
throw ex;
}
}
}
}
else {
break;
}
}
}
/**
* See STOMP Spec 1.2:
* "Value Encoding".
*/
private String unescape(String inString) {
StringBuilder sb = new StringBuilder(inString.length());
int pos = 0; // position in the old string
int index = inString.indexOf('\\');
while (index >= 0) {
sb.append(inString.substring(pos, index));
if (index + 1 >= inString.length()) {
throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
}
Character c = inString.charAt(index + 1);
if (c == 'r') {
sb.append('\r');
}
else if (c == 'n') {
sb.append('\n');
}
else if (c == 'c') {
sb.append(':');
}
else if (c == '\\') {
sb.append('\\');
}
else {
// should never happen
throw new StompConversionException("Illegal escape sequence at index " + index + ": " + inString);
}
pos = index + 2;
index = inString.indexOf('\\', pos);
}
sb.append(inString.substring(pos));
return sb.toString();
}
@Nullable
private byte[] readPayload(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) {
Integer contentLength;
try {
contentLength = headerAccessor.getContentLength();
}
catch (NumberFormatException ex) {
if (logger.isDebugEnabled()) {
logger.debug("Ignoring invalid content-length: '" + headerAccessor);
}
contentLength = null;
}
if (contentLength != null && contentLength >= 0) {
if (byteBuffer.remaining() > contentLength) {
byte[] payload = new byte[contentLength];
byteBuffer.get(payload);
if (byteBuffer.get() != 0) {
throw new StompConversionException("Frame must be terminated with a null octet");
}
return payload;
}
else {
return null;
}
}
else {
ByteArrayOutputStream payload = new ByteArrayOutputStream(256);
while (byteBuffer.remaining() > 0) {
byte b = byteBuffer.get();
if (b == 0) {
return payload.toByteArray();
}
else {
payload.write(b);
}
}
}
return null;
}
/**
* Try to read an EOL incrementing the buffer position if successful.
* @return whether an EOL was consumed
*/
private boolean tryConsumeEndOfLine(ByteBuffer byteBuffer) {
if (byteBuffer.remaining() > 0) {
byte b = byteBuffer.get();
if (b == '\n') {
return true;
}
else if (b == '\r') {
if (byteBuffer.remaining() > 0 && byteBuffer.get() == '\n') {
return true;
}
else {
throw new StompConversionException("'\\r' must be followed by '\\n'");
}
}
// Explicit cast for compatibility with covariant return type on JDK 9's ByteBuffer
((Buffer) byteBuffer).position(byteBuffer.position() - 1);
}
return false;
}
}