All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.messaging.simp.stomp.DefaultStompSession Maven / Gradle / Ivy

There is a newer version: 6.1.6
Show newest version
/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.messaging.simp.stomp;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.logging.Log;

import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.converter.SimpleMessageConverter;
import org.springframework.messaging.simp.SimpLogging;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.tcp.TcpConnection;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.AlternativeJdkIdGenerator;
import org.springframework.util.Assert;
import org.springframework.util.IdGenerator;
import org.springframework.util.StringUtils;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;

/**
 * Default implementation of {@link ConnectionHandlingStompSession}.
 *
 * @author Rossen Stoyanchev
 * @since 4.2
 */
public class DefaultStompSession implements ConnectionHandlingStompSession {

	private static final Log logger = SimpLogging.forLogName(DefaultStompSession.class);

	private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator();

	/**
	 * An empty payload.
	 */
	public static final byte[] EMPTY_PAYLOAD = new byte[0];

	/* STOMP spec: receiver SHOULD take into account an error margin */
	private static final long HEARTBEAT_MULTIPLIER = 3;

	private static final Message HEARTBEAT;

	static {
		StompHeaderAccessor accessor = StompHeaderAccessor.createForHeartbeat();
		HEARTBEAT = MessageBuilder.createMessage(StompDecoder.HEARTBEAT_PAYLOAD, accessor.getMessageHeaders());
	}


	private final String sessionId;

	private final StompSessionHandler sessionHandler;

	private final StompHeaders connectHeaders;

	private final SettableListenableFuture sessionFuture = new SettableListenableFuture<>();

	private MessageConverter converter = new SimpleMessageConverter();

	@Nullable
	private TaskScheduler taskScheduler;

	private long receiptTimeLimit = TimeUnit.SECONDS.toMillis(15);

	private volatile boolean autoReceiptEnabled;


	@Nullable
	private volatile TcpConnection connection;

	@Nullable
	private volatile String version;

	private final AtomicInteger subscriptionIndex = new AtomicInteger();

	private final Map subscriptions = new ConcurrentHashMap<>(4);

	private final AtomicInteger receiptIndex = new AtomicInteger();

	private final Map receiptHandlers = new ConcurrentHashMap<>(4);

	/* Whether the client is willfully closing the connection */
	private volatile boolean closing = false;


	/**
	 * Create a new session.
	 * @param sessionHandler the application handler for the session
	 * @param connectHeaders headers for the STOMP CONNECT frame
	 */
	public DefaultStompSession(StompSessionHandler sessionHandler, StompHeaders connectHeaders) {
		Assert.notNull(sessionHandler, "StompSessionHandler must not be null");
		Assert.notNull(connectHeaders, "StompHeaders must not be null");
		this.sessionId = idGenerator.generateId().toString();
		this.sessionHandler = sessionHandler;
		this.connectHeaders = connectHeaders;
	}


	@Override
	public String getSessionId() {
		return this.sessionId;
	}

	/**
	 * Return the configured session handler.
	 */
	public StompSessionHandler getSessionHandler() {
		return this.sessionHandler;
	}

	@Override
	public ListenableFuture getSessionFuture() {
		return this.sessionFuture;
	}

	/**
	 * Set the {@link MessageConverter} to use to convert the payload of incoming
	 * and outgoing messages to and from {@code byte[]} based on object type, or
	 * expected object type, and the "content-type" header.
	 * 

By default, {@link SimpleMessageConverter} is configured. * @param messageConverter the message converter to use */ public void setMessageConverter(MessageConverter messageConverter) { Assert.notNull(messageConverter, "MessageConverter must not be null"); this.converter = messageConverter; } /** * Return the configured {@link MessageConverter}. */ public MessageConverter getMessageConverter() { return this.converter; } /** * Configure the TaskScheduler to use for receipt tracking. */ public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) { this.taskScheduler = taskScheduler; } /** * Return the configured TaskScheduler to use for receipt tracking. */ @Nullable public TaskScheduler getTaskScheduler() { return this.taskScheduler; } /** * Configure the time in milliseconds before a receipt expires. *

By default set to 15,000 (15 seconds). */ public void setReceiptTimeLimit(long receiptTimeLimit) { Assert.isTrue(receiptTimeLimit > 0, "Receipt time limit must be larger than zero"); this.receiptTimeLimit = receiptTimeLimit; } /** * Return the configured time limit before a receipt expires. */ public long getReceiptTimeLimit() { return this.receiptTimeLimit; } @Override public void setAutoReceipt(boolean autoReceiptEnabled) { this.autoReceiptEnabled = autoReceiptEnabled; } /** * Whether receipt headers should be automatically added. */ public boolean isAutoReceiptEnabled() { return this.autoReceiptEnabled; } @Override public boolean isConnected() { return (this.connection != null); } @Override public Receiptable send(String destination, Object payload) { StompHeaders headers = new StompHeaders(); headers.setDestination(destination); return send(headers, payload); } @Override public Receiptable send(StompHeaders headers, Object payload) { Assert.hasText(headers.getDestination(), "Destination header is required"); String receiptId = checkOrAddReceipt(headers); Receiptable receiptable = new ReceiptHandler(receiptId); StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SEND); accessor.addNativeHeaders(headers); Message message = createMessage(accessor, payload); execute(message); return receiptable; } @Nullable private String checkOrAddReceipt(StompHeaders headers) { String receiptId = headers.getReceipt(); if (isAutoReceiptEnabled() && receiptId == null) { receiptId = String.valueOf(DefaultStompSession.this.receiptIndex.getAndIncrement()); headers.setReceipt(receiptId); } return receiptId; } private StompHeaderAccessor createHeaderAccessor(StompCommand command) { StompHeaderAccessor accessor = StompHeaderAccessor.create(command); accessor.setSessionId(this.sessionId); accessor.setLeaveMutable(true); return accessor; } @SuppressWarnings("unchecked") private Message createMessage(StompHeaderAccessor accessor, @Nullable Object payload) { accessor.updateSimpMessageHeadersFromStompHeaders(); Message message; if (payload == null) { message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); } else if (payload instanceof byte[]) { message = MessageBuilder.createMessage((byte[]) payload, accessor.getMessageHeaders()); } else { message = (Message) getMessageConverter().toMessage(payload, accessor.getMessageHeaders()); accessor.updateStompHeadersFromSimpMessageHeaders(); if (message == null) { throw new MessageConversionException("Unable to convert payload with type='" + payload.getClass().getName() + "', contentType='" + accessor.getContentType() + "', converter=[" + getMessageConverter() + "]"); } } return message; } private void execute(Message message) { if (logger.isTraceEnabled()) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); if (accessor != null) { logger.trace("Sending " + accessor.getDetailedLogMessage(message.getPayload())); } } TcpConnection conn = this.connection; Assert.state(conn != null, "Connection closed"); try { conn.send(message).get(); } catch (ExecutionException ex) { throw new MessageDeliveryException(message, ex.getCause()); } catch (Throwable ex) { throw new MessageDeliveryException(message, ex); } } @Override public Subscription subscribe(String destination, StompFrameHandler handler) { StompHeaders headers = new StompHeaders(); headers.setDestination(destination); return subscribe(headers, handler); } @Override public Subscription subscribe(StompHeaders headers, StompFrameHandler handler) { Assert.hasText(headers.getDestination(), "Destination header is required"); Assert.notNull(handler, "StompFrameHandler must not be null"); String subscriptionId = headers.getId(); if (!StringUtils.hasText(subscriptionId)) { subscriptionId = String.valueOf(DefaultStompSession.this.subscriptionIndex.getAndIncrement()); headers.setId(subscriptionId); } checkOrAddReceipt(headers); Subscription subscription = new DefaultSubscription(headers, handler); StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.SUBSCRIBE); accessor.addNativeHeaders(headers); Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); return subscription; } @Override public Receiptable acknowledge(String messageId, boolean consumed) { StompHeaders headers = new StompHeaders(); if ("1.1".equals(this.version)) { headers.setMessageId(messageId); } else { headers.setId(messageId); } return acknowledge(headers, consumed); } @Override public Receiptable acknowledge(StompHeaders headers, boolean consumed) { String receiptId = checkOrAddReceipt(headers); Receiptable receiptable = new ReceiptHandler(receiptId); StompCommand command = (consumed ? StompCommand.ACK : StompCommand.NACK); StompHeaderAccessor accessor = createHeaderAccessor(command); accessor.addNativeHeaders(headers); Message message = createMessage(accessor, null); execute(message); return receiptable; } private void unsubscribe(String id, @Nullable StompHeaders headers) { StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.UNSUBSCRIBE); if (headers != null) { accessor.addNativeHeaders(headers); } accessor.setSubscriptionId(id); Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); } @Override public void disconnect() { this.closing = true; try { StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.DISCONNECT); Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); } finally { resetConnection(); } } // TcpConnectionHandler @Override public void afterConnected(TcpConnection connection) { this.connection = connection; if (logger.isDebugEnabled()) { logger.debug("Connection established in session id=" + this.sessionId); } StompHeaderAccessor accessor = createHeaderAccessor(StompCommand.CONNECT); accessor.addNativeHeaders(this.connectHeaders); if (this.connectHeaders.getAcceptVersion() == null) { accessor.setAcceptVersion("1.1,1.2"); } Message message = createMessage(accessor, EMPTY_PAYLOAD); execute(message); } @Override public void afterConnectFailure(Throwable ex) { if (logger.isDebugEnabled()) { logger.debug("Failed to connect session id=" + this.sessionId, ex); } this.sessionFuture.setException(ex); this.sessionHandler.handleTransportError(this, ex); } @Override public void handleMessage(Message message) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.state(accessor != null, "No StompHeaderAccessor"); accessor.setSessionId(this.sessionId); StompCommand command = accessor.getCommand(); Map> nativeHeaders = accessor.getNativeHeaders(); StompHeaders headers = StompHeaders.readOnlyStompHeaders(nativeHeaders); boolean isHeartbeat = accessor.isHeartbeat(); if (logger.isTraceEnabled()) { logger.trace("Received " + accessor.getDetailedLogMessage(message.getPayload())); } try { if (StompCommand.MESSAGE.equals(command)) { DefaultSubscription subscription = this.subscriptions.get(headers.getSubscription()); if (subscription != null) { invokeHandler(subscription.getHandler(), message, headers); } else if (logger.isDebugEnabled()) { logger.debug("No handler for: " + accessor.getDetailedLogMessage(message.getPayload()) + ". Perhaps just unsubscribed?"); } } else { if (StompCommand.RECEIPT.equals(command)) { String receiptId = headers.getReceiptId(); ReceiptHandler handler = this.receiptHandlers.get(receiptId); if (handler != null) { handler.handleReceiptReceived(); } else if (logger.isDebugEnabled()) { logger.debug("No matching receipt: " + accessor.getDetailedLogMessage(message.getPayload())); } } else if (StompCommand.CONNECTED.equals(command)) { initHeartbeatTasks(headers); this.version = headers.getFirst("version"); this.sessionFuture.set(this); this.sessionHandler.afterConnected(this, headers); } else if (StompCommand.ERROR.equals(command)) { invokeHandler(this.sessionHandler, message, headers); } else if (!isHeartbeat && logger.isTraceEnabled()) { logger.trace("Message not handled."); } } } catch (Throwable ex) { this.sessionHandler.handleException(this, command, headers, message.getPayload(), ex); } } private void invokeHandler(StompFrameHandler handler, Message message, StompHeaders headers) { if (message.getPayload().length == 0) { handler.handleFrame(headers, null); return; } Type payloadType = handler.getPayloadType(headers); Class resolvedType = ResolvableType.forType(payloadType).resolve(); if (resolvedType == null) { throw new MessageConversionException("Unresolvable payload type [" + payloadType + "] from handler type [" + handler.getClass() + "]"); } Object object = getMessageConverter().fromMessage(message, resolvedType); if (object == null) { throw new MessageConversionException("No suitable converter for payload type [" + payloadType + "] from handler type [" + handler.getClass() + "]"); } handler.handleFrame(headers, object); } private void initHeartbeatTasks(StompHeaders connectedHeaders) { long[] connect = this.connectHeaders.getHeartbeat(); long[] connected = connectedHeaders.getHeartbeat(); if (connect == null || connected == null) { return; } TcpConnection con = this.connection; Assert.state(con != null, "No TcpConnection available"); if (connect[0] > 0 && connected[1] > 0) { long interval = Math.max(connect[0], connected[1]); con.onWriteInactivity(new WriteInactivityTask(), interval); } if (connect[1] > 0 && connected[0] > 0) { long interval = Math.max(connect[1], connected[0]) * HEARTBEAT_MULTIPLIER; con.onReadInactivity(new ReadInactivityTask(), interval); } } @Override public void handleFailure(Throwable ex) { try { this.sessionFuture.setException(ex); // no-op if already set this.sessionHandler.handleTransportError(this, ex); } catch (Throwable ex2) { if (logger.isDebugEnabled()) { logger.debug("Uncaught failure while handling transport failure", ex2); } } } @Override public void afterConnectionClosed() { if (logger.isDebugEnabled()) { logger.debug("Connection closed in session id=" + this.sessionId); } if (!this.closing) { resetConnection(); handleFailure(new ConnectionLostException("Connection closed")); } } private void resetConnection() { TcpConnection conn = this.connection; this.connection = null; if (conn != null) { try { conn.close(); } catch (Throwable ex) { // ignore } } } private class ReceiptHandler implements Receiptable { @Nullable private final String receiptId; private final List receiptCallbacks = new ArrayList<>(2); private final List receiptLostCallbacks = new ArrayList<>(2); @Nullable private ScheduledFuture future; @Nullable private Boolean result; public ReceiptHandler(@Nullable String receiptId) { this.receiptId = receiptId; if (receiptId != null) { initReceiptHandling(); } } private void initReceiptHandling() { Assert.notNull(getTaskScheduler(), "To track receipts, a TaskScheduler must be configured"); DefaultStompSession.this.receiptHandlers.put(this.receiptId, this); Date startTime = new Date(System.currentTimeMillis() + getReceiptTimeLimit()); this.future = getTaskScheduler().schedule(this::handleReceiptNotReceived, startTime); } @Override @Nullable public String getReceiptId() { return this.receiptId; } @Override public void addReceiptTask(Runnable task) { addTask(task, true); } @Override public void addReceiptLostTask(Runnable task) { addTask(task, false); } private void addTask(Runnable task, boolean successTask) { Assert.notNull(this.receiptId, "To track receipts, set autoReceiptEnabled=true or add 'receiptId' header"); synchronized (this) { if (this.result != null && this.result == successTask) { invoke(Collections.singletonList(task)); } else { if (successTask) { this.receiptCallbacks.add(task); } else { this.receiptLostCallbacks.add(task); } } } } private void invoke(List callbacks) { for (Runnable runnable : callbacks) { try { runnable.run(); } catch (Throwable ex) { // ignore } } } public void handleReceiptReceived() { handleInternal(true); } public void handleReceiptNotReceived() { handleInternal(false); } private void handleInternal(boolean result) { synchronized (this) { if (this.result != null) { return; } this.result = result; invoke(result ? this.receiptCallbacks : this.receiptLostCallbacks); DefaultStompSession.this.receiptHandlers.remove(this.receiptId); if (this.future != null) { this.future.cancel(true); } } } } private class DefaultSubscription extends ReceiptHandler implements Subscription { private final StompHeaders headers; private final StompFrameHandler handler; public DefaultSubscription(StompHeaders headers, StompFrameHandler handler) { super(headers.getReceipt()); Assert.notNull(headers.getDestination(), "Destination must not be null"); Assert.notNull(handler, "StompFrameHandler must not be null"); this.headers = headers; this.handler = handler; DefaultStompSession.this.subscriptions.put(headers.getId(), this); } @Override @Nullable public String getSubscriptionId() { return this.headers.getId(); } @Override public StompHeaders getSubscriptionHeaders() { return this.headers; } public StompFrameHandler getHandler() { return this.handler; } @Override public void unsubscribe() { unsubscribe(null); } @Override public void unsubscribe(@Nullable StompHeaders headers) { String id = this.headers.getId(); if (id != null) { DefaultStompSession.this.subscriptions.remove(id); DefaultStompSession.this.unsubscribe(id, headers); } } @Override public String toString() { return "Subscription [id=" + getSubscriptionId() + ", destination='" + this.headers.getDestination() + "', receiptId='" + getReceiptId() + "', handler=" + getHandler() + "]"; } } private class WriteInactivityTask implements Runnable { @Override public void run() { TcpConnection conn = connection; if (conn != null) { conn.send(HEARTBEAT).addCallback( new ListenableFutureCallback() { public void onSuccess(@Nullable Void result) { } public void onFailure(Throwable ex) { handleFailure(ex); } }); } } } private class ReadInactivityTask implements Runnable { @Override public void run() { closing = true; String error = "Server has gone quiet. Closing connection in session id=" + sessionId + "."; if (logger.isDebugEnabled()) { logger.debug(error); } resetConnection(); handleFailure(new IllegalStateException(error)); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy