org.cometd.client.websocket.common.AbstractWebSocketTransport Maven / Gradle / Ivy
/*
* Copyright (c) 2008 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cometd.client.websocket.common;
import java.io.EOFException;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.cometd.bayeux.Channel;
import org.cometd.bayeux.Message;
import org.cometd.bayeux.Message.Mutable;
import org.cometd.bayeux.Promise;
import org.cometd.client.transport.HttpClientTransport;
import org.cometd.client.transport.MessageClientTransport;
import org.cometd.client.transport.TransportListener;
import org.eclipse.jetty.util.thread.AutoLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class AbstractWebSocketTransport extends HttpClientTransport implements MessageClientTransport {
public static final String PREFIX = "ws";
public static final String NAME = "websocket";
public static final String PROTOCOL_OPTION = "protocol";
public static final String PERMESSAGE_DEFLATE_OPTION = "permessageDeflate";
public static final String CONNECT_TIMEOUT_OPTION = "connectTimeout";
public static final String IDLE_TIMEOUT_OPTION = "idleTimeout";
public static final String STICKY_RECONNECT_OPTION = "stickyReconnect";
public static final int MAX_CLOSE_REASON_LENGTH = 30;
public static final int NORMAL_CLOSE_CODE = 1000;
protected static final String COOKIE_HEADER = "Cookie";
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractWebSocketTransport.class);
private final AutoLock _lock = new AutoLock();
private boolean _open;
private String _protocol;
private boolean _perMessageDeflate;
private long _connectTimeout;
private long _idleTimeout;
private boolean _stickyReconnect;
private Delegate _delegate;
private TransportListener _listener;
protected AbstractWebSocketTransport(String url, Map options, ScheduledExecutorService scheduler) {
super(NAME, url, options, scheduler);
setOptionPrefix(PREFIX);
}
@Override
public void setMessageTransportListener(TransportListener listener) {
_listener = listener;
}
@Override
public void setURL(String url) {
// Mangle the URL
super.setURL(url.replaceFirst("^http", "ws"));
}
@Override
public void init() {
super.init();
_protocol = getOption(PROTOCOL_OPTION, _protocol);
_perMessageDeflate = getOption(PERMESSAGE_DEFLATE_OPTION, false);
setMaxNetworkDelay(15000L);
_connectTimeout = 30000L;
_idleTimeout = 60000L;
_stickyReconnect = getOption(STICKY_RECONNECT_OPTION, true);
locked(() -> {
_open = true;
initScheduler();
});
}
protected void locked(Runnable block) {
locked(() -> {
block.run();
return null;
});
}
protected T locked(Supplier block) {
try (AutoLock ignored = _lock.lock()) {
return block.get();
}
}
public String getProtocol() {
return _protocol;
}
public boolean isPerMessageDeflateEnabled() {
return _perMessageDeflate;
}
public long getIdleTimeout() {
return _idleTimeout = getOption(IDLE_TIMEOUT_OPTION, _idleTimeout);
}
public long getConnectTimeout() {
return _connectTimeout = getOption(CONNECT_TIMEOUT_OPTION, _connectTimeout);
}
public boolean isStickyReconnect() {
return _stickyReconnect;
}
@Override
public void abort(Throwable failure) {
Delegate delegate = locked(() -> {
_open = false;
shutdownScheduler();
return getDelegate();
});
if (delegate != null) {
delegate.abort(failure);
}
}
@Override
public void terminate() {
Delegate delegate = locked(() -> {
_open = false;
shutdownScheduler();
return getDelegate();
});
if (delegate != null) {
delegate.terminate();
}
super.terminate();
}
protected Delegate getDelegate() {
return locked(() -> _delegate);
}
@Override
public void send(TransportListener listener, List messages) {
Delegate delegate = getDelegate();
if (delegate == null) {
// Mangle the URL
String url = getURL();
url = url.replaceFirst("^http", "ws");
Delegate newDelegate = connect(url, listener, messages);
if (newDelegate == null) {
return;
}
delegate = locked(() -> {
if (_delegate == null) {
return _delegate = newDelegate;
} else {
// We connected concurrently, keep only one.
newDelegate.shutdown("Extra");
return _delegate;
}
});
}
try {
delegate.registerMessages(listener, messages);
String content = generateJSON(messages);
// The onSending() callback must be invoked before the actual send
// otherwise we may have a race condition where the response is so
// fast that it arrives before the onSending() is called.
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Sending messages {}", content);
}
listener.onSending(messages);
delegate.send(content);
} catch (Throwable x) {
delegate.fail(x, "Exception");
}
}
protected abstract Delegate connect(String uri, TransportListener listener, List messages);
protected abstract class Delegate {
private final Map _exchanges = new ConcurrentHashMap<>();
private boolean _connected;
private boolean _disconnected;
private Map _advice;
protected void onClose(int code, String reason) {
if (detach()) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Closed websocket connection {}/{}", code, reason);
}
close();
failMessages(new EOFException("Connection closed " + code + " " + reason));
}
}
protected void onData(String data) {
try {
List messages = parseMessages(data);
if (isAttached()) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Received messages {}", data);
}
onMessages(messages);
} else {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Discarded messages {}", data);
}
}
} catch (ParseException x) {
fail(x, "Exception");
}
}
protected void onMessages(List messages) {
for (Mutable message : messages) {
if (isReply(message)) {
// Remembering the advice must be done before we notify listeners
// otherwise we risk that listeners send a connect message that does
// not take into account the timeout to calculate the maxNetworkDelay
if (Channel.META_CONNECT.equals(message.getChannel()) && message.isSuccessful()) {
Map advice = message.getAdvice();
if (advice != null) {
// Remember the advice so that we can properly calculate the max network delay
if (advice.get(Message.TIMEOUT_FIELD) != null) {
_advice = advice;
}
}
}
WebSocketExchange exchange = deregisterMessage(message);
if (exchange != null) {
exchange.listener.onMessages(new ArrayList<>(List.of(message)));
} else {
// If the exchange is missing, then the message has expired, and we do not notify
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Could not find request for reply {}", message);
}
}
if (_disconnected && !_connected) {
disconnect("Disconnect");
}
} else {
_listener.onMessages(new ArrayList<>(List.of(message)));
}
}
}
private boolean isReply(Message message) {
if (message.isPublishReply()) {
return true;
}
if (message.isMeta()) {
// Check if it's a server-side disconnect.
if (Channel.META_DISCONNECT.equals(message.getChannel())) {
return message.getId() != null;
}
return true;
}
return false;
}
protected void registerMessages(TransportListener listener, List messages) {
boolean open = locked(() -> {
// Check whether it is active and register messages atomically.
if (isOpen()) {
for (Mutable message : messages) {
registerMessage(message, listener);
}
return true;
}
return false;
});
if (!open) {
listener.onFailure(new IOException("Unconnected"), messages);
}
}
private void registerMessage(Message.Mutable message, TransportListener listener) {
// Calculate max network delay
long maxNetworkDelay = getMaxNetworkDelay();
if (Channel.META_CONNECT.equals(message.getChannel())) {
Map advice = message.getAdvice();
if (advice == null) {
advice = _advice;
}
if (advice != null) {
Object timeout = advice.get("timeout");
if (timeout instanceof Number) {
maxNetworkDelay += ((Number)timeout).intValue();
} else if (timeout != null) {
maxNetworkDelay += Integer.parseInt(timeout.toString());
}
}
_connected = true;
}
// Schedule a task to expire if the maxNetworkDelay elapses.
long delay = maxNetworkDelay;
AtomicReference> timeoutTaskRef = new AtomicReference<>();
ScheduledFuture> newTask = getScheduler().schedule(() -> onTimeout(listener, message, delay, timeoutTaskRef), maxNetworkDelay, TimeUnit.MILLISECONDS);
timeoutTaskRef.set(newTask);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Started waiting for message reply, {} ms, task@{}", maxNetworkDelay, Integer.toHexString(newTask.hashCode()));
}
// Register the exchange
// Message responses must have the same messageId as the requests
WebSocketExchange exchange = new WebSocketExchange(message, listener, timeoutTaskRef);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Registering {}", exchange);
}
Object existing = _exchanges.put(message.getId(), exchange);
// Paranoid check
if (existing != null) {
throw new IllegalStateException("Could not register exchange " + exchange + ", existing exchange is " + existing + " for message " + message);
}
}
private void onTimeout(TransportListener listener, Message message, long delay, AtomicReference> timeoutTaskRef) {
listener.onTimeout(List.of(message), Promise.from(result -> {
if (result > 0) {
ScheduledFuture> newTask = getScheduler().schedule(() -> onTimeout(listener, message, delay + result, timeoutTaskRef), result, TimeUnit.MILLISECONDS);
ScheduledFuture> oldTask = timeoutTaskRef.getAndSet(newTask);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Extended waiting for message reply, {} ms, oldTask@{}, newTask@{}", result, Integer.toHexString(oldTask.hashCode()), Integer.toHexString(newTask.hashCode()));
}
} else {
fail(new TimeoutException("Network delay expired: " + delay + " ms"), "Expired");
}
}, failure -> fail(failure, "Failure")));
}
private WebSocketExchange deregisterMessage(Message message) {
if (Channel.META_CONNECT.equals(message.getChannel())) {
_connected = false;
} else if (Channel.META_DISCONNECT.equals(message.getChannel())) {
_disconnected = true;
}
WebSocketExchange exchange = null;
String messageId = message.getId();
if (messageId != null) {
exchange = _exchanges.remove(messageId);
}
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Deregistering {} for message {}", exchange, message);
}
if (exchange != null) {
ScheduledFuture> task = exchange.taskRef.get();
task.cancel(false);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Cancelled waiting for message replies, task@{}", Integer.toHexString(task.hashCode()));
}
}
return exchange;
}
protected String trimCloseReason(String reason) {
if (reason != null) {
return reason.substring(0, Math.min(reason.length(), MAX_CLOSE_REASON_LENGTH));
}
return null;
}
protected abstract void send(String content);
protected void fail(Throwable failure, String reason) {
disconnect(reason);
failMessages(failure);
}
protected void failMessages(Throwable cause) {
List messages = new ArrayList<>(1);
for (WebSocketExchange exchange : new ArrayList<>(_exchanges.values())) {
Mutable message = exchange.message;
if (deregisterMessage(message) == exchange) {
messages.add(message);
exchange.listener.onFailure(cause, messages);
messages.clear();
}
}
}
private void abort(Throwable failure) {
fail(failure, "Aborted");
}
private void disconnect(String reason) {
if (detach()) {
shutdown(reason);
}
}
private boolean isAttached() {
return locked(() -> this == _delegate);
}
private boolean detach() {
return locked(() -> {
boolean attached = this == _delegate;
if (attached) {
_delegate = null;
}
return attached;
});
}
protected boolean isOpen() {
return locked(() -> _open);
}
protected abstract void close();
protected abstract void shutdown(String reason);
private void terminate() {
fail(new EOFException(), "Terminate");
}
}
private static class WebSocketExchange {
private final Mutable message;
private final TransportListener listener;
private final AtomicReference> taskRef;
private WebSocketExchange(Mutable message, TransportListener listener, AtomicReference> taskRef) {
this.message = message;
this.listener = listener;
this.taskRef = taskRef;
}
@Override
public String toString() {
return getClass().getSimpleName() + " " + message;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy