com.github.sseserver.local.SseWebController Maven / Gradle / Ivy
Show all versions of sse-server Show documentation
package com.github.sseserver.local;
import com.github.sseserver.AccessUser;
import com.github.sseserver.qos.Message;
import com.github.sseserver.qos.MessageRepository;
import com.github.sseserver.remote.ClusterConnectionService;
import com.github.sseserver.remote.ConnectionByUserIdDTO;
import com.github.sseserver.remote.ConnectionDTO;
import com.github.sseserver.springboot.SseServerProperties;
import com.github.sseserver.util.CompletableFuture;
import com.github.sseserver.util.PageInfo;
import com.github.sseserver.util.PlatformDependentUtil;
import com.github.sseserver.util.WebUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.context.request.async.DeferredResult;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.ToLongFunction;
import java.util.stream.Collectors;
/**
* 消息事件推送 (分布式)
* 注: !! 这里是示例代码, 根据自己项目封装的用户逻辑, 继承类或复制到自己项目里都行
*
* 1. 如果用nginx代理, 要加下面的配置
* # 长连接配置
* proxy_buffering off;
* proxy_read_timeout 7200s;
* proxy_pass http://xx.xx.xx.xx:xxx;
* proxy_http_version 1.1; #nginx默认是http1.0, 改为1.1 支持长连接, 和后端保持长连接,复用,防止出现文件句柄打开数量过多的错误
* proxy_set_header Connection ""; # 去掉Connection的close字段
*
* @author hao 2021年12月7日19:29:51
*/
//@RestController
//@RequestMapping("/a/sse")
//@RequestMapping("/b/sse")
public class SseWebController {
public static final String API_CONNECT_STREAM = "/connect";
public static final String API_ADD_LISTENER_DO = "/connect/addListener.do";
public static final String API_REMOVE_LISTENER_DO = "/connect/removeListener.do";
public static final String API_MESSAGE_DO = "/connect/message/{path}.do";
public static final String API_UPLOAD_DO = "/connect/upload/{path}.do";
public static final String API_DISCONNECT_DO = "/connect/disconnect.do";
public static final String API_REPOSITORY_MESSAGES_JSON = "/connect/repositoryMessages.json";
public static final String API_USER_JSON = "/connect/users.json";
public static final String API_CONNECTIONS_JSON = "/connect/connections.json";
private static final byte[] SSE_APPEND_BYTES = "\nexport default Sse".getBytes(Charset.forName("UTF-8"));
private final Logger logger = LoggerFactory.getLogger(getClass());
@Autowired(required = false)
protected HttpServletRequest request;
protected LocalConnectionService localConnectionService;
private final ClusterBatchDisconnectRunnable batchDisconnectRunnable = new ClusterBatchDisconnectRunnable(() -> localConnectionService != null ? localConnectionService.getCluster() : null);
@Value("${server.port:8080}")
private Integer serverPort;
private String sseServerIdHeaderName = "X-Sse-Server-Id";
private Integer clientIdMaxConnections = 3;
private Long keepaliveTime;
private boolean enableGetJson = false;
protected boolean isGetJson(String api) {
return api != null && api.endsWith(".json");
}
public boolean isEnableGetJson() {
return enableGetJson;
}
public void setEnableGetJson(boolean enableGetJson) {
this.enableGetJson = enableGetJson;
}
public Long getKeepaliveTime() {
return keepaliveTime;
}
public void setKeepaliveTime(Long keepaliveTime) {
this.keepaliveTime = keepaliveTime;
}
public Integer getClientIdMaxConnections() {
return clientIdMaxConnections;
}
public void setClientIdMaxConnections(Integer clientIdMaxConnections) {
this.clientIdMaxConnections = clientIdMaxConnections;
}
public String getSseServerIdHeaderName() {
return sseServerIdHeaderName;
}
public void setSseServerIdHeaderName(String sseServerIdHeaderName) {
this.sseServerIdHeaderName = sseServerIdHeaderName;
}
/**
* 前端文件
*/
@RequestMapping("")
public Object index(@RequestParam(required = false, name = "script-type", defaultValue = "module") String type) throws IOException {
return ssejs(type);
}
/**
* 前端文件
*/
@RequestMapping("/sse.js")
public Object ssejs(@RequestParam(required = false, name = "script-type", defaultValue = "module") String type) throws IOException {
HttpHeaders headers = new HttpHeaders();
settingResponseHeader(headers);
headers.set("Content-Type", "application/javascript;charset=utf-8");
Resource body = readSseJs(type);
return new ResponseEntity<>(body, headers, HttpStatus.OK);
}
protected Resource readSseJs(String type) throws IOException {
InputStream stream = SseWebController.class.getResourceAsStream("/sse.js");
if ("module".equalsIgnoreCase(type)) {
int bufferSize = Math.max(stream.available(), 4096);
ByteArrayOutputStream out = new ByteArrayOutputStream(bufferSize + SSE_APPEND_BYTES.length);
copy(stream, out, bufferSize);
out.write(SSE_APPEND_BYTES);
return new ByteArrayResource(out.toByteArray());
} else {
return new InputStreamResource(stream);
}
}
public void setLocalConnectionService(LocalConnectionService localConnectionService) {
this.localConnectionService = localConnectionService;
}
@Autowired(required = false)
public void setLocalConnectionServiceMap(Map localConnectionServiceMap) {
if (localConnectionServiceMap == null || localConnectionServiceMap.isEmpty()) {
return;
}
this.localConnectionService = choseLocalConnectionService(localConnectionServiceMap);
}
/**
* 选择一个给当前SseWebController用的链接服务
*
* @return 给当前SseWebController用的链接服务 LocalConnectionService
* @since 1.2.16
*/
protected LocalConnectionService choseLocalConnectionService(Map localConnectionServiceMap) {
return localConnectionServiceMap.values().iterator().next();
}
/**
* 获取当前登录用户, 这里返回后, 就可以获取 {@link SseEmitter#getAccessUser()}
*
* @param api 是哪个接口调用的getAccessUser
* @return 使用者自己系统的用户
* @see #API_CONNECT_STREAM
* @see #API_ADD_LISTENER_DO
* @see #API_REMOVE_LISTENER_DO
* @see #API_MESSAGE_DO
* @see #API_UPLOAD_DO
* @see #API_DISCONNECT_DO
* @see #API_REPOSITORY_MESSAGES_JSON
* @see #API_USER_JSON
* @see #API_CONNECTIONS_JSON
*/
protected ACCESS_USER getAccessUser(String api) {
return null;
}
/**
* 当前登录用户是否有权限访问这个接口
*
* @param currentUser 当前登录用户
* @param api 是哪个接口调用的getAccessUser
* @return true=有权限,false=无权限,会返回 {@link #buildPermissionRejectResponse(Object, String)}
* @see #API_CONNECT_STREAM
* @see #API_ADD_LISTENER_DO
* @see #API_REMOVE_LISTENER_DO
* @see #API_MESSAGE_DO
* @see #API_UPLOAD_DO
* @see #API_DISCONNECT_DO
* @see #API_REPOSITORY_MESSAGES_JSON
* @see #API_USER_JSON
* @see #API_CONNECTIONS_JSON
*/
protected boolean hasPermission(ACCESS_USER currentUser, String api) {
return isEnableGetJson() || !isGetJson(api);
}
protected ResponseEntity buildIfPermissionErrorResponse(ACCESS_USER currentUser, String api) {
if (hasPermission(currentUser, api)) {
return null;
} else {
return buildPermissionRejectResponse(currentUser, api);
}
}
protected ResponseEntity buildPermissionRejectResponse(ACCESS_USER currentUser, String api) {
HttpHeaders headers = new HttpHeaders();
headers.setConnection("close");
settingResponseHeader(headers);
Map body = Collections.singletonMap("error",
"get json api default is disabled. if you need use, place use SseWebController#setEnableGetJson(true);");
return new ResponseEntity<>(body, headers, HttpStatus.UNAUTHORIZED);
}
protected Object wrapOkResponse(Object result) {
return new ResponseWrap<>(result);
}
protected Object onMessage(String path, SseEmitter connection, Map message) {
return path;
}
protected Object onUpload(String path, SseEmitter connection, Map message, Collection files) {
return path;
}
protected void onConnect(SseEmitter conncet, Map query) {
disconnectClientIdMaxConnections(conncet, getClientIdMaxConnections());
}
protected ResponseEntity buildIfLoginVerifyErrorResponse(ACCESS_USER accessUser,
Map query, Map body,
Long keepaliveTime) {
if (accessUser == null) {
return buildUnauthorizedResponse();
}
return null;
}
protected ResponseEntity buildIfConnectVerifyErrorResponse(SseEmitter emitter) {
return null;
}
protected ResponseEntity buildUnauthorizedResponse() {
HttpHeaders headers = new HttpHeaders();
headers.setConnection("close");
settingResponseHeader(headers);
return new ResponseEntity<>("", headers, HttpStatus.UNAUTHORIZED);
}
protected Long choseKeepaliveTime(Long clientKeepaliveTime, Long serverKeepaliveTime) {
if (serverKeepaliveTime != null) {
return serverKeepaliveTime;
} else {
return clientKeepaliveTime;
}
}
/**
* 创建连接
*/
@RequestMapping(value = API_CONNECT_STREAM, method = {RequestMethod.GET, RequestMethod.POST})
public Object connect(@RequestParam Map query, @RequestBody(required = false) Map body,
Long keepaliveTime, Long sessionDuration) {
// args
Map attributeMap = new LinkedHashMap<>(query);
if (body != null) {
attributeMap.putAll(body);
}
Long choseKeepaliveTime = choseKeepaliveTime(keepaliveTime, getKeepaliveTime());
// Verify 1 login
ACCESS_USER currentUser = getAccessUser(API_CONNECT_STREAM);
ResponseEntity loginVerifyErrorResponse = buildIfLoginVerifyErrorResponse(currentUser, query, body, choseKeepaliveTime);
if (loginVerifyErrorResponse != null) {
return loginVerifyErrorResponse;
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_CONNECT_STREAM);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
// build connect
SseEmitter emitter = localConnectionService.connect(currentUser, choseKeepaliveTime, attributeMap);
if (emitter == null) {
return buildUnauthorizedResponse();
}
// dump
String channel = Objects.toString(attributeMap.get("channel"), null);
emitter.setChannel(channel == null || channel.isEmpty() ? null : channel);
emitter.setUserAgent(request.getHeader("User-Agent"));
emitter.setSessionDuration(sessionDuration);
emitter.setRequestIp(getRequestIpAddr(request));
emitter.setRequestDomain(getRequestDomain(request));
emitter.setHttpCookies(request.getCookies());
emitter.getHttpParameters().putAll(attributeMap);
Enumeration headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String name = headerNames.nextElement();
emitter.getHttpHeaders().put(name, request.getHeader(name));
}
// Verify 2 connect
loginVerifyErrorResponse = buildIfConnectVerifyErrorResponse(emitter);
if (loginVerifyErrorResponse != null) {
return loginVerifyErrorResponse;
}
// callback
onConnect(emitter, attributeMap);
settingResponseHeader(emitter.getResponseHeaders());
return emitter;
}
/**
* 新增监听
*
* @return http原生响应
*/
@PostMapping(API_ADD_LISTENER_DO)
public ResponseEntity addListener(@RequestBody ListenerReq req) {
if (req == null || req.isInvalid()) {
return responseEntity(Collections.singletonMap("listener", null));
}
ACCESS_USER currentUser = getAccessUser(API_ADD_LISTENER_DO);
if (currentUser == null) {
return buildUnauthorizedResponse();
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_ADD_LISTENER_DO);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
SseEmitter emitter = localConnectionService.getConnectionById(req.getConnectionId());
if (emitter == null) {
return responseEntity(Collections.singletonMap("error", "connectionId not exist"));
} else {
emitter.addListener(req.getListener());
return responseEntity(Collections.singletonMap("listener", emitter.getListeners()));
}
}
/**
* 移除监听
*
* @return http原生响应
*/
@PostMapping(API_REMOVE_LISTENER_DO)
public ResponseEntity removeListener(@RequestBody ListenerReq req) {
if (req == null || req.isInvalid()) {
return responseEntity(Collections.singletonMap("listener", null));
}
ACCESS_USER currentUser = getAccessUser(API_REMOVE_LISTENER_DO);
if (currentUser == null) {
return buildUnauthorizedResponse();
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_REMOVE_LISTENER_DO);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
SseEmitter emitter = localConnectionService.getConnectionById(req.getConnectionId());
if (emitter == null) {
return responseEntity(Collections.singletonMap("error", "connectionId not exist"));
} else {
emitter.removeListener(req.getListener());
return responseEntity(Collections.singletonMap("listener", emitter.getListeners()));
}
}
/**
* 收到前端的消息
*
* @return http原生响应
*/
@PostMapping(API_MESSAGE_DO)
public ResponseEntity message(@PathVariable String path, Long connectionId, @RequestParam Map query, @RequestBody(required = false) Map body) {
ACCESS_USER currentUser = getAccessUser(API_MESSAGE_DO);
if (currentUser == null) {
return buildUnauthorizedResponse();
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_MESSAGE_DO);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
SseEmitter emitter = localConnectionService.getConnectionById(connectionId);
Map message = new LinkedHashMap<>(query);
message.remove("connectionId");
if (body != null) {
message.putAll(body);
}
if (emitter != null) {
emitter.requestMessage();
}
Object responseBody = onMessage(path, emitter, message);
return responseEntity(responseBody);
}
/**
* 收到前端上传的数据
*
* @return http原生响应
*/
@PostMapping(API_UPLOAD_DO)
public ResponseEntity upload(@PathVariable String path, HttpServletRequest request, Long connectionId, @RequestParam Map query, @RequestBody(required = false) Map body) throws IOException, ServletException {
ACCESS_USER currentUser = getAccessUser(API_UPLOAD_DO);
if (currentUser == null) {
return buildUnauthorizedResponse();
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_UPLOAD_DO);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
SseEmitter emitter = localConnectionService.getConnectionById(connectionId);
Map message = new LinkedHashMap<>(query);
message.remove("connectionId");
if (body != null) {
message.putAll(body);
}
if (emitter != null) {
emitter.requestUpload();
}
Object responseBody = onUpload(path, emitter, message, request.getParts());
return responseEntity(responseBody);
}
/**
* 关闭连接
*/
@PostMapping(API_DISCONNECT_DO)
public Object disconnect(Long connectionId, @RequestParam Map query,
Boolean cluster,
@RequestParam(required = false, defaultValue = "5000") Long timeout,
Long duration,
Long sessionDuration) {
if (connectionId == null) {
return responseEntity(buildDisconnectResult(0, false));
}
SseEmitter disconnect = localConnectionService.disconnectByConnectionId(connectionId, duration, sessionDuration);
int localCount = disconnect != null ? 1 : 0;
if (cluster == null || cluster) {
cluster = localConnectionService.isEnableCluster();
}
if (cluster && localCount == 0) {
DeferredResult result = new DeferredResult<>(timeout, responseEntity(buildDisconnectResult(localCount, true)));
localConnectionService.getCluster().disconnectByConnectionId(connectionId, duration, sessionDuration)
.whenComplete((remoteCount, throwable) -> {
if (throwable != null) {
logger.warn("disconnectConnection exception = {}", throwable, throwable);
result.setResult(responseEntity(buildDisconnectResult(0, false)));
} else {
int count = remoteCount + localCount;
result.setResult(responseEntity(buildDisconnectResult(count, false)));
}
});
return result;
} else {
return responseEntity(buildDisconnectResult(localCount, false));
}
}
@GetMapping(API_REPOSITORY_MESSAGES_JSON)
public Object repositoryMessages(RepositoryMessagesReq req) {
ACCESS_USER currentUser = getAccessUser(API_REPOSITORY_MESSAGES_JSON);
if (currentUser == null) {
return buildUnauthorizedResponse();
}
ResponseEntity permissionErrorResponse = buildIfPermissionErrorResponse(currentUser, API_REPOSITORY_MESSAGES_JSON);
if (permissionErrorResponse != null) {
return permissionErrorResponse;
}
Integer pageNum = req.getPageNum();
Integer pageSize = req.getPageSize();
Long timeout = req.getTimeout();
Boolean cluster = req.getCluster();
if (cluster == null || cluster) {
cluster = localConnectionService.isEnableCluster();
}
CompletableFuture> future;
if (req.isEmptyCondition()) {
if (cluster) {
future = localConnectionService.getClusterMessageRepository().listAsync();
} else {
future = CompletableFuture.completedFuture(localConnectionService.getLocalMessageRepository().list());
}
} else {
if (cluster) {
future = localConnectionService.getClusterMessageRepository().selectAsync(req);
} else {
future = CompletableFuture.completedFuture(localConnectionService.getLocalMessageRepository().select(req));
}
}
DeferredResult result = new DeferredResult<>(timeout, () -> responseEntity(PageInfo.timeout()));
future.whenComplete((messages, throwable) -> {
if (throwable != null) {
result.setErrorResult(throwable);
} else {
List