All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.sseserver.local.LocalController Maven / Gradle / Ivy

package com.github.sseserver.local;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.sseserver.ConnectionQueryService;
import com.github.sseserver.SendService;
import com.github.sseserver.qos.Message;
import com.github.sseserver.qos.MessageRepository;
import com.github.sseserver.remote.ServiceDiscoveryService;
import com.github.sseserver.util.AutoTypeBean;
import com.github.sseserver.util.TypeUtil;
import com.github.sseserver.util.WebUtil;
import com.sun.net.httpserver.*;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.nio.charset.Charset;
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class LocalController implements Closeable {
    private static final Charset UTF_8 = Charset.forName("UTF-8");
    private final HttpServer httpServer;
    private final Supplier localConnectionServiceSupplier;
    private final Supplier localMessageRepositorySupplier;
    private final Supplier discoverySupplier;
    private final boolean primary;

    public LocalController(Supplier localConnectionServiceSupplier,
                           Supplier localMessageRepositorySupplier,
                           Supplier discoverySupplier,
                           boolean primary) {
        this(WebUtil.getIPAddress(), localConnectionServiceSupplier, localMessageRepositorySupplier, discoverySupplier, primary);
    }

    public LocalController(String ip,
                           Supplier localConnectionServiceSupplier,
                           Supplier localMessageRepositorySupplier,
                           Supplier discoverySupplier,
                           boolean primary) {
        this.localMessageRepositorySupplier = localMessageRepositorySupplier;
        this.localConnectionServiceSupplier = localConnectionServiceSupplier;
        this.discoverySupplier = discoverySupplier;
        this.primary = primary;
        this.httpServer = createHttpServer(ip);
        configHttpServer(httpServer);
        httpServer.start();
    }

    public boolean isPrimary() {
        return primary;
    }

    public InetSocketAddress getAddress() {
        return httpServer.getAddress();
    }

    public void registerInstance() {
        if (discoverySupplier == null) {
            return;
        }
        InetSocketAddress address = httpServer.getAddress();
        ServiceDiscoveryService discoveryService = discoverySupplier.get();
        if (discoveryService == null) {
            return;
        }
        for (int i = 0, retry = 3; i < retry; i++) {
            try {
                discoveryService.registerInstance(address.getAddress().getHostAddress(), address.getPort());
                return;
            } catch (Exception e) {
                if (i == retry - 1) {
                    throw e;
                } else {
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException ex) {
                        throw e;
                    }
                }
            }
        }
    }

    protected HttpServer createHttpServer(String ip) {
        while (true) {
            try {
                // 0 = random port
                return HttpServer.create(new InetSocketAddress(ip, 0), 0);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    protected HttpContext configConnectionQueryService(HttpServer httpServer) {
        return httpServer.createContext("/ConnectionQueryService/",
                new ConnectionQueryServiceHttpHandler(localConnectionServiceSupplier));
    }

    protected HttpContext configSendService(HttpServer httpServer) {
        return httpServer.createContext("/SendService/",
                new SendServiceHttpHandler(localConnectionServiceSupplier));
    }

    protected HttpContext configRemoteConnectionService(HttpServer httpServer) {
        return httpServer.createContext("/RemoteConnectionService/",
                new RemoteConnectionServiceHttpHandler(localConnectionServiceSupplier));
    }

    protected HttpContext configMessageRepository(HttpServer httpServer) {
        return httpServer.createContext("/MessageRepository/",
                new MessageRepositoryHttpHandler(localMessageRepositorySupplier));
    }

    protected void configAuthenticator(List httpContextList) {
        for (HttpContext httpContext : httpContextList) {
            httpContext.setAuthenticator(new AuthorizationHeaderAuthenticator(discoverySupplier));
        }
    }

    protected void configFilters(List httpContextList) {
        for (HttpContext httpContext : httpContextList) {
            List filters = httpContext.getFilters();
            filters.add(new ErrorPageFilter());
        }
    }

    protected void configHttpServer(HttpServer httpServer) {
        List contextList = new ArrayList<>(2);
        contextList.add(configConnectionQueryService(httpServer));
        contextList.add(configSendService(httpServer));
        contextList.add(configRemoteConnectionService(httpServer));
        contextList.add(configMessageRepository(httpServer));

        configAuthenticator(contextList);
        configFilters(contextList);
    }

    @Override
    public void close() {
        httpServer.stop(0);
    }

    public static class RemoteConnectionServiceHttpHandler extends AbstractHttpHandler {
        private final Supplier supplier;

        public RemoteConnectionServiceHttpHandler(Supplier supplier) {
            this.supplier = supplier;
        }

        @Override
        public void handle0(HttpExchange request) throws IOException {
            LocalConnectionService service = supplier != null ? supplier.get() : null;
            String rpcMethodName = getRpcMethodName();
            switch (rpcMethodName) {
                case "active": {
                    if (service instanceof LocalConnectionServiceImpl) {
                        LocalConnectionServiceImpl local = ((LocalConnectionServiceImpl) service);
                        List> activeList = body("activeList");
                        for (Map active : activeList) {
                            local.localActive(
                                    Objects.toString(active.get("userId"), null),
                                    Objects.toString(active.get("accessToken"), null));
                        }
                        writeResponse(request, 1);
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "setDurationByUserId": {
                    if (service != null) {
                        writeResponse(request, service.setDurationByUserId(
                                        body("userId"), body("durationSecond"))
                                .size());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "setDurationByAccessToken": {
                    if (service != null) {
                        writeResponse(request, service.setDurationByAccessToken(
                                        body("accessToken"), body("durationSecond"))
                                .size());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "disconnectByUserId": {
                    if (service != null) {
                        writeResponse(request, service.disconnectByUserId(
                                body("userId")
                        ).size());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "disconnectByAccessToken": {
                    if (service != null) {
                        writeResponse(request, service.disconnectByAccessToken(
                                body("accessToken")
                        ).size());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "disconnectByConnectionId": {
                    if (service != null) {
                        writeResponse(request, service.disconnectByConnectionId(
                                body("connectionId", Long.class),
                                body("duration", Long.class),
                                body("sessionDuration", Long.class)
                        ) != null ? 1 : 0);
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "disconnectByConnectionIds": {
                    if (service != null) {
                        Collection connectionIds = body("connectionIds", Collection.class);
                        if (connectionIds == null) {
                            writeResponse(request, 0);
                        } else {
                            writeResponse(request, service.disconnectByConnectionIds(connectionIds.stream()
                                    .map(TypeUtil::castToLong)
                                    .collect(Collectors.toList())
                            ).size());
                        }
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                default: {
                    request.sendResponseHeaders(404, 0);
                    break;
                }
            }
        }
    }

    public static class SendServiceHttpHandler extends AbstractHttpHandler {
        private final Supplier> supplier;

        public SendServiceHttpHandler(Supplier> supplier) {
            this.supplier = supplier;
        }

        @Override
        public void handle0(HttpExchange request) throws IOException {
            String rpcMethodName = getRpcMethodName();
            SendService service = supplier != null ? supplier.get() : null;

            if (service != null) {
                Object scopeOnWriteable = body("scopeOnWriteable");
                if (Boolean.TRUE.equals(scopeOnWriteable)) {
                    service.scopeOnWriteable(() -> {
                        handleCase(request, rpcMethodName, service);
                        return null;
                    });
                } else {
                    handleCase(request, rpcMethodName, service);
                }
            } else {
                writeResponse(request, 0);
            }
        }

        public void handleCase(HttpExchange request, String rpcMethodName, SendService service) throws IOException {
            switch (rpcMethodName) {
                case "sendAll": {
                    writeResponse(request, service.sendAll(
                            body("eventName"),
                            body("body")
                    ));
                    break;
                }
                case "sendAllListening": {
                    writeResponse(request, service.sendAllListening(
                            body("eventName"),
                            body("body")
                    ));
                    break;
                }
                case "sendByChannel": {
                    Object channels = body("channels");
                    if (channels instanceof Collection) {
                        writeResponse(request, service.sendByChannel(
                                (Collection) channels,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByChannel(
                                (String) channels,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByChannelListening": {
                    Object channels = body("channels");
                    if (channels instanceof Collection) {
                        writeResponse(request, service.sendByChannelListening(
                                (Collection) channels,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByChannelListening(
                                (String) channels,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByAccessToken": {
                    Object accessTokens = body("accessTokens");
                    if (accessTokens instanceof Collection) {
                        writeResponse(request, service.sendByAccessToken(
                                (Collection) accessTokens,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByAccessToken(
                                (String) accessTokens,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByAccessTokenListening": {
                    Object accessTokens = body("accessTokens");
                    if (accessTokens instanceof Collection) {
                        writeResponse(request, service.sendByAccessTokenListening(
                                (Collection) accessTokens,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByAccessTokenListening(
                                (String) accessTokens,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByUserId": {
                    Object userIds = body("userIds");
                    if (userIds instanceof Collection) {
                        writeResponse(request, service.sendByUserId(
                                (Collection) userIds,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByUserId(
                                (String) userIds,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByUserIdListening": {
                    Object userIds = body("userIds");
                    if (userIds instanceof Collection) {
                        writeResponse(request, service.sendByUserIdListening(
                                (Collection) userIds,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByUserIdListening(
                                (String) userIds,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByTenantId": {
                    Object tenantIds = body("tenantIds");
                    if (tenantIds instanceof Collection) {
                        writeResponse(request, service.sendByTenantId(
                                (Collection) tenantIds,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByTenantId(
                                (String) tenantIds,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                case "sendByTenantIdListening": {
                    Object tenantIds = body("tenantIds");
                    if (tenantIds instanceof Collection) {
                        writeResponse(request, service.sendByTenantIdListening(
                                (Collection) tenantIds,
                                body("eventName"),
                                body("body")
                        ));
                    } else {
                        writeResponse(request, service.sendByTenantIdListening(
                                (String) tenantIds,
                                body("eventName"),
                                body("body")
                        ));
                    }
                    break;
                }
                default: {
                    request.sendResponseHeaders(404, 0);
                    break;
                }
            }
        }
    }

    public static class ConnectionQueryServiceHttpHandler extends AbstractHttpHandler {
        private final Supplier supplier;

        public ConnectionQueryServiceHttpHandler(Supplier supplier) {
            this.supplier = supplier;
        }

        @Override
        public void handle0(HttpExchange request) throws IOException {
            String rpcMethodName = getRpcMethodName();
            ConnectionQueryService service = supplier != null ? supplier.get() : null;
            switch (rpcMethodName) {
                case "isOnline": {
                    if (service != null) {
                        writeResponse(request, service.isOnline(
                                query("userId")
                        ));
                    } else {
                        writeResponse(request, false);
                    }
                    break;
                }
                case "getUser": {
                    if (service != null) {
                        writeResponse(request, service.getUser(
                                query("userId")
                        ));
                    } else {
                        writeResponse(request, null);
                    }
                    break;
                }
                case "getUsers": {
                    if (service != null) {
                        writeResponse(request, service.getUsers());
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getUsersByListening": {
                    if (service != null) {
                        writeResponse(request, service.getUsersByListening(
                                query("sseListenerName")
                        ));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getUsersByTenantIdListening": {
                    if (service != null) {
                        writeResponse(request, service.getUsersByTenantIdListening(
                                query("tenantId"),
                                query("sseListenerName")
                        ));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getUserIds": {
                    if (service != null) {
                        writeResponse(request, service.getUserIds(String.class));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getUserIdsByListening": {
                    if (service != null) {
                        writeResponse(request, service.getUserIdsByListening(
                                query("sseListenerName"),
                                String.class
                        ));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getUserIdsByTenantIdListening": {
                    if (service != null) {
                        writeResponse(request, service.getUserIdsByTenantIdListening(
                                query("tenantId"),
                                query("sseListenerName"),
                                String.class
                        ));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getAccessTokens": {
                    if (service != null) {
                        writeResponse(request, service.getAccessTokens());
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getTenantIds": {
                    if (service != null) {
                        writeResponse(request, service.getTenantIds(
                                String.class
                        ));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getConnectionDTOAll": {
                    if (service != null) {
                        writeResponse(request, service.getConnectionDTOAll());
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getConnectionDTOByUserId": {
                    if (service != null) {
                        writeResponse(request, service.getConnectionDTOByUserId(query("userId")));
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getChannels": {
                    if (service != null) {
                        writeResponse(request, service.getChannels());
                    } else {
                        writeResponse(request, Collections.emptyList());
                    }
                    break;
                }
                case "getAccessTokenCount": {
                    if (service != null) {
                        writeResponse(request, service.getAccessTokenCount());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "getUserCount": {
                    if (service != null) {
                        writeResponse(request, service.getUserCount());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                case "getConnectionCount": {
                    if (service != null) {
                        writeResponse(request, service.getConnectionCount());
                    } else {
                        writeResponse(request, 0);
                    }
                    break;
                }
                default: {
                    request.sendResponseHeaders(404, 0);
                    break;
                }
            }
        }
    }

    public static class MessageRepositoryHttpHandler extends AbstractHttpHandler {
        private final Supplier supplier;

        public MessageRepositoryHttpHandler(Supplier supplier) {
            this.supplier = supplier;
        }

        @Override
        public void handle0(HttpExchange request) throws IOException {
            MessageRepository service = supplier != null ? supplier.get() : null;
            if (service == null) {
                request.sendResponseHeaders(404, 0);
                return;
            }

            String rpcMethodName = getRpcMethodName();
            switch (rpcMethodName) {
                case "insert": {
                    writeResponse(request, service.insert(
                            new RemoteRequestMessage(request.getPrincipal(), body())
                    ), false);
                    break;
                }
                case "list": {
                    writeResponse(request, service.list(), false);
                    break;
                }
                case "select": {
                    writeResponse(request, service.select(
                            new RequestQuery()
                    ), false);
                    break;
                }
                case "delete": {
                    writeResponse(request, service.delete(
                            body("id", String.class)
                    ), false);
                    break;
                }
                default: {
                    request.sendResponseHeaders(404, 0);
                    break;
                }
            }
        }

        public static class RemoteRequestMessage extends AutoTypeBean implements Message {
            private final HttpPrincipal principal;

            private final String listenerName;
            private final Collection userIdList;
            private final Collection tenantIdList;
            private final Collection accessTokenList;
            private final Collection channelList;
            private final Object body;
            private final String eventName;
            private final String id;
            private final int filters;

            public RemoteRequestMessage(HttpPrincipal principal, Map body) {
                this.principal = principal;
                this.listenerName = body(body, "listenerName");
                this.userIdList = body(body, "userIdList");
                this.tenantIdList = body(body, "tenantIdList");
                this.accessTokenList = body(body, "accessTokenList");
                this.channelList = body(body, "channelList");
                this.body = body(body, "body");
                this.eventName = body(body, "eventName");
                this.id = body(body, "id");
                this.filters = body(body, "filters");
                setArrayClassName(body(body, "arrayClassName"));
                setObjectClassName(body(body, "objectClassName"));
            }

            public static  T body(Map body, String name) {
                return (T) body.get(name);
            }

            @Override
            public String getListenerName() {
                return listenerName;
            }

            @Override
            public Collection getUserIdList() {
                return userIdList;
            }

            @Override
            public Collection getTenantIdList() {
                return tenantIdList;
            }

            @Override
            public Collection getAccessTokenList() {
                return accessTokenList;
            }

            @Override
            public Collection getChannelList() {
                return channelList;
            }

            @Override
            public Object getBody() {
                return body;
            }

            @Override
            public String getEventName() {
                return eventName;
            }

            @Override
            public String getId() {
                return id;
            }

            @Override
            public int getFilters() {
                return filters;
            }
        }

        public class RequestQuery implements MessageRepository.Query {
            private Set listeners;

            @Override
            public Serializable getTenantId() {
                return body("tenantId");
            }

            @Override
            public String getChannel() {
                return body("channel");
            }

            @Override
            public String getAccessToken() {
                return body("accessToken");
            }

            @Override
            public Serializable getUserId() {
                return body("userId");
            }

            @Override
            public Set getListeners() {
                if (this.listeners == null) {
                    Object listeners = body("listeners");
                    if (listeners instanceof Set) {
                        this.listeners = (Set) listeners;
                    } else if (listeners instanceof Collection) {
                        this.listeners = new HashSet((Collection) listeners);
                    } else {
                        return null;
                    }
                }
                return this.listeners;
            }
        }

    }

    public static class AuthorizationHeaderAuthenticator extends Authenticator {
        private final Supplier supplier;

        public AuthorizationHeaderAuthenticator(Supplier supplier) {
            this.supplier = supplier;
        }

        @Override
        public Result authenticate(HttpExchange exchange) {
            ServiceDiscoveryService service = supplier != null ? supplier.get() : null;
            if (service == null) {
                return new Failure(401);
            }
            String authorization = exchange.getRequestHeaders().getFirst("Authorization");
            HttpPrincipal principal = service.login(authorization);
            if (principal != null) {
                return new Success(principal);
            } else {
                return new Failure(401);
            }
        }
    }

    public static class ErrorPageFilter extends Filter {
        @Override
        public void doFilter(HttpExchange request, Chain chain) throws IOException {
            try {
                chain.doFilter(request);
            } catch (Throwable e) {
                byte[] body = ("

500 Internal Server Error

" + e).getBytes(UTF_8); if (request.getResponseCode() == -1) { request.getResponseHeaders().add("Content-Type", "text/html; charset=UTF-8"); request.sendResponseHeaders(500, body.length); OutputStream responseBody = request.getResponseBody(); responseBody.write(body); responseBody.close(); } } } @Override public String description() { return getClass().getSimpleName(); } } public static abstract class AbstractHttpHandler implements HttpHandler { private final ObjectMapper objectMapper = new ObjectMapper(); private final ThreadLocal REQUEST_THREAD_LOCAL = new ThreadLocal<>(); private final ThreadLocal BODY_THREAD_LOCAL = new ThreadLocal<>(); private static boolean isKeepAlive(HttpExchange request) { String connection = request.getRequestHeaders().getFirst("Connection"); String protocol = request.getProtocol(); if (protocol.endsWith("1.0")) { return "keep-alive".equalsIgnoreCase(connection); } else { //不包含close就是保持 return !"close".equalsIgnoreCase(connection); } } protected String query(String name) { return WebUtil.getQueryParam(REQUEST_THREAD_LOCAL.get().getRequestURI().getQuery(), name); } protected Map body() { return BODY_THREAD_LOCAL.get(); } protected T body(String name, Class type) { Object body = body(name); return TypeUtil.cast(body, type); } protected T body(String name) { Map body = BODY_THREAD_LOCAL.get(); if (body == null) { return null; } return (T) body.get(name); } public String getRpcMethodName() { HttpExchange request = REQUEST_THREAD_LOCAL.get(); return request.getRequestURI().getPath().substring(request.getHttpContext().getPath().length()); } @Override public final void handle(HttpExchange request) throws IOException { Headers requestHeaders = request.getRequestHeaders(); String contentLength = requestHeaders.getFirst("content-length"); if (contentLength != null && Long.parseLong(contentLength) > 0) { String contentType = requestHeaders.getFirst("content-type"); if (contentType != null && contentType.startsWith("application/json")) { Map body = objectMapper.readValue(request.getRequestBody(), Map.class); BODY_THREAD_LOCAL.set(body); } } try { REQUEST_THREAD_LOCAL.set(request); handle0(request); } finally { REQUEST_THREAD_LOCAL.remove(); BODY_THREAD_LOCAL.remove(); } } public void handle0(HttpExchange httpExchange) throws IOException { } protected void writeResponse(HttpExchange request, Object data) throws IOException { writeResponse(request, data, true); } protected void writeResponse(HttpExchange request, Object data, boolean autoType) throws IOException { request.getResponseHeaders().set("Content-Type", "application/json; charset=UTF-8"); if (isKeepAlive(request)) { request.getResponseHeaders().set("Connection", "keep-alive"); } else { request.getResponseHeaders().set("Connection", "close"); } Response body = new Response(); body.setData(data); if (autoType) { body.retainClassName(data); } request.sendResponseHeaders(200, 0L); OutputStream out = request.getResponseBody(); objectMapper.writeValue(out, body); out.close(); } } public static class Response extends AutoTypeBean { private T data; public T getData() { return data; } public void setData(T data) { this.data = data; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy