org.cometd.server.BayeuxServerImpl Maven / Gradle / Ivy
/*
* Copyright (c) 2008-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cometd.server;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.Spliterator;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import org.cometd.bayeux.Bayeux;
import org.cometd.bayeux.Channel;
import org.cometd.bayeux.ChannelId;
import org.cometd.bayeux.MarkedReference;
import org.cometd.bayeux.Message;
import org.cometd.bayeux.Promise;
import org.cometd.bayeux.server.Authorizer;
import org.cometd.bayeux.server.BayeuxContext;
import org.cometd.bayeux.server.BayeuxServer;
import org.cometd.bayeux.server.ConfigurableServerChannel;
import org.cometd.bayeux.server.ConfigurableServerChannel.Initializer;
import org.cometd.bayeux.server.LocalSession;
import org.cometd.bayeux.server.SecurityPolicy;
import org.cometd.bayeux.server.ServerChannel;
import org.cometd.bayeux.server.ServerChannel.MessageListener;
import org.cometd.bayeux.server.ServerMessage;
import org.cometd.bayeux.server.ServerMessage.Mutable;
import org.cometd.bayeux.server.ServerSession;
import org.cometd.bayeux.server.ServerTransport;
import org.cometd.common.AsyncFoldLeft;
import org.cometd.server.http.AbstractHttpTransport;
import org.cometd.server.http.AsyncJSONTransport;
import org.cometd.server.http.JSONPTransport;
import org.cometd.server.http.JSONTransport;
import org.eclipse.jetty.util.annotation.ManagedAttribute;
import org.eclipse.jetty.util.annotation.ManagedObject;
import org.eclipse.jetty.util.annotation.ManagedOperation;
import org.eclipse.jetty.util.annotation.Name;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.component.Dumpable;
import org.eclipse.jetty.util.component.DumpableCollection;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
import org.eclipse.jetty.util.thread.Scheduler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ManagedObject("The CometD server")
public class BayeuxServerImpl extends ContainerLifeCycle implements BayeuxServer, Dumpable {
public static final String ALLOWED_TRANSPORTS_OPTION = "allowedTransports";
public static final String SWEEP_PERIOD_OPTION = "sweepPeriod";
public static final String SWEEP_THREADS_OPTION = "sweepThreads";
public static final String TRANSPORTS_OPTION = "transports";
public static final String VALIDATE_MESSAGE_FIELDS_OPTION = "validateMessageFields";
public static final String BROADCAST_TO_PUBLISHER_OPTION = "broadcastToPublisher";
public static final String SCHEDULER_THREADS = "schedulerThreads";
public static final String EXECUTOR_MAX_THREADS = "executorMaxThreads";
private static final long DEFAULT_SWEEP_PERIOD = 997;
private static final int DEFAULT_SWEEP_THREADS = 2;
private final String _name = getClass().getSimpleName() + "@" + Integer.toHexString(System.identityHashCode(this));
private final Logger _logger = LoggerFactory.getLogger(getClass().getPackage().getName() + "." + _name);
private final SecureRandom _random = new SecureRandom();
private final List _listeners = new CopyOnWriteArrayList<>();
private final List _extensions = new CopyOnWriteArrayList<>();
private final ConcurrentMap _sessions = new ConcurrentHashMap<>();
private final ConcurrentMap _channels = new ConcurrentHashMap<>();
private final Map _transports = new LinkedHashMap<>(); // Order is important
private final List _allowedTransports = new ArrayList<>();
private final Map _options = new TreeMap<>();
private final Sweeper _sweeper = new Sweeper();
private MarkedReference _scheduler;
private MarkedReference _executor;
private SecurityPolicy _policy = new DefaultSecurityPolicy();
private JSONContextServer _jsonContext;
private boolean _validation;
private boolean _broadcastToPublisher;
private boolean _detailedDump;
private long _sweepPeriod;
private int _sweepThreads;
public String getName() {
return _name;
}
@Override
protected void doStart() throws Exception {
initializeMetaChannels();
initializeJSONContext();
initializeServerTransports();
if (_executor == null) {
_executor = new MarkedReference<>(newExecutor(), true);
}
addBean(_executor.getReference());
if (_scheduler == null) {
_scheduler = new MarkedReference<>(newScheduler(), true);
}
addBean(_scheduler.getReference());
long sweepPeriodOption = getOption(SWEEP_PERIOD_OPTION, DEFAULT_SWEEP_PERIOD);
if (sweepPeriodOption < 0) {
sweepPeriodOption = DEFAULT_SWEEP_PERIOD;
}
_sweepPeriod = sweepPeriodOption;
long sweepThreads = getOption(SWEEP_THREADS_OPTION, DEFAULT_SWEEP_THREADS);
if (sweepThreads < DEFAULT_SWEEP_THREADS) {
sweepThreads = DEFAULT_SWEEP_THREADS;
}
_sweepThreads = (int)Math.min(sweepThreads, Runtime.getRuntime().availableProcessors());
_validation = getOption(VALIDATE_MESSAGE_FIELDS_OPTION, true);
_broadcastToPublisher = getOption(BROADCAST_TO_PUBLISHER_OPTION, true);
super.doStart();
schedule(new Runnable() {
@Override
public void run() {
asyncSweep().whenComplete((r, x) -> schedule(this, getSweepPeriod()));
}
}, getSweepPeriod());
}
@Override
protected void doStop() throws Exception {
super.doStop();
for (String allowedTransportName : getAllowedTransports()) {
ServerTransport transport = getTransport(allowedTransportName);
if (transport instanceof AbstractServerTransport) {
((AbstractServerTransport)transport).destroy();
}
}
_listeners.clear();
_extensions.clear();
_sessions.clear();
_channels.clear();
_transports.clear();
_allowedTransports.clear();
_options.clear();
removeBean(_scheduler.getReference());
if (_scheduler.isMarked()) {
_scheduler = null;
}
removeBean(_executor.getReference());
if (_executor.isMarked()) {
_executor = null;
}
}
protected void initializeMetaChannels() {
createChannelIfAbsent(Channel.META_HANDSHAKE);
createChannelIfAbsent(Channel.META_CONNECT);
createChannelIfAbsent(Channel.META_SUBSCRIBE);
createChannelIfAbsent(Channel.META_UNSUBSCRIBE);
createChannelIfAbsent(Channel.META_DISCONNECT);
}
protected void initializeJSONContext() throws Exception {
Object option = getOption(AbstractServerTransport.JSON_CONTEXT_OPTION);
if (option == null) {
_jsonContext = new JettyJSONContextServer();
} else {
if (option instanceof String) {
Class> jsonContextClass = Thread.currentThread().getContextClassLoader().loadClass((String)option);
if (JSONContextServer.class.isAssignableFrom(jsonContextClass)) {
_jsonContext = (JSONContextServer)jsonContextClass.getConstructor().newInstance();
} else {
throw new IllegalArgumentException("Invalid " + JSONContextServer.class.getName() + " implementation class");
}
} else if (option instanceof JSONContextServer) {
_jsonContext = (JSONContextServer)option;
} else {
throw new IllegalArgumentException("Invalid " + JSONContextServer.class.getName() + " implementation class");
}
}
_options.put(AbstractServerTransport.JSON_CONTEXT_OPTION, _jsonContext);
}
protected void initializeServerTransports() {
if (_transports.isEmpty()) {
String option = (String)getOption(TRANSPORTS_OPTION);
if (option == null) {
// Order is important, see #findHttpTransport()
ServerTransport transport = newWebSocketTransport();
if (transport != null) {
addTransport(transport);
}
addTransport(newJSONTransport());
addTransport(new JSONPTransport(this));
} else {
for (String className : option.split(",")) {
ServerTransport transport = newServerTransport(className.trim());
if (transport != null) {
addTransport(transport);
}
}
if (_transports.isEmpty()) {
throw new IllegalArgumentException("Option '" + TRANSPORTS_OPTION +
"' does not contain a valid list of server transport class names");
}
}
}
if (_allowedTransports.isEmpty()) {
String option = (String)getOption(ALLOWED_TRANSPORTS_OPTION);
if (option == null) {
_allowedTransports.addAll(_transports.keySet());
} else {
for (String transportName : option.split(",")) {
if (_transports.containsKey(transportName)) {
_allowedTransports.add(transportName);
}
}
if (_allowedTransports.isEmpty()) {
throw new IllegalArgumentException("Option '" + ALLOWED_TRANSPORTS_OPTION +
"' does not contain at least one configured server transport name");
}
}
}
List activeTransports = new ArrayList<>();
for (String transportName : _allowedTransports) {
ServerTransport serverTransport = getTransport(transportName);
if (serverTransport instanceof AbstractServerTransport) {
((AbstractServerTransport)serverTransport).init();
activeTransports.add(serverTransport.getName());
}
}
if (_logger.isDebugEnabled()) {
_logger.debug("Active transports: {}", activeTransports);
}
}
private ServerTransport newWebSocketTransport() {
try {
ClassLoader loader = Thread.currentThread().getContextClassLoader();
loader.loadClass("javax.websocket.server.ServerContainer");
String transportClass = "org.cometd.server.websocket.javax.WebSocketTransport";
ServerTransport transport = newServerTransport(transportClass);
if (transport == null) {
_logger.info("JSR 356 WebSocket classes available, but " + transportClass +
" unavailable: JSR 356 WebSocket transport disabled");
}
return transport;
} catch (Exception x) {
return null;
}
}
private ServerTransport newJSONTransport() {
try {
ClassLoader loader = Thread.currentThread().getContextClassLoader();
loader.loadClass("javax.servlet.ReadListener");
return new AsyncJSONTransport(this);
} catch (Exception x) {
return new JSONTransport(this);
}
}
private ServerTransport newServerTransport(String className) {
try {
ClassLoader loader = Thread.currentThread().getContextClassLoader();
@SuppressWarnings("unchecked")
Class extends ServerTransport> klass = (Class extends ServerTransport>)loader.loadClass(className);
Constructor extends ServerTransport> constructor = klass.getConstructor(BayeuxServerImpl.class);
return constructor.newInstance(this);
} catch (Exception x) {
return null;
}
}
public void setExecutor(Executor executor) {
if (isRunning()) {
throw new IllegalStateException("Cannot set executor on a running BayeuxServer instance");
}
_executor = new MarkedReference<>(Objects.requireNonNull(executor), false);
}
public Executor getExecutor() {
return _executor == null ? null : _executor.getReference();
}
private Executor newExecutor() {
String name = _name + "-Executor";
int maxThreads = (int)getOption(EXECUTOR_MAX_THREADS, 128);
QueuedThreadPool executor = new QueuedThreadPool(maxThreads, 0);
executor.setName(name);
executor.setReservedThreads(0);
return executor;
}
public void execute(Runnable job) {
Executor executor = getExecutor();
if (executor == null) {
throw new RejectedExecutionException("Cannot execute job, no executor");
}
executor.execute(job);
}
public void setScheduler(Scheduler scheduler) {
if (isRunning()) {
throw new IllegalStateException("Cannot set scheduler on a running BayeuxServer instance");
}
_scheduler = new MarkedReference<>(Objects.requireNonNull(scheduler), false);
}
public Scheduler getScheduler() {
return _scheduler == null ? null : _scheduler.getReference();
}
private Scheduler newScheduler() {
String name = _name + "-Scheduler";
int threads = (int)getOption(SCHEDULER_THREADS, 1);
return new ScheduledExecutorScheduler(name, false, threads);
}
/**
* Entry point to schedule tasks in CometD.
* Subclasses may override and run the task in a {@link java.util.concurrent.Executor},
* rather than in the scheduler thread.
*
* @param task the task to schedule
* @param delay the delay, in milliseconds, to run the task
* @return the task promise
*/
public Scheduler.Task schedule(Runnable task, long delay) {
Scheduler scheduler = getScheduler();
if (scheduler == null) {
throw new RejectedExecutionException("Cannot schedule task, no scheduler");
}
return scheduler.schedule(task, delay, TimeUnit.MILLISECONDS);
}
public ChannelId newChannelId(String id) {
ServerChannelImpl channel = _channels.get(id);
if (channel != null) {
return channel.getChannelId();
}
return new ChannelId(id);
}
public Map getOptions() {
return _options;
}
@Override
@ManagedOperation(value = "The value of the given configuration option", impact = "INFO")
public Object getOption(@Name("optionName") String qualifiedName) {
return _options.get(qualifiedName);
}
protected long getOption(String name, long dft) {
Object val = getOption(name);
if (val == null) {
return dft;
}
if (val instanceof Number) {
return ((Number)val).longValue();
}
return Long.parseLong(val.toString());
}
protected boolean getOption(String name, boolean dft) {
Object value = getOption(name);
if (value == null) {
return dft;
}
if (value instanceof Boolean) {
return (Boolean)value;
}
return Boolean.parseBoolean(value.toString());
}
@Override
public Set getOptionNames() {
return _options.keySet();
}
@Override
public void setOption(String qualifiedName, Object value) {
_options.put(qualifiedName, value);
}
public void setOptions(Map options) {
_options.putAll(options);
}
public long randomLong() {
long value = _random.nextLong();
return value < 0 ? -value : value;
}
@Override
public SecurityPolicy getSecurityPolicy() {
return _policy;
}
public JSONContextServer getJSONContext() {
return _jsonContext;
}
@Override
public MarkedReference createChannelIfAbsent(String channelName, Initializer... initializers) {
ChannelId channelId;
boolean initialized = false;
ServerChannelImpl channel = _channels.get(channelName);
if (channel == null) {
// Creating the ChannelId will also normalize the channelName.
channelId = new ChannelId(channelName);
String id = channelId.getId();
if (!id.equals(channelName)) {
channelName = id;
channel = _channels.get(channelName);
}
} else {
channelId = channel.getChannelId();
}
if (channel == null) {
ServerChannelImpl candidate = new ServerChannelImpl(this, channelId);
channel = _channels.putIfAbsent(channelName, candidate);
if (channel == null) {
// My candidate channel was added to the map, so I'd better initialize it
channel = candidate;
if (_logger.isDebugEnabled()) {
_logger.debug("Added channel {}", channel);
}
try {
for (Initializer initializer : initializers) {
notifyConfigureChannel(initializer, channel);
}
for (BayeuxServer.BayeuxServerListener listener : _listeners) {
if (listener instanceof ServerChannel.Initializer) {
notifyConfigureChannel((Initializer)listener, channel);
}
}
} finally {
channel.initialized();
}
for (BayeuxServer.BayeuxServerListener listener : _listeners) {
if (listener instanceof BayeuxServer.ChannelListener) {
notifyChannelAdded((ChannelListener)listener, channel);
}
}
initialized = true;
}
} else {
channel.resetSweeperPasses();
// Double check if the sweeper removed this channel between the check at the top and here.
// This is not 100% fool proof (e.g. this thread is preempted long enough for the sweeper
// to remove the channel, but the alternative is to have a global lock)
_channels.putIfAbsent(channelName, channel);
}
// Another thread may add this channel concurrently, so wait until it is initialized
channel.waitForInitialized();
return new MarkedReference<>(channel, initialized);
}
private void notifyConfigureChannel(Initializer listener, ServerChannel channel) {
try {
listener.configureChannel(channel);
} catch (Throwable x) {
_logger.info("Exception while invoking listener " + listener, x);
}
}
private void notifyChannelAdded(ChannelListener listener, ServerChannel channel) {
try {
listener.channelAdded(channel);
} catch (Throwable x) {
_logger.info("Exception while invoking listener " + listener, x);
}
}
@Override
public List getSessions() {
return List.copyOf(_sessions.values());
}
@Override
public ServerSession getSession(String clientId) {
return clientId == null ? null : _sessions.get(clientId);
}
protected void addServerSession(ServerSessionImpl session, ServerMessage message) {
if (_logger.isDebugEnabled()) {
_logger.debug("Adding {}", session);
}
_sessions.put(session.getId(), session);
for (BayeuxServerListener listener : _listeners) {
if (listener instanceof BayeuxServer.SessionListener) {
notifySessionAdded((SessionListener)listener, session, message);
}
}
session.added(message);
}
private void notifySessionAdded(SessionListener listener, ServerSession session, ServerMessage message) {
try {
listener.sessionAdded(session, message);
} catch (Throwable x) {
_logger.info("Exception while invoking listener " + listener, x);
}
}
@Override
public boolean removeSession(ServerSession session) {
return removeServerSession(session, null, false).getReference() != null;
}
/**
* @param session the session to remove
* @param timeout whether the session has been removed due to a timeout
* @return true if the session was removed and was connected
*/
public boolean removeServerSession(ServerSession session, boolean timeout) {
return removeServerSession(session, null, timeout).isMarked();
}
private MarkedReference removeServerSession(ServerSession session, ServerMessage message, boolean timeout) {
if (_logger.isDebugEnabled()) {
_logger.debug("Removing session timeout: {}, {}, message: {}", timeout, session, message);
}
ServerSessionImpl removed = _sessions.remove(session.getId());
if (removed != session) {
return MarkedReference.empty();
}
// Invoke BayeuxServer.SessionListener first, so that the application
// can be "pre-notified" that a session is being removed before the
// application gets notifications of channel unsubscriptions.
for (BayeuxServerListener listener : _listeners) {
if (listener instanceof SessionListener) {
notifySessionRemoved((SessionListener)listener, removed, message, timeout);
}
}
boolean connected = removed.removed(message, timeout);
return new MarkedReference<>(removed, connected);
}
private void notifySessionRemoved(SessionListener listener, ServerSession session, ServerMessage message, boolean timeout) {
try {
listener.sessionRemoved(session, message, timeout);
} catch (Throwable x) {
_logger.info("Exception while invoking listener " + listener, x);
}
}
public ServerSessionImpl newServerSession() {
return new ServerSessionImpl(this);
}
@Override
public LocalSession newLocalSession(String idHint) {
return new LocalSessionImpl(this, idHint);
}
@Override
public ServerMessage.Mutable newMessage() {
return new ServerMessageImpl();
}
public ServerMessage.Mutable newMessage(ServerMessage original) {
ServerMessage.Mutable mutable = newMessage();
mutable.putAll(original);
return mutable;
}
@Override
public void setSecurityPolicy(SecurityPolicy securityPolicy) {
_policy = securityPolicy;
}
@Override
public void addExtension(Extension extension) {
_extensions.add(extension);
}
@Override
public void removeExtension(Extension extension) {
_extensions.remove(extension);
}
@Override
public List getExtensions() {
return List.copyOf(_extensions);
}
@Override
public void addListener(BayeuxServerListener listener) {
Objects.requireNonNull(listener);
_listeners.add(listener);
}
@Override
public ServerChannel getChannel(String channelId) {
return getServerChannel(channelId);
}
private ServerChannelImpl getServerChannel(String channelId) {
ServerChannelImpl channel = _channels.get(channelId);
if (channel != null) {
channel.waitForInitialized();
}
return channel;
}
@Override
public List getChannels() {
List result = new ArrayList<>();
for (ServerChannelImpl channel : _channels.values()) {
channel.waitForInitialized();
result.add(channel);
}
return result;
}
@Override
public void removeListener(BayeuxServerListener listener) {
_listeners.remove(listener);
}
public void handle(ServerSessionImpl session, ServerMessage.Mutable message, Promise promise) {
ServerMessageImpl reply = (ServerMessageImpl)createReply(message);
if (_validation) {
String error = validateMessage(message);
if (error != null) {
error(reply, error);
promise.succeed(reply);
return;
}
}
extendIncoming(session, message, Promise.from(extPass -> {
if (extPass) {
if (session != null) {
session.extendIncoming(message, Promise.from(sessExtPass -> {
if (sessExtPass) {
handle1(session, message, promise);
} else {
if (!reply.isHandled()) {
error(reply, "404::message_deleted");
}
promise.succeed(reply);
}
}, promise::fail));
} else {
handle1(null, message, promise);
}
} else {
if (!reply.isHandled()) {
error(reply, "404::message_deleted");
}
promise.succeed(reply);
}
}, promise::fail));
}
private void handle1(ServerSessionImpl session, ServerMessage.Mutable message, Promise promise) {
if (_logger.isDebugEnabled()) {
_logger.debug("> {} {}", message, session);
}
ServerMessage.Mutable reply = message.getAssociated();
if (session == null || session.isDisconnected() ||
(!session.getId().equals(message.getClientId()) && !Channel.META_HANDSHAKE.equals(message.getChannel()))) {
unknownSession(reply);
promise.succeed(reply);
} else {
String channelName = message.getChannel();
session.cancelExpiration(Channel.META_CONNECT.equals(channelName));
if (channelName == null) {
error(reply, "400::channel_missing");
promise.succeed(reply);
} else {
ServerChannelImpl channel = getServerChannel(channelName);
if (channel == null) {
isCreationAuthorized(session, message, channelName, Promise.from(result -> {
if (result instanceof Authorizer.Result.Denied) {
String denyReason = ((Authorizer.Result.Denied)result).getReason();
error(reply, "403:" + denyReason + ":channel_create_denied");
promise.succeed(reply);
} else {
handle2(session, message, (ServerChannelImpl)createChannelIfAbsent(channelName).getReference(), promise);
}
}, promise::fail));
} else {
handle2(session, message, channel, promise);
}
}
}
}
private void handle2(ServerSessionImpl session, ServerMessage.Mutable message, ServerChannelImpl channel, Promise promise) {
ServerMessage.Mutable reply = message.getAssociated();
if (channel.isMeta()) {
publish(session, channel, message, true, Promise.from(published -> promise.succeed(reply), promise::fail));
} else {
isPublishAuthorized(channel, session, message, Promise.from(result -> {
if (result instanceof Authorizer.Result.Denied) {
String denyReason = ((Authorizer.Result.Denied)result).getReason();
error(reply, "403:" + denyReason + ":publish_denied");
promise.succeed(reply);
} else {
reply.setSuccessful(true);
publish(session, channel, message, true, Promise.from(published -> promise.succeed(reply), promise::fail));
}
}, promise::fail));
}
}
protected String validateMessage(Mutable message) {
String channel = message.getChannel();
if (channel == null) {
return "400::channel_missing";
}
if (!Bayeux.Validator.isValidChannelId(channel)) {
return "405::channel_invalid";
}
String id = message.getId();
if (id != null && !Bayeux.Validator.isValidMessageId(id)) {
return "405::message_id_invalid";
}
return null;
}
private void isPublishAuthorized(ServerChannel channel, ServerSession session, ServerMessage message, Promise promise) {
if (_policy != null) {
_policy.canPublish(this, session, channel, message, Promise.from(can -> {
if (can == null || can) {
isOperationAuthorized(Authorizer.Operation.PUBLISH, session, message, channel.getChannelId(), promise);
} else {
_logger.info("{} denied publish on channel {} by {}", session, channel.getId(), _policy);
promise.succeed(Authorizer.Result.deny("denied_by_security_policy"));
}
}, promise::fail));
} else {
isOperationAuthorized(Authorizer.Operation.PUBLISH, session, message, channel.getChannelId(), promise);
}
}
private void isSubscribeAuthorized(ServerChannel channel, ServerSession session, ServerMessage message, Promise promise) {
if (_policy != null) {
_policy.canSubscribe(this, session, channel, message, Promise.from(can -> {
if (can == null || can) {
isOperationAuthorized(Authorizer.Operation.SUBSCRIBE, session, message, channel.getChannelId(), promise);
} else {
_logger.info("{} denied Subscribe@{} by {}", session, channel, _policy);
promise.succeed(Authorizer.Result.deny("denied_by_security_policy"));
}
}, promise::fail));
} else {
isOperationAuthorized(Authorizer.Operation.SUBSCRIBE, session, message, channel.getChannelId(), promise);
}
}
private void isCreationAuthorized(ServerSession session, ServerMessage message, String channel, Promise promise) {
if (_policy != null) {
_policy.canCreate(BayeuxServerImpl.this, session, channel, message, Promise.from(can -> {
if (can == null || can) {
isOperationAuthorized(Authorizer.Operation.CREATE, session, message, new ChannelId(channel), promise);
} else {
_logger.info("{} denied creation of channel {} by {}", session, channel, _policy);
promise.succeed(Authorizer.Result.deny("denied_by_security_policy"));
}
}, promise::fail));
} else {
isOperationAuthorized(Authorizer.Operation.CREATE, session, message, new ChannelId(channel), promise);
}
}
private void isOperationAuthorized(Authorizer.Operation operation, ServerSession session, ServerMessage message, ChannelId channelId, Promise promise) {
isChannelOperationAuthorized(operation, session, message, channelId, Promise.from(result -> {
if (result == null) {
result = Authorizer.Result.grant();
if (_logger.isDebugEnabled()) {
_logger.debug("No authorizers, {} for channel {} {}", operation, channelId, result);
}
} else {
if (result.isGranted()) {
if (_logger.isDebugEnabled()) {
_logger.debug("No authorizer denied {} for channel {}, authorization {}", operation, channelId, result);
}
} else if (!result.isDenied()) {
result = Authorizer.Result.deny("denied_by_not_granting");
if (_logger.isDebugEnabled()) {
_logger.debug("No authorizer granted {} for channel {}, authorization {}", operation, channelId, result);
}
}
}
promise.succeed(result);
}, promise::fail));
}
private void isChannelOperationAuthorized(Authorizer.Operation operation, ServerSession session, ServerMessage message, ChannelId channelId, Promise promise) {
AsyncFoldLeft.reverseRun(channelId.getAllIds(), null, (result, channelName, loop) -> {
ServerChannelImpl channel = _channels.get(channelName);
if (channel != null) {
isChannelOperationAuthorized(channel, operation, session, message, channelId, Promise.from(authz -> {
if (authz != null) {
if (authz.isDenied()) {
loop.leave(authz);
} else if (result == null || authz.isGranted()) {
loop.proceed(authz);
} else {
loop.proceed(result);
}
} else {
loop.proceed(result);
}
}, promise::fail));
} else {
loop.proceed(result);
}
}, promise);
}
private void isChannelOperationAuthorized(ServerChannelImpl channel, Authorizer.Operation operation, ServerSession session, ServerMessage message, ChannelId channelId, Promise promise) {
List authorizers = channel.authorizers();
if (authorizers.isEmpty()) {
promise.succeed(null);
} else {
AsyncFoldLeft.run(authorizers, Authorizer.Result.ignore(), (result, authorizer, loop) ->
authorizer.authorize(operation, channelId, session, message, Promise.from(authorization -> {
if (_logger.isDebugEnabled()) {
_logger.debug("Authorizer {} on channel {} {} {} for channel {}", authorizer, channel, authorization, operation, channelId);
}
if (authorization.isDenied()) {
loop.leave(authorization);
} else if (authorization.isGranted()) {
loop.proceed(authorization);
} else {
loop.proceed(result);
}
}, promise::fail)), promise);
}
}
protected void publish(ServerSessionImpl session, ServerChannelImpl channel, ServerMessage.Mutable message, boolean receiving, Promise promise) {
if (_logger.isDebugEnabled()) {
_logger.debug("< {} {}", message, session);
}
if (channel.isBroadcast()) {
// Do not leak the clientId to other subscribers
// as we are now "sending" this message.
message.setClientId(null);
// Reset the messageId to avoid clashes with message-based transports such
// as websocket whose clients may rely on the messageId to match request/responses.
message.setId(null);
}
notifyListeners(session, channel, message, Promise.from(proceed -> {
if (proceed) {
publish1(session, channel, message, receiving, promise);
} else {
ServerMessageImpl reply = (ServerMessageImpl)message.getAssociated();
if (reply != null && !reply.isHandled()) {
error(reply, "404::message_deleted");
}
promise.succeed(false);
}
}, promise::fail));
}
private void publish1(ServerSessionImpl session, ServerChannelImpl channel, ServerMessage.Mutable message, boolean receiving, Promise promise) {
if (channel.isBroadcast() || !receiving) {
extendOutgoing(session, null, message, Promise.from(result -> {
if (result) {
// Exactly at this point, we convert the message to JSON and therefore
// any further modification will be lost.
// This is an optimization so that if the message is sent to a million
// subscribers, we generate the JSON only once.
// From now on, user code is passed a ServerMessage reference (and not
// ServerMessage.Mutable), and we attempt to return immutable data
// structures, even if it is not possible to guard against all cases.
// For example, it is impossible to prevent things like
// ((CustomObject)serverMessage.getData()).change() or
// ((Map)serverMessage.getExt().get("map")).put().
freeze(message);
publish2(session, channel, message, promise);
} else {
ServerMessage.Mutable reply = message.getAssociated();
error(reply, "404::message_deleted");
promise.succeed(false);
}
}, promise::fail));
} else {
publish2(session, channel, message, promise);
}
}
private void publish2(ServerSessionImpl session, ServerChannelImpl channel, ServerMessage.Mutable message, Promise promise) {
if (channel.isMeta()) {
notifyMetaHandlers(session, channel, message, promise);
} else if (channel.isBroadcast()) {
notifySubscribers(session, channel, message, promise);
} else {
promise.succeed(true);
}
}
private void notifySubscribers(ServerSessionImpl session, ServerChannelImpl serverChannel, Mutable message, Promise promise) {
// Both the client and the server know their subscriptions, say to /chat/* and /chat/news.
// The server wants to avoid to send the same message multiple times to the same subscriber,
// if that subscriber is subscribed to both /chat/* and /chat/news; that's why a Set is used
// to avoid sending the same message multiple times to the same subscriber.
// When the client receives the message, it can fan out the message to its subscriptions.
Set subscriberIds = new HashSet<>();
// The notification flows from the root of the channel tree,
// i.e. from the /** channel to the exact message channel.
List channels = serverChannel.getChannelId().getAllIds();
AsyncFoldLeft.reverseRun(channels, false, (result, channelName, channelLoop) -> {
ServerChannelImpl channel = _channels.get(channelName);
if (channel == null) {
channelLoop.proceed(result);
} else {
Set subscribers = channel.subscribers();
if (_logger.isDebugEnabled()) {
_logger.debug("Notifying {} subscribers on {}", subscribers.size(), channel);
}
AsyncFoldLeft.run(subscribers, false, (r, subscriber, loop) -> {
String subscriberId = subscriber.getId();
if (subscriberIds.contains(subscriberId)) {
loop.proceed(r);
} else {
if (subscriber == session && !channel.isBroadcastToPublisher()) {
loop.proceed(r);
} else {
((ServerSessionImpl)subscriber).deliver1(session, message, Promise.from(delivered -> {
if (delivered) {
subscriberIds.add(subscriberId);
}
loop.proceed(r || delivered);
}, loop::fail));
}
}
}, Promise.from(delivered -> channelLoop.proceed(result || delivered), channelLoop::fail));
}
}, promise);
}
private void notifyListeners(ServerSessionImpl session, ServerChannelImpl channel, Mutable message, Promise promise) {
AsyncFoldLeft.reverseRun(channel.getChannelId().getAllIds(), true, (channelResult, channelName, channelLoop) -> {
ServerChannelImpl target = _channels.get(channelName);
if (target == null) {
channelLoop.proceed(channelResult);
} else {
if (target.isLazy()) {
message.setLazy(true);
}
List listeners = target.listeners();
if (_logger.isDebugEnabled()) {
_logger.debug("Notifying {} listeners on {}", listeners.size(), target);
}
AsyncFoldLeft.run(listeners, true, (result, listener, loop) -> {
if (listener instanceof MessageListener) {
notifyOnMessage((MessageListener)listener, session, channel, message, resolveLoop(loop));
} else {
loop.proceed(true);
}
}, resolveLoop(channelLoop));
}
}, promise);
}
protected Promise resolveLoop(AsyncFoldLeft.Loop loop) {
return Promise.from(result -> {
if (result) {
loop.proceed(true);
} else {
loop.leave(false);
}
}, loop::fail);
}
private void notifyMetaHandlers(ServerSessionImpl session, ServerChannelImpl channel, Mutable message, Promise promise) {
switch (channel.getId()) {
case Channel.META_HANDSHAKE:
handleMetaHandshake(session, message, promise);
break;
case Channel.META_CONNECT:
handleMetaConnect(session, message, promise);
break;
case Channel.META_SUBSCRIBE:
handleMetaSubscribe(session, message, promise);
break;
case Channel.META_UNSUBSCRIBE:
handleMetaUnsubscribe(session, message, promise);
break;
case Channel.META_DISCONNECT:
handleMetaDisconnect(session, message, promise);
break;
default:
promise.fail(new IllegalStateException("Invalid channel " + channel));
break;
}
}
public void freeze(Mutable mutable) {
if (mutable instanceof ServerMessageImpl) {
ServerMessageImpl message = (ServerMessageImpl)mutable;
if (message.isFrozen()) {
return;
}
String json = _jsonContext.generate(message);
message.freeze(json);
}
}
private void notifyOnMessage(MessageListener listener, ServerSession from, ServerChannel to, Mutable mutable, Promise promise) {
try {
listener.onMessage(from, to, mutable, Promise.from(r -> promise.succeed(r == null || r), failure -> {
_logger.info("Exception reported by listener " + listener, failure);
promise.succeed(true);
}));
} catch (Throwable x) {
_logger.info("Exception thrown by listener " + listener, x);
promise.succeed(true);
}
}
private void extendIncoming(ServerSessionImpl session, ServerMessage.Mutable message, Promise promise) {
AsyncFoldLeft.run(_extensions, true, (result, extension, loop) -> {
if (result) {
try {
extension.incoming(session, message, Promise.from(r -> {
if (_logger.isDebugEnabled()) {
_logger.debug("Extension {}: result {} for incoming message {}", extension, r, message);
}
loop.proceed(r == null || r);
}, failure -> {
_logger.info("Exception reported by extension " + extension, failure);
loop.proceed(true);
}));
} catch (Throwable x) {
_logger.info("Exception thrown by extension " + extension, x);
loop.proceed(true);
}
} else {
loop.leave(false);
}
}, promise);
}
protected void extendOutgoing(ServerSession sender, ServerSession session, Mutable message, Promise promise) {
AsyncFoldLeft.reverseRun(_extensions, true, (result, extension, loop) -> {
if (result) {
try {
extension.outgoing(sender, session, message, Promise.from(r -> loop.proceed(r == null || r), failure -> {
_logger.info("Exception reported by extension " + extension, failure);
loop.proceed(true);
}));
} catch (Throwable x) {
_logger.info("Exception thrown by extension " + extension, x);
loop.proceed(true);
}
} else {
loop.leave(false);
}
}, promise);
}
public void extendReply(ServerSessionImpl sender, ServerSessionImpl session, ServerMessage.Mutable reply, Promise promise) {
if (_logger.isDebugEnabled()) {
_logger.debug("<< {} {}", reply, sender);
}
extendOutgoing(sender, session, reply, Promise.from(b -> {
if (b) {
if (session != null) {
session.extendOutgoing(sender, reply, promise);
} else {
promise.succeed(reply);
}
} else {
promise.succeed(null);
}
}, promise::fail));
}
protected boolean removeServerChannel(ServerChannelImpl channel) {
if (_channels.remove(channel.getId(), channel)) {
if (_logger.isDebugEnabled()) {
_logger.debug("Removed channel {}", channel);
}
for (BayeuxServerListener listener : _listeners) {
if (listener instanceof BayeuxServer.ChannelListener) {
notifyChannelRemoved((ChannelListener)listener, channel);
}
}
return true;
}
return false;
}
private void notifyChannelRemoved(ChannelListener listener, ServerChannelImpl channel) {
try {
listener.channelRemoved(channel.getId());
} catch (Throwable x) {
_logger.info("Exception while invoking listener " + listener, x);
}
}
protected List getListeners() {
return List.copyOf(_listeners);
}
@Override
public Set getKnownTransportNames() {
return _transports.keySet();
}
@Override
public ServerTransport getTransport(String transport) {
return _transports.get(transport);
}
public ServerTransport addTransport(ServerTransport transport) {
ServerTransport result = _transports.put(transport.getName(), transport);
if (_logger.isDebugEnabled()) {
_logger.debug("Added transport {} from {}", transport.getName(), transport.getClass());
}
return result;
}
public void setTransports(ServerTransport... transports) {
setTransports(List.of(transports));
}
public void setTransports(List transports) {
_transports.clear();
for (ServerTransport transport : transports) {
addTransport(transport);
}
}
public List getTransports() {
return new ArrayList<>(_transports.values());
}
protected AbstractHttpTransport findHttpTransport(HttpServletRequest request) {
for (String transportName : _allowedTransports) {
ServerTransport serverTransport = getTransport(transportName);
if (serverTransport instanceof AbstractHttpTransport) {
AbstractHttpTransport transport = (AbstractHttpTransport)serverTransport;
if (transport.accept(request)) {
return transport;
}
}
}
return null;
}
@ManagedAttribute(value = "The transports allowed by this CometD server", readonly = true)
@Override
public List getAllowedTransports() {
return List.copyOf(_allowedTransports);
}
public void setAllowedTransports(String... allowed) {
setAllowedTransports(List.of(allowed));
}
public void setAllowedTransports(List allowed) {
if (_logger.isDebugEnabled()) {
_logger.debug("setAllowedTransport {} of {}", allowed, _transports);
}
_allowedTransports.clear();
for (String transport : allowed) {
if (_transports.containsKey(transport)) {
_allowedTransports.add(transport);
}
}
if (_logger.isDebugEnabled()) {
_logger.debug("allowedTransports {}", _allowedTransports);
}
}
@ManagedAttribute(value = "Whether this CometD server broadcast messages to the publisher", readonly = true)
public boolean isBroadcastToPublisher() {
return _broadcastToPublisher;
}
protected void unknownSession(Mutable reply) {
error(reply, "402::session_unknown");
if (Channel.META_HANDSHAKE.equals(reply.getChannel()) || Channel.META_CONNECT.equals(reply.getChannel())) {
Map advice = reply.getAdvice(true);
advice.put(Message.RECONNECT_FIELD, Message.RECONNECT_HANDSHAKE_VALUE);
advice.put(Message.INTERVAL_FIELD, 0L);
}
}
protected void error(ServerMessage.Mutable reply, String error) {
if (reply != null) {
reply.put(Message.ERROR_FIELD, error);
reply.setSuccessful(false);
}
}
protected ServerMessage.Mutable createReply(ServerMessage.Mutable message) {
ServerMessageImpl reply = (ServerMessageImpl)newMessage();
message.setAssociated(reply);
reply.setAssociated(message);
reply.setServerTransport(message.getServerTransport());
reply.setBayeuxContext(message.getBayeuxContext());
reply.setChannel(message.getChannel());
String id = message.getId();
if (id != null) {
reply.setId(id);
}
Object subscriptions = message.get(Message.SUBSCRIPTION_FIELD);
if (subscriptions != null) {
reply.put(Message.SUBSCRIPTION_FIELD, subscriptions);
}
return reply;
}
private boolean validateSubscriptions(List subscriptions) {
if (_validation) {
for (String subscription : subscriptions) {
if (!Bayeux.Validator.isValidChannelId(subscription)) {
return false;
}
}
}
return true;
}
@ManagedOperation(value = "Sweeps channels and sessions of this CometD server", impact = "ACTION")
public void sweep() {
sweepTransports();
_channels.values().forEach(ServerChannelImpl::sweep);
long now = System.nanoTime();
for (ServerSessionImpl session : _sessions.values()) {
session.sweep(now);
}
}
CompletableFuture asyncSweep() {
return _sweeper.asyncSweep();
}
private void sweepTransports() {
for (ServerTransport transport : _transports.values()) {
if (transport instanceof AbstractServerTransport) {
((AbstractServerTransport)transport).sweep();
}
}
}
@ManagedAttribute("Reports additional details in the dump() operation")
public boolean isDetailedDump() {
return _detailedDump;
}
public void setDetailedDump(boolean detailedDump) {
_detailedDump = detailedDump;
}
@ManagedAttribute("The period, in milliseconds, of the sweeping activity performed by the server")
public long getSweepPeriod()
{
return _sweepPeriod;
}
public void setSweepPeriod(long sweepPeriod)
{
if (sweepPeriod < 0) {
sweepPeriod = DEFAULT_SWEEP_PERIOD;
}
_sweepPeriod = sweepPeriod;
}
@ManagedAttribute("The maximum number of threads that can be used by the sweeping activity performed by the server")
public int getSweepThreads()
{
return _sweepThreads;
}
public void setSweepThreads(int sweepThreads)
{
if (sweepThreads < 1) {
sweepThreads = DEFAULT_SWEEP_THREADS;
}
_sweepThreads = sweepThreads;
}
@Override
public void dump(Appendable out, String indent) throws IOException {
long before = System.nanoTime();
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy