org.springframework.messaging.simp.user.MultiServerUserRegistry Maven / Gradle / Ivy
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.user;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.Ordered;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
/**
* {@code SimpUserRegistry} that looks up users in a "local" user registry as
* well as a set of "remote" user registries. The local registry is provided as
* a constructor argument while remote registries are updated via broadcasts
* handled by {@link UserRegistryMessageHandler} which in turn notifies this
* registry when updates are received.
*
* @author Rossen Stoyanchev
* @since 4.2
*/
@SuppressWarnings("serial")
public class MultiServerUserRegistry implements SimpUserRegistry, SmartApplicationListener {
private final String id;
private final SimpUserRegistry localRegistry;
private final Map remoteRegistries = new ConcurrentHashMap<>();
private final boolean delegateApplicationEvents;
/* Cross-server session lookup (e.g. same user connected to multiple servers) */
private final SessionLookup sessionLookup = new SessionLookup();
/**
* Create an instance wrapping the local user registry.
*/
public MultiServerUserRegistry(SimpUserRegistry localRegistry) {
Assert.notNull(localRegistry, "'localRegistry' is required.");
this.id = generateId();
this.localRegistry = localRegistry;
this.delegateApplicationEvents = this.localRegistry instanceof SmartApplicationListener;
}
private static String generateId() {
String host;
try {
host = InetAddress.getLocalHost().getHostAddress();
}
catch (UnknownHostException ex) {
host = "unknown";
}
return host + '-' + UUID.randomUUID();
}
@Override
public int getOrder() {
return (this.delegateApplicationEvents ?
((SmartApplicationListener) this.localRegistry).getOrder() : Ordered.LOWEST_PRECEDENCE);
}
// SmartApplicationListener methods
@Override
public boolean supportsEventType(Class eventType) {
return (this.delegateApplicationEvents &&
((SmartApplicationListener) this.localRegistry).supportsEventType(eventType));
}
@Override
public boolean supportsSourceType(@Nullable Class sourceType) {
return (this.delegateApplicationEvents &&
((SmartApplicationListener) this.localRegistry).supportsSourceType(sourceType));
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
if (this.delegateApplicationEvents) {
((SmartApplicationListener) this.localRegistry).onApplicationEvent(event);
}
}
// SimpUserRegistry methods
@Override
@Nullable
public SimpUser getUser(String userName) {
// Prefer remote registries due to cross-server SessionLookup
for (UserRegistrySnapshot registry : this.remoteRegistries.values()) {
SimpUser user = registry.getUserMap().get(userName);
if (user != null) {
return user;
}
}
return this.localRegistry.getUser(userName);
}
@Override
public Set getUsers() {
// Prefer remote registries due to cross-server SessionLookup
Set result = new HashSet<>();
for (UserRegistrySnapshot registry : this.remoteRegistries.values()) {
result.addAll(registry.getUserMap().values());
}
result.addAll(this.localRegistry.getUsers());
return result;
}
@Override
public int getUserCount() {
int userCount = 0;
for (UserRegistrySnapshot registry : this.remoteRegistries.values()) {
userCount += registry.getUserMap().size();
}
userCount += this.localRegistry.getUserCount();
return userCount;
}
@Override
public Set findSubscriptions(SimpSubscriptionMatcher matcher) {
Set result = new HashSet<>();
for (UserRegistrySnapshot registry : this.remoteRegistries.values()) {
result.addAll(registry.findSubscriptions(matcher));
}
result.addAll(this.localRegistry.findSubscriptions(matcher));
return result;
}
// Internal methods for UserRegistryMessageHandler to manage broadcasts
Object getLocalRegistryDto() {
return new UserRegistrySnapshot(this.id, this.localRegistry);
}
void addRemoteRegistryDto(Message message, MessageConverter converter, long expirationPeriod) {
UserRegistrySnapshot registry = (UserRegistrySnapshot) converter.fromMessage(message, UserRegistrySnapshot.class);
if (registry != null && !registry.getId().equals(this.id)) {
registry.init(expirationPeriod, this.sessionLookup);
this.remoteRegistries.put(registry.getId(), registry);
}
}
void purgeExpiredRegistries() {
long now = System.currentTimeMillis();
this.remoteRegistries.entrySet().removeIf(entry -> entry.getValue().isExpired(now));
}
@Override
public String toString() {
return "local=[" + this.localRegistry + "], remote=" + this.remoteRegistries;
}
/**
* Holds a copy of a SimpUserRegistry for the purpose of broadcasting to and
* receiving broadcasts from other application servers.
*/
private static class UserRegistrySnapshot {
private String id = "";
private Map users = Collections.emptyMap();
private long expirationTime;
/**
* Default constructor for JSON deserialization.
*/
@SuppressWarnings("unused")
public UserRegistrySnapshot() {
}
/**
* Constructor to create DTO from a local user registry.
*/
public UserRegistrySnapshot(String id, SimpUserRegistry registry) {
this.id = id;
Set users = registry.getUsers();
this.users = new HashMap<>(users.size());
for (SimpUser user : users) {
this.users.put(user.getName(), new TransferSimpUser(user));
}
}
public void setId(String id) {
this.id = id;
}
public String getId() {
return this.id;
}
public void setUserMap(Map users) {
this.users = users;
}
public Map getUserMap() {
return this.users;
}
public boolean isExpired(long now) {
return (now > this.expirationTime);
}
public void init(long expirationPeriod, SessionLookup sessionLookup) {
this.expirationTime = System.currentTimeMillis() + expirationPeriod;
for (TransferSimpUser user : this.users.values()) {
user.afterDeserialization(sessionLookup);
}
}
public Set findSubscriptions(SimpSubscriptionMatcher matcher) {
Set result = new HashSet<>();
for (TransferSimpUser user : this.users.values()) {
for (TransferSimpSession session : user.sessions) {
for (SimpSubscription subscription : session.subscriptions) {
if (matcher.match(subscription)) {
result.add(subscription);
}
}
}
}
return result;
}
@Override
public String toString() {
return "id=" + this.id + ", users=" + this.users;
}
}
/**
* SimpUser that can be (de)serialized and broadcast to other servers.
*/
private static class TransferSimpUser implements SimpUser {
private String name = "";
// User sessions from "this" registry only (i.e. one server)
private Set sessions;
// Cross-server session lookup (e.g. user connected to multiple servers)
@Nullable
private SessionLookup sessionLookup;
/**
* Default constructor for JSON deserialization.
*/
@SuppressWarnings("unused")
public TransferSimpUser() {
this.sessions = new HashSet<>(1);
}
/**
* Constructor to create user from a local user.
*/
public TransferSimpUser(SimpUser user) {
this.name = user.getName();
Set sessions = user.getSessions();
this.sessions = new HashSet<>(sessions.size());
for (SimpSession session : sessions) {
this.sessions.add(new TransferSimpSession(session));
}
}
public void setName(String name) {
this.name = name;
}
@Override
public String getName() {
return this.name;
}
@Override
public boolean hasSessions() {
if (this.sessionLookup != null) {
return !this.sessionLookup.findSessions(getName()).isEmpty();
}
return !this.sessions.isEmpty();
}
@Override
@Nullable
public SimpSession getSession(String sessionId) {
if (this.sessionLookup != null) {
return this.sessionLookup.findSessions(getName()).get(sessionId);
}
for (TransferSimpSession session : this.sessions) {
if (session.getId().equals(sessionId)) {
return session;
}
}
return null;
}
public void setSessions(Set sessions) {
this.sessions.addAll(sessions);
}
@Override
public Set getSessions() {
if (this.sessionLookup != null) {
Map sessions = this.sessionLookup.findSessions(getName());
return new HashSet<>(sessions.values());
}
return new HashSet<>(this.sessions);
}
private void afterDeserialization(SessionLookup sessionLookup) {
this.sessionLookup = sessionLookup;
for (TransferSimpSession session : this.sessions) {
session.setUser(this);
session.afterDeserialization();
}
}
private void addSessions(Map map) {
for (SimpSession session : this.sessions) {
map.put(session.getId(), session);
}
}
@Override
public boolean equals(Object other) {
return (this == other || (other instanceof SimpUser && this.name.equals(((SimpUser) other).getName())));
}
@Override
public int hashCode() {
return this.name.hashCode();
}
@Override
public String toString() {
return "name=" + this.name + ", sessions=" + this.sessions;
}
}
/**
* SimpSession that can be (de)serialized and broadcast to other servers.
*/
private static class TransferSimpSession implements SimpSession {
private String id;
private TransferSimpUser user;
private final Set subscriptions;
/**
* Default constructor for JSON deserialization.
*/
@SuppressWarnings("unused")
public TransferSimpSession() {
this.id = "";
this.user = new TransferSimpUser();
this.subscriptions = new HashSet<>(4);
}
/**
* Constructor to create DTO from the local user session.
*/
public TransferSimpSession(SimpSession session) {
this.id = session.getId();
this.user = new TransferSimpUser();
Set subscriptions = session.getSubscriptions();
this.subscriptions = new HashSet<>(subscriptions.size());
for (SimpSubscription subscription : subscriptions) {
this.subscriptions.add(new TransferSimpSubscription(subscription));
}
}
public void setId(String id) {
this.id = id;
}
@Override
public String getId() {
return this.id;
}
public void setUser(TransferSimpUser user) {
this.user = user;
}
@Override
public TransferSimpUser getUser() {
return this.user;
}
public void setSubscriptions(Set subscriptions) {
this.subscriptions.addAll(subscriptions);
}
@Override
public Set getSubscriptions() {
return new HashSet<>(this.subscriptions);
}
private void afterDeserialization() {
for (TransferSimpSubscription subscription : this.subscriptions) {
subscription.setSession(this);
}
}
@Override
public boolean equals(Object other) {
return (this == other || (other instanceof SimpSession && getId().equals(((SimpSession) other).getId())));
}
@Override
public int hashCode() {
return getId().hashCode();
}
@Override
public String toString() {
return "id=" + this.id + ", subscriptions=" + this.subscriptions;
}
}
/**
* SimpSubscription that can be (de)serialized and broadcast to other servers.
*/
private static class TransferSimpSubscription implements SimpSubscription {
private String id;
private TransferSimpSession session;
private String destination;
/**
* Default constructor for JSON deserialization.
*/
@SuppressWarnings("unused")
public TransferSimpSubscription() {
this.id = "";
this.session = new TransferSimpSession();
this.destination = "";
}
/**
* Constructor to create DTO from a local user subscription.
*/
public TransferSimpSubscription(SimpSubscription subscription) {
this.id = subscription.getId();
this.session = new TransferSimpSession();
this.destination = subscription.getDestination();
}
public void setId(String id) {
this.id = id;
}
@Override
public String getId() {
return this.id;
}
public void setSession(TransferSimpSession session) {
this.session = session;
}
@Override
public TransferSimpSession getSession() {
return this.session;
}
public void setDestination(String destination) {
this.destination = destination;
}
@Override
public String getDestination() {
return this.destination;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof SimpSubscription)) {
return false;
}
SimpSubscription otherSubscription = (SimpSubscription) other;
return (getId().equals(otherSubscription.getId()) &&
ObjectUtils.nullSafeEquals(getSession(), otherSubscription.getSession()));
}
@Override
public int hashCode() {
return getId().hashCode() * 31 + ObjectUtils.nullSafeHashCode(getSession());
}
@Override
public String toString() {
return "destination=" + this.destination;
}
}
/**
* Helper class to find user sessions across all servers.
*/
private class SessionLookup {
public Map findSessions(String userName) {
Map map = new HashMap<>(4);
SimpUser user = localRegistry.getUser(userName);
if (user != null) {
for (SimpSession session : user.getSessions()) {
map.put(session.getId(), session);
}
}
for (UserRegistrySnapshot registry : remoteRegistries.values()) {
TransferSimpUser transferUser = registry.getUserMap().get(userName);
if (transferUser != null) {
transferUser.addSessions(map);
}
}
return map;
}
}
}