Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.github.sseserver.local.SseEmitter Maven / Gradle / Ivy
package com.github.sseserver.local;
import com.github.sseserver.AccessToken;
import com.github.sseserver.AccessUser;
import com.github.sseserver.TenantAccessUser;
import com.github.sseserver.qos.MessageRepository;
import com.github.sseserver.remote.ConnectionDTO;
import com.github.sseserver.util.SnowflakeIdWorker;
import com.github.sseserver.util.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import javax.servlet.http.Cookie;
import java.io.IOException;
import java.io.Serializable;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* 事件推送
*
* @author wangzihaogithub 2022-11-12
*/
public class SseEmitter extends org.springframework.web.servlet.mvc.method.annotation.SseEmitter implements MessageRepository.Query {
private final static Logger log = LoggerFactory.getLogger(SseEmitter.class);
private static final MediaType TEXT_PLAIN = new MediaType("text", "plain", Charset.forName("UTF-8"));
private final long id = SnowflakeIdWorker.INSTANCE.nextId();
private final ACCESS_USER accessUser;
private final AtomicBoolean disconnect = new AtomicBoolean();
private final Queue earlySendQueue = new LinkedList<>();
private final List>> connectListeners = new ArrayList<>(2);
private final List>> disconnectListeners = new ArrayList<>(2);
private final List>>> listenersWatchList = new ArrayList<>(2);
private final Map attributeMap = new LinkedHashMap<>(3);
private final long createTime = System.currentTimeMillis();
private final Map httpParameters = new LinkedHashMap<>(6);
private final Map httpHeaders = new LinkedHashMap<>(6);
private String serverId;
private boolean connect = false;
private volatile boolean complete = false;
private boolean writeable = false;
private boolean earlyDisconnect = false;
private int count;
private int requestUploadCount;
private int requestMessageCount;
private long lastRequestTimestamp;
private String channel;
private String requestIp;
private String requestDomain;
private String userAgent;
private Long sessionDuration;
private Cookie[] httpCookies;
/**
* 前端已正在监听的钩子, 值是 {@link SseEventBuilder#name(String)}
*/
private Set listeners;
private ScheduledFuture> timeoutCheckFuture;
private HttpHeaders responseHeaders;
private IOException sendError;
private int defaultId;
/**
* timeout = 0是永不过期
*/
public SseEmitter(Long timeout) {
this(timeout, null);
}
/**
* timeout = 0是永不过期
*/
public SseEmitter(Long timeout, ACCESS_USER accessUser) {
super(timeout);
this.accessUser = accessUser;
}
public static SseEventBuilderFuture event() {
return new SseEventBuilderFuture<>();
}
public static SseEventBuilderFuture> event(String name, Object data) {
return new SseEventBuilderFuture>().name(name).data(data);
}
private static Long castLong(Object value) {
if (value == null || "".equals(value)) {
return null;
}
if (value instanceof Date) {
return ((Date) value).getTime();
}
return Long.valueOf(value.toString());
}
void requestUpload() {
this.requestUploadCount++;
this.lastRequestTimestamp = System.currentTimeMillis();
}
void requestMessage() {
this.requestMessageCount++;
this.lastRequestTimestamp = System.currentTimeMillis();
}
public void addListeningWatch(Consumer>> watch) {
listenersWatchList.add(watch);
}
public String getServerId() {
return serverId;
}
void setServerId(String serverId) {
this.serverId = serverId;
}
public IOException getSendError() {
return sendError;
}
/**
* 是否可用
*
* @return true=可用
*/
public boolean isActive() {
return !complete && sendError == null;
}
/**
* 是否可写入数据
* 如果可写则立即同步发送.
* 否则放入等待队列{@link #earlySendQueue}, 等到写就绪后{@link #writeableReady()}发送
*
* @return true=可写
* @see #earlySendQueue 过早写入的等待队列
* @see #writeableReady() 写就绪事件
* @see SseEventBuilderFuture 写结束的异步回调
*/
public boolean isWriteable() {
return writeable;
}
public HttpHeaders getResponseHeaders() {
if (responseHeaders == null) {
responseHeaders = new HttpHeaders();
}
return responseHeaders;
}
public int getRequestUploadCount() {
return requestUploadCount;
}
public int getRequestMessageCount() {
return requestMessageCount;
}
public long getLastRequestTimestamp() {
return lastRequestTimestamp;
}
public Map getHttpParameters() {
return httpParameters;
}
public Cookie[] getHttpCookies() {
return httpCookies;
}
public void setHttpCookies(Cookie[] httpCookies) {
this.httpCookies = httpCookies;
}
public Map getHttpHeaders() {
return httpHeaders;
}
public String getRequestIp() {
return requestIp;
}
public void setRequestIp(String requestIp) {
this.requestIp = requestIp;
}
public String getRequestDomain() {
return requestDomain;
}
public void setRequestDomain(String requestDomain) {
this.requestDomain = requestDomain;
}
public boolean isTimeout() {
Long timeout = getTimeout();
if (timeout == null) {
// servlet 默认30秒
timeout = 30_000L;
} else if (timeout <= 0L) {
// 0是永不超时
return false;
}
return (System.currentTimeMillis() - createTime) > timeout;
}
public String getUserAgent() {
return userAgent;
}
public void setUserAgent(String userAgent) {
this.userAgent = userAgent;
}
public void setSessionDuration(Long sessionDuration) {
this.sessionDuration = sessionDuration;
}
public Long getSessionDuration() {
return sessionDuration;
}
public int getCount() {
return count;
}
public long getCreateTime() {
return createTime;
}
public long getId() {
return id;
}
public Date getAccessTime() {
Long accessTime = castLong(httpParameters.get("accessTime"));
return accessTime != null ? new Date(accessTime) : null;
}
/**
* 获取客户端的URL地址
*
* @return URL
*/
public String getLocationHref() {
return (String) httpParameters.get("locationHref");
}
/**
* 前端JS 已正在监听的钩子, 值是 {@link SseEventBuilder#name(String)}
*
* @return
*/
public Set getListeners() {
if (this.listeners == null) {
String listeners = (String) httpParameters.get("listeners");
this.listeners = listeners != null && listeners.length() > 0 ? new LinkedHashSet<>(Arrays.asList(listeners.split(","))) : new LinkedHashSet<>();
}
return this.listeners;
}
/**
* 前端JS是否在监听这个事件, 值是 {@link SseEventBuilder#name(String)}
* 前端没监听就不用推消息
*
* @return true=在监听
*/
@Override
public boolean existListener(String sseListenerName) {
return getListeners().contains(sseListenerName);
}
/**
* 浏览器的sessionID
*
* @return 浏览器的sessionID = clientId(36) + accessTime(13) = 长度49位
*/
public String getBrowserSessionId() {
return ConnectionDTO.browserSessionId(getClientId(), getAccessTime());
}
/**
* 是否是有效版本
*
* @param minVersion 要求的最小版本 (传空就是不控制,全部有效)
* @return true=有效,大于等于minVersion。 false=无效版本,小于minVersion
*/
public boolean isInVersion(String minVersion) {
return WebUtil.isInVersion(getClientVersion(), minVersion);
}
public Long getClientImportModuleTime() {
return castLong(httpParameters.get("clientImportModuleTime"));
}
public Long getClientInstanceTime() {
return castLong(httpParameters.get("clientInstanceTime"));
}
public String getClientInstanceId() {
return (String) httpParameters.get("clientInstanceId");
}
public String getClientId() {
return (String) httpParameters.get("clientId");
}
public String getScreen() {
return (String) httpParameters.get("screen");
}
public Long getTotalJSHeapSize() {
return castLong(httpParameters.get("totalJSHeapSize"));
}
public Long getUsedJSHeapSize() {
return castLong(httpParameters.get("usedJSHeapSize"));
}
public Long getJsHeapSizeLimit() {
return castLong(httpParameters.get("jsHeapSizeLimit"));
}
public String getClientVersion() {
return (String) httpParameters.get("clientVersion");
}
public Serializable getUserId() {
return accessUser instanceof AccessUser ? ((AccessUser) accessUser).getId() : null;
}
public String getAccessToken() {
return accessUser instanceof AccessToken ? ((AccessToken) accessUser).getAccessToken() : null;
}
public Serializable getTenantId() {
return accessUser instanceof TenantAccessUser ? ((TenantAccessUser) accessUser).getTenantId() : null;
}
public ACCESS_USER getAccessUser() {
return accessUser;
}
public Map getAttributeMap() {
return attributeMap;
}
public T getAttribute(String key) {
return (T) attributeMap.get(key);
}
public T setAttribute(String key, Object value) {
return (T) attributeMap.put(key, value);
}
public T removeAttribute(String key) {
return (T) attributeMap.remove(key);
}
public String getChannel() {
return channel;
}
public void setChannel(String channel) {
this.channel = channel;
}
public boolean isConnect() {
return connect;
}
public void addConnectListener(Consumer> consumer) {
if (connect) {
try {
consumer.accept(this);
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("addConnectListener connectListener error = {} {}", e.toString(), consumer, e);
}
}
} else {
connectListeners.add(consumer);
}
}
public void addDisConnectListener(Consumer> consumer) {
if (isDisconnect()) {
try {
consumer.accept(this);
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("addDisConnectListener connectListener error = {} {}", e.toString(), consumer, e);
}
}
} else {
disconnectListeners.add(consumer);
}
}
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
HttpHeaders responseHeaders = this.responseHeaders;
if (responseHeaders != null) {
outputMessage.getHeaders().putAll(responseHeaders);
}
if (earlyDisconnect) {
disconnect();
}
connect = true;
}
/**
* override for spring
*/
public void writeableReady() {
this.writeable = true;
SseEventBuilder builder;
while ((builder = earlySendQueue.poll()) != null) {
try {
send(builder);
} catch (IOException ignored) {
}
}
for (Consumer> connectListener : new ArrayList<>(connectListeners)) {
try {
connectListener.accept(this);
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("connectListener error = {} {}", e, connectListener, e);
}
}
}
connectListeners.clear();
}
@Override
public void complete() {
this.complete = true;
super.complete();
}
@Override
public void completeWithError(Throwable ex) {
this.complete = true;
super.completeWithError(ex);
}
/**
* 发送消息
*
* @param name
* @param data
* @return SseEventBuilderFuture 完成后的回调
* @throws IOException 如果当前处于写就绪 {@link #isWriteable()}, 异常在当前线程会生效.
* 如果尚未写就绪, 异常会在异步回调里通知{@link SseEventBuilderFuture}
*/
public SseEventBuilderFuture> send(String name, Object data) throws IOException {
SseEventBuilderFuture event = event();
send(event.defaultId(++defaultId).name(name).data(data));
return event;
}
/**
* 发送消息
*
* @param builder
* @throws IOException 如果当前处于写就绪 {@link #isWriteable()}, 异常在当前线程会生效.
* 如果尚未写就绪, 异常会在异步回调里通知{@link SseEventBuilderFuture}
*/
@Override
public void send(SseEventBuilder builder) throws IOException {
boolean active = isActive();
if (!writeable && active) {
earlySendQueue.add(builder);
return;
}
count++;
CompletableFuture> future;
if (builder instanceof CompletableFuture) {
future = (CompletableFuture) builder;
} else {
future = null;
}
if (log.isDebugEnabled()) {
if (builder instanceof SseEmitter.SseEventBuilderFuture) {
log.debug("sse connection send {} : {}, id = {}, name = {}, active = {}",
count, this, ((SseEventBuilderFuture) builder).id, ((SseEventBuilderFuture) builder).name, active);
} else {
log.debug("sse connection send {} : {}, active = {}", count, this, active);
}
}
if (sendError != null) {
if (future != null) {
future.completeExceptionally(sendError);
}
throw sendError;
}
if (!active) {
sendError = new ClosedChannelException();
if (future != null) {
future.completeExceptionally(sendError);
}
throw sendError;
}
try {
super.send(builder);
if (future != null) {
future.complete(this);
}
} catch (IllegalStateException e) {
/* tomcat recycle bug. socketWrapper is null. is read op cancel then recycle()
* Http11OutputBuffer: 254行,对端网络关闭, 但没触发onError或onTimeout回调, 这时不知道是否不可用了
*
* Caused by: java.lang.NullPointerException
* at org.apache.coyote.http11.Http11OutputBuffer$SocketOutputBuffer.doWrite(Http11OutputBuffer.java:530)
* at org.apache.coyote.http11.filters.ChunkedOutputFilter.doWrite(ChunkedOutputFilter.java:110)
* at org.apache.coyote.http11.Http11OutputBuffer.doWrite(Http11OutputBuffer.java:189)
* at org.apache.coyote.Response.doWrite(Response.java:599)
* at org.apache.catalina.connector.OutputBuffer.realWriteBytes(OutputBuffer.java:329)
* at org.apache.catalina.connector.OutputBuffer.flushByteBuffer(OutputBuffer.java:766)
* at org.apache.catalina.connector.OutputBuffer.doFlush(OutputBuffer.java:288)
* at org.apache.catalina.connector.OutputBuffer.flush(OutputBuffer.java:262)
* at org.apache.catalina.connector.CoyoteOutputStream.flush(CoyoteOutputStream.java:118)
* at sun.nio.cs.StreamEncoder.implFlush(StreamEncoder.java:297)
* at sun.nio.cs.StreamEncoder.flush(StreamEncoder.java:141)
* at java.io.OutputStreamWriter.flush(OutputStreamWriter.java:229)
* at org.springframework.util.StreamUtils.copy(StreamUtils.java:124)
* at org.springframework.http.converter.StringHttpMessageConverter.writeInternal(StringHttpMessageConverter.java:106)
* at org.springframework.http.converter.StringHttpMessageConverter.writeInternal(StringHttpMessageConverter.java:43)
* at org.springframework.http.converter.AbstractHttpMessageConverter.write(AbstractHttpMessageConverter.java:227)
* at org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitterReturnValueHandler$HttpMessageConvertingHandler.sendInternal(ResponseBodyEmitterReturnValueHandler.java:191)
* at org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitterReturnValueHandler$HttpMessageConvertingHandler.send(ResponseBodyEmitterReturnValueHandler.java:184)
* at org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter.sendInternal(ResponseBodyEmitter.java:189)
*/
ClosedChannelException exception = new ClosedChannelException();
this.sendError = exception;
if (future != null) {
future.completeExceptionally(sendError);
}
disconnect();
throw exception;
} catch (IOException e) {
this.sendError = e;
if (future != null) {
future.completeExceptionally(sendError);
}
throw e;
}
}
public boolean isDisconnect() {
return disconnect.get();
}
private void cancelTimeoutTask() {
ScheduledFuture future = this.timeoutCheckFuture;
if (future != null) {
this.timeoutCheckFuture = null;
future.cancel(false);
}
}
void setTimeoutCheckFuture(ScheduledFuture> timeoutCheckFuture) {
this.timeoutCheckFuture = timeoutCheckFuture;
}
void disconnectByTimeoutCheck() {
disconnect(false);
}
public boolean disconnect() {
return disconnect(true);
}
public boolean disconnect(boolean sendClose) {
if (!connect) {
this.earlyDisconnect = true;
return false;
}
cancelTimeoutTask();
if (disconnect.compareAndSet(false, true)) {
for (Consumer> disconnectListener : new ArrayList<>(disconnectListeners)) {
try {
disconnectListener.accept(this);
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("disconnectListener error = {} {}", e.toString(), disconnectListener, e);
}
}
}
disconnectListeners.clear();
if (sendClose && isActive()) {
try {
SseEventBuilderFuture event = event();
super.send(event.defaultId(++defaultId).name("connect-close").data("{\"connectionId\": \"" + id + "\"}"));
} catch (IOException | IllegalStateException ignored) {
}
}
this.writeable = false;
try {
complete();
} catch (NullPointerException ignored) {
// 关闭服务时,tomcat会报NullPointerException。 ignored for only at shutdown tomcat server.
} catch (Exception e) {
if (log.isWarnEnabled()) {
log.warn("sse connection disconnect exception : {}. {}", e.toString(), this, e);
}
}
return true;
} else {
return false;
}
}
public boolean isMessageChange(Object newMessage, String messageType) {
Object oldMessage = getAttribute(messageType);
if (Objects.equals(oldMessage, newMessage)) {
return false;
}
setAttribute(messageType, newMessage);
return true;
}
public List distinctMessageList(List messageList,
Function idGetter,
String messageType) {
Set distinctSet = (Set) getAttributeMap().computeIfAbsent(messageType, o -> new HashSet<>());
return messageList.stream()
.filter(e -> !distinctSet.contains(idGetter.apply(e)))
.peek(e -> distinctSet.add(idGetter.apply(e)))
.collect(Collectors.toList());
}
@Override
public String toString() {
if (accessUser == null) {
return id + "#";
} else {
return id + "#" + accessUser;
}
}
@Override
public boolean equals(Object obj) {
if (obj instanceof SseEmitter) {
return ((SseEmitter) obj).id == this.id;
}
return false;
}
@Override
public int hashCode() {
return Long.hashCode(this.id);
}
public void addListener(Collection addListener) {
Set listeners = getListeners();
Set beforeCopy = new LinkedHashSet<>(listeners);
listeners.addAll(addListener);
Set afterCopy = new LinkedHashSet<>(listeners);
SseChangeEvent> event = new SseChangeEvent<>(this, SseChangeEvent.EVENT_ADD_LISTENER, beforeCopy, afterCopy);
for (Consumer>> changeEventConsumer : new ArrayList<>(listenersWatchList)) {
changeEventConsumer.accept(event);
}
}
public void removeListener(Collection removeListener) {
Set listeners = getListeners();
Set beforeCopy = new LinkedHashSet<>(listeners);
listeners.removeAll(removeListener);
Set afterCopy = new LinkedHashSet<>(listeners);
SseChangeEvent> event = new SseChangeEvent<>(this, SseChangeEvent.EVENT_REMOVE_LISTENER, beforeCopy, afterCopy);
for (Consumer>> changeEventConsumer : new ArrayList<>(listenersWatchList)) {
changeEventConsumer.accept(event);
}
}
/**
* Sse事件对象, 写给前端后的Future
*
* @see SseEventBuilder
* @see #send(String, Object)
* @see #send(SseEventBuilder)
* @see #isWriteable()
* @see #writeableReady()
*/
public static class SseEventBuilderFuture extends CompletableFuture> implements SseEventBuilder {
private final Set dataToSend = new LinkedHashSet<>(3);
private int defaultId;
private String id;
private String name;
private StringBuilder sb;
public SseEventBuilderFuture() {
}
public SseEventBuilderFuture defaultId(int defaultId) {
this.defaultId = defaultId;
return this;
}
public String getId() {
return id;
}
public String getName() {
return name;
}
@Override
public SseEventBuilderFuture id(String id) {
this.id = id;
append("id:").append(id).append("\n");
return this;
}
@Override
public SseEventBuilderFuture name(String name) {
this.name = name;
append("event:").append(name).append("\n");
return this;
}
@Override
public SseEventBuilderFuture reconnectTime(long reconnectTimeMillis) {
append("retry:").append(String.valueOf(reconnectTimeMillis)).append("\n");
return this;
}
@Override
public SseEventBuilderFuture comment(String comment) {
append(":").append(comment).append("\n");
return this;
}
@Override
public SseEventBuilderFuture data(Object object) {
return data(object, null);
}
@Override
public SseEventBuilderFuture data(Object object, MediaType mediaType) {
if (id == null) {
id(Integer.toString(defaultId));
}
append("data:");
saveAppendedText();
this.dataToSend.add(new DataWithMediaType(object, mediaType));
append("\n");
return this;
}
SseEventBuilderFuture append(String text) {
if (this.sb == null) {
this.sb = new StringBuilder();
}
this.sb.append(text);
return this;
}
@Override
public Set build() {
if ((sb == null || sb.length() == 0) && this.dataToSend.isEmpty()) {
return Collections.emptySet();
}
append("\n");
saveAppendedText();
return this.dataToSend;
}
private void saveAppendedText() {
if (this.sb != null) {
this.dataToSend.add(new DataWithMediaType(this.sb.toString(), TEXT_PLAIN));
this.sb = null;
}
}
}
}