org.jgroups.protocols.STOMP Maven / Gradle / Ivy
package org.jgroups.protocols;
import org.jgroups.*;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.Property;
import org.jgroups.stack.Protocol;
import org.jgroups.util.MessageBatch;
import org.jgroups.util.StackType;
import org.jgroups.util.UUID;
import org.jgroups.util.Util;
import java.io.*;
import java.net.*;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Supplier;
/**
* Protocol which provides STOMP (http://stomp.codehaus.org/) support. Very simple implementation, with a
* one-thread-per-connection model. Use for a few hundred clients max.
* The intended use for this protocol is pub-sub with clients which handle text messages, e.g. stock updates,
* SMS messages to mobile clients, SNMP traps etc.
* Note that the full STOMP protocol has not yet been implemented, e.g. transactions are not supported.
* todo: use a thread pool to handle incoming frames and to send messages to clients
*
* todo: add PING to test health of client connections
*
* @author Bela Ban
* @since 2.11
*/
@MBean(description="Server side STOPM protocol, STOMP clients can connect to it")
public class STOMP extends Protocol implements Runnable {
/* ----------------------------------------- Properties ----------------------------------------------- */
@Property(name="bind_addr",
description="The bind address which should be used by the server socket. The following special values " +
"are also recognized: GLOBAL, SITE_LOCAL, LINK_LOCAL and NON_LOOPBACK",
defaultValueIPv4="0.0.0.0", defaultValueIPv6="::", writable=false)
protected InetAddress bind_addr;
@Property(description="If set, then endpoint will be set to this address")
protected String endpoint_addr;
@Property(description="Port on which the STOMP protocol listens for requests",writable=false)
protected int port=8787;
@Property(description="If set to false, then a destination of /a/b match /a/b/c, a/b/d, a/b/c/d etc")
protected boolean exact_destination_match=true;
@Property(description="If true, information such as a list of endpoints, or views, will be sent to all clients " +
"(via the INFO command). This allows for example intelligent clients to connect to " +
"a different server should a connection be closed.")
protected boolean send_info=true;
@Property(description="Forward received messages which don't have a StompHeader to clients")
protected boolean forward_non_client_generated_msgs=false;
/* --------------------------------------------- JMX ---------------------------------------------------*/
@ManagedAttribute(description="Number of client connections")
public int getNumConnections() {return connections.size();}
@ManagedAttribute(description="Number of subscriptions")
public int getNumSubscriptions() {return subscriptions.size();}
@ManagedAttribute(description="Print subscriptions")
public String getSubscriptions() {return subscriptions.keySet().toString();}
@ManagedAttribute
public String getEndpoints() {return endpoints.toString();}
/* --------------------------------------------- Fields ------------------------------------------------------ */
protected ServerSocket srv_sock;
@ManagedAttribute
protected String endpoint;
protected Thread acceptor;
protected final List connections=new LinkedList<>();
protected final Map endpoints=new HashMap<>();
protected View view;
// Subscriptions and connections which are subscribed
protected final ConcurrentMap> subscriptions=Util.createConcurrentMap(20);
public enum ClientVerb {CONNECT, SEND, SUBSCRIBE, UNSUBSCRIBE, BEGIN, COMMIT, ABORT, ACK, DISCONNECT}
public enum ServerVerb {MESSAGE, RECEIPT, ERROR, CONNECTED, INFO}
public static final byte NULL_BYTE=0;
public STOMP() {
}
public void start() throws Exception {
super.start();
srv_sock=Util.createServerSocket(getSocketFactory(), "jgroups.stomp.srv_sock", bind_addr,
port, port+50, 0);
if(log.isDebugEnabled())
log.debug("server socket listening on " + srv_sock.getLocalSocketAddress());
if(acceptor == null) {
acceptor=getThreadFactory().newThread(this, "STOMP acceptor");
acceptor.setDaemon(true);
acceptor.start();
}
endpoint=endpoint_addr != null? endpoint_addr : getAddress(Util.getIpStackType());
}
public void stop() {
if(log.isDebugEnabled())
log.debug("closing server socket " + srv_sock.getLocalSocketAddress());
if(acceptor != null && acceptor.isAlive()) {
try {
// this will terminate thread, peer will receive SocketException (socket close)
getSocketFactory().close(srv_sock);
}
catch(Exception ex) {
}
}
synchronized(connections) {
connections.forEach(Connection::stop);
connections.clear();
}
acceptor=null;
super.stop();
}
// Acceptor loop
public void run() {
Socket client_sock;
while(acceptor != null && srv_sock != null) {
try {
client_sock=srv_sock.accept();
if(log.isTraceEnabled())
log.trace("accepted connection from " + client_sock.getInetAddress() + ':' + client_sock.getPort());
Connection conn=new Connection(client_sock);
Thread thread=getThreadFactory().newThread(conn, "STOMP client connection");
thread.setDaemon(true);
synchronized(connections) {
connections.add(conn);
}
thread.start();
conn.sendInfo();
}
catch(IOException io_ex) {
break;
}
}
acceptor=null;
}
public Object down(Event evt) {
switch(evt.getType()) {
case Event.VIEW_CHANGE:
handleView(evt.getArg());
break;
}
return down_prot.down(evt);
}
public Object up(Event evt) {
switch(evt.getType()) {
case Event.VIEW_CHANGE:
handleView(evt.getArg());
break;
}
return up_prot.up(evt);
}
public Object up(Message msg) {
StompHeader hdr=msg.getHeader(id);
if(hdr == null) {
if(forward_non_client_generated_msgs) {
HashMap hdrs=new HashMap<>();
hdrs.put("sender", msg.getSrc().toString());
sendToClients(hdrs, msg.getArray(), msg.getOffset(), msg.getLength());
}
return up_prot.up(msg);
}
switch(hdr.type) {
case MESSAGE:
sendToClients(hdr.headers, msg.getArray(), msg.getOffset(), msg.getLength());
break;
case ENDPOINT:
String tmp_endpoint=hdr.headers.get("endpoint");
if(tmp_endpoint != null) {
boolean update_clients;
String old_endpoint=null;
synchronized(endpoints) {
endpoints.put(msg.getSrc(), tmp_endpoint);
}
update_clients=!Objects.equals(old_endpoint, tmp_endpoint);
if(update_clients && this.send_info) {
synchronized(connections) {
for(Connection conn: connections) {
conn.writeResponse(ServerVerb.INFO, "endpoints", getAllEndpoints());
}
}
}
}
return null;
default:
throw new IllegalArgumentException("type " + hdr.type + " is not known");
}
return up_prot.up(msg);
}
public void up(MessageBatch batch) {
for(Iterator it=batch.iterator(); it.hasNext();) {
Message msg=it.next();
StompHeader hdr=msg.getHeader(id);
if(hdr != null || forward_non_client_generated_msgs) {
try {
it.remove();
up(msg);
}
catch(Throwable t) {
log.error(Util.getMessage("FailedPassingUpMessage"), t);
}
}
}
if(!batch.isEmpty())
up_prot.up(batch);
}
public static Frame readFrame(DataInputStream in) throws IOException {
String verb=Util.readLine(in);
if(verb == null)
throw new EOFException("reading verb");
if(verb.isEmpty())
return null;
verb=verb.trim();
Map headers=new HashMap<>();
byte[] body=null;
for(;;) {
String header=Util.readLine(in);
if(header == null)
throw new EOFException("reading header");
if(header.isEmpty())
break;
int index=header.indexOf(':');
if(index != -1)
headers.put(header.substring(0, index).trim(), header.substring(index+1).trim());
}
if(headers.containsKey("content-length")) {
int length=Integer.parseInt(headers.get("content-length"));
body=new byte[length];
in.read(body, 0, body.length);
}
else {
ByteBuffer buf=ByteBuffer.allocate(500);
boolean terminate=false;
for(;;) {
int c=in.read();
if(c == -1 || c == 0)
terminate=true;
if(buf.remaining() == 0 || terminate) {
if(body == null) {
body=new byte[buf.position()];
System.arraycopy(buf.array(), buf.arrayOffset(), body, 0, buf.position());
}
else {
byte[] tmp=new byte[body.length + buf.position()];
System.arraycopy(body, 0, tmp, 0, body.length);
try {
System.arraycopy(buf.array(), buf.arrayOffset(), tmp, body.length, buf.position());
}
catch(Throwable t) {
}
body=tmp;
}
buf.rewind();
}
if(terminate)
break;
buf.put((byte)c);
}
}
return new Frame(verb, headers, body);
}
protected void handleView(View view) {
broadcastEndpoint();
List mbrs=view.getMembers();
this.view=view;
synchronized(endpoints) {
endpoints.keySet().retainAll(mbrs);
}
synchronized(connections) {
connections.forEach(Connection::sendInfo);
}
}
private String getAddress(StackType ip_version) {
InetSocketAddress saddr=(InetSocketAddress)srv_sock.getLocalSocketAddress();
InetAddress tmp=saddr.getAddress();
if(!tmp.isAnyLocalAddress())
return tmp.getHostAddress() + ":" + srv_sock.getLocalPort();
for(Util.AddressScope scope: Util.AddressScope.values()) {
try {
InetAddress addr=Util.getAddress(scope, ip_version);
if(addr != null) return addr.getHostAddress() + ":" + srv_sock.getLocalPort();
}
catch(SocketException e) {
}
}
return null;
}
protected String getAllEndpoints() {
synchronized(endpoints) {
return Util.printListWithDelimiter(endpoints.values(), ",");
}
}
protected void broadcastEndpoint() {
if(endpoint != null) {
Message msg=new EmptyMessage().putHeader(id, StompHeader.createHeader(StompHeader.Type.ENDPOINT, "endpoint", endpoint));
down_prot.down(msg);
}
}
private void sendToClients(Map headers, byte[] buffer, int offset, int length) {
int len=50 + length + (ServerVerb.MESSAGE.name().length() + 2);
if(headers != null) {
for(Map.Entry entry: headers.entrySet()) {
len+=entry.getKey().length() +2;
len+=entry.getValue().length() +2;
len+=5; // fill chars, such as ": " or "\n"
}
}
len+=(buffer != null? 20 : 0);
ByteBuffer buf=ByteBuffer.allocate(len);
StringBuilder sb=new StringBuilder(ServerVerb.MESSAGE.name()).append("\n");
if(headers != null) {
for(Map.Entry entry: headers.entrySet())
sb.append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
}
if(buffer != null)
sb.append("content-length: ").append(length).append("\n");
sb.append("\n");
byte[] tmp=sb.toString().getBytes();
if(buffer != null) {
buf.put(tmp, 0, tmp.length);
buf.put(buffer, offset, length);
}
buf.put(NULL_BYTE);
final Set target_connections=new HashSet<>();
String destination=headers != null? headers.get("destination") : null;
if(destination == null) {
synchronized(connections) {
target_connections.addAll(connections);
}
}
else {
if(!exact_destination_match) {
subscriptions.entrySet().stream().filter(entry -> entry.getKey().startsWith(destination))
.forEach(entry -> target_connections.addAll(entry.getValue()));
}
else {
Set conns=subscriptions.get(destination);
if(conns != null)
target_connections.addAll(conns);
}
}
for(Connection conn: target_connections)
conn.writeResponse(buf.array(), buf.arrayOffset(), buf.position());
}
/**
* Class which handles a connection to a client
*/
public class Connection implements Runnable {
protected final Socket sock;
protected final DataInputStream in;
protected final DataOutputStream out;
protected final UUID session_id=UUID.randomUUID();
public Connection(Socket sock) throws IOException {
this.sock=sock;
this.in=new DataInputStream(sock.getInputStream());
this.out=new DataOutputStream(sock.getOutputStream());
}
public void stop() {
if(log.isTraceEnabled())
log.trace("closing connection to " + sock.getRemoteSocketAddress());
Util.close(in);
Util.close(out);
Util.close(sock);
}
protected void remove() {
synchronized(connections) {
connections.remove(this);
}
for(Set conns: subscriptions.values()) {
conns.remove(this);
}
subscriptions.entrySet().removeIf(entry -> entry.getValue().isEmpty());
}
public void run() {
while(!sock.isClosed()) {
try {
Frame frame=readFrame(in);
if(frame != null) {
if(log.isTraceEnabled())
log.trace(frame);
handleFrame(frame);
}
}
catch(IOException ex) {
stop();
remove();
}
catch(Throwable t) {
log.error(Util.getMessage("FailureReadingFrame"), t);
}
}
}
protected void handleFrame(Frame frame) {
Map headers=frame.getHeaders();
ClientVerb verb=ClientVerb.valueOf(frame.getVerb());
switch(verb) {
case CONNECT:
writeResponse(ServerVerb.CONNECTED,
"session-id", session_id.toString(),
"password-check", "none");
break;
case SEND:
if(!headers.containsKey("sender")) {
headers.put("sender", session_id.toString());
}
Message msg=new BytesMessage(null, frame.getBody());
Header hdr=StompHeader.createHeader(StompHeader.Type.MESSAGE, headers);
msg.putHeader(id, hdr);
down_prot.down(msg);
String receipt=headers.get("receipt");
if(receipt != null)
writeResponse(ServerVerb.RECEIPT, "receipt-id", receipt);
break;
case SUBSCRIBE:
String destination=headers.get("destination");
if(destination != null) {
Set conns=subscriptions.get(destination);
if(conns == null) {
conns=new HashSet<>();
Set tmp=subscriptions.putIfAbsent(destination, conns);
if(tmp != null)
conns=tmp;
}
conns.add(this);
}
break;
case UNSUBSCRIBE:
destination=headers.get("destination");
if(destination != null) {
Set conns=subscriptions.get(destination);
if(conns != null && conns.remove(this) && conns.isEmpty())
subscriptions.remove(destination);
}
break;
case BEGIN:
case COMMIT:
case ABORT:
case ACK:
case DISCONNECT:
break;
default:
log.error("Verb " + frame.getVerb() + " is not handled");
break;
}
}
public void sendInfo() {
if(send_info) {
writeResponse(ServerVerb.INFO,
"local_addr", local_addr != null? local_addr.toString() : "n/a",
"view", view.toString(),
"endpoints", getAllEndpoints());
// "clients", getAllClients());
}
}
/**
* Sends back a response. The keys_and_values vararg array needs to have an even number of elements
* @param response
* @param keys_and_values
*/
private void writeResponse(ServerVerb response, String ... keys_and_values) {
String tmp=response.name();
try {
out.write(tmp.getBytes());
out.write('\n');
for(int i=0; i < keys_and_values.length; i++) {
String key=keys_and_values[i];
String val=keys_and_values[++i];
out.write((key + ": " + val + "\n").getBytes());
}
out.write("\n".getBytes());
out.write(NULL_BYTE);
out.flush();
}
catch(IOException ex) {
log.error(Util.getMessage("FailedWritingResponse") + response + ": " + ex);
}
}
private void writeResponse(byte[] response, int offset, int length) {
try {
out.write(response, offset, length);
out.flush();
}
catch(IOException ex) {
log.error(Util.getMessage("FailedWritingResponse") + ex);
}
}
}
public static class Frame {
final String verb;
final Map headers;
final byte[] body;
public Frame(String verb, Map headers, byte[] body) {
this.verb=verb;
this.headers=headers;
this.body=body;
}
public byte[] getBody() {
return body;
}
public Map getHeaders() {
return headers;
}
public String getVerb() {
return verb;
}
public String toString() {
StringBuilder sb=new StringBuilder();
sb.append(verb).append("\n");
if(headers != null && !headers.isEmpty()) {
for(Map.Entry entry: headers.entrySet())
sb.append(entry.getKey()).append(": ").append(entry.getValue()).append("\n");
}
if(body != null && body.length > 0) {
sb.append("body: ");
if(body.length < 50)
sb.append(new String(body)).append(" (").append(body.length).append(" bytes)");
else
sb.append(body.length).append(" bytes");
}
return sb.toString();
}
}
public static class StompHeader extends org.jgroups.Header {
public enum Type {MESSAGE, ENDPOINT}
protected Type type;
protected final Map headers=new HashMap<>();
public StompHeader() {
}
public Supplier extends Header> create() {return StompHeader::new;}
public short getMagicId() {return 71;}
private StompHeader(Type type) {
this.type=type;
}
/**
* Creates a new header
* @param type
* @param headers Keys and values to be added to the header hashmap. Needs to be an even number
* @return
*/
public static StompHeader createHeader(Type type, String ... headers) {
StompHeader retval=new StompHeader(type);
if(headers != null) {
for(int i=0; i < headers.length; i++) {
String key=headers[i];
String value=headers[++i];
retval.headers.put(key, value);
}
}
return retval;
}
public static StompHeader createHeader(Type type, Map headers) {
StompHeader retval=new StompHeader(type);
if(headers != null)
retval.headers.putAll(headers);
return retval;
}
@Override
public int serializedSize() {
int retval=Global.INT_SIZE *2; // type + size of hashmap
for(Map.Entry entry: headers.entrySet()) {
retval+=entry.getKey().length() +2;
retval+=entry.getValue().length() +2;
}
return retval;
}
@Override
public void writeTo(DataOutput out) throws IOException {
out.writeInt(type.ordinal());
out.writeInt(headers.size());
for(Map.Entry entry: headers.entrySet()) {
out.writeUTF(entry.getKey());
out.writeUTF(entry.getValue());
}
}
@Override
public void readFrom(DataInput in) throws IOException {
type=Type.values()[in.readInt()];
int size=in.readInt();
for(int i=0; i < size; i++) {
String key=in.readUTF();
String value=in.readUTF();
headers.put(key, value);
}
}
public String toString() {
StringBuilder sb=new StringBuilder(type.toString());
sb.append("headers: ").append(headers);
return sb.toString();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy