Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.clianz.spur.SpurServer Maven / Gradle / Ivy
/*
* Copyright 2017 Ian Chan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.clianz.spur;
import static com.clianz.spur.helpers.HttpMethods.DELETE;
import static com.clianz.spur.helpers.HttpMethods.GET;
import static com.clianz.spur.helpers.HttpMethods.HEAD;
import static com.clianz.spur.helpers.HttpMethods.OPTIONS;
import static com.clianz.spur.helpers.HttpMethods.PATCH;
import static com.clianz.spur.helpers.HttpMethods.POST;
import static com.clianz.spur.helpers.HttpMethods.PUT;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;
import javax.net.ssl.SSLContext;
import com.clianz.spur.helpers.BasicAuthHandler;
import com.clianz.spur.helpers.CorsHandler;
import com.clianz.spur.helpers.Endpoint;
import com.clianz.spur.helpers.RedirectHttpsHandler;
import com.clianz.spur.helpers.RequestFilter;
import com.clianz.spur.helpers.WebSocketHandler;
import com.clianz.spur.helpers.WebSocketMessageSender;
import com.clianz.spur.helpers.WebSocketOnConnect;
import com.clianz.spur.helpers.WebSocketOnMessage;
import org.jboss.logging.Logger;
import io.undertow.Handlers;
import io.undertow.Undertow;
import io.undertow.UndertowOptions;
import io.undertow.predicate.Predicates;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.PathTemplateHandler;
import io.undertow.server.handlers.encoding.ContentEncodingRepository;
import io.undertow.server.handlers.encoding.EncodingHandler;
import io.undertow.server.handlers.encoding.GzipEncodingProvider;
import io.undertow.server.handlers.sse.ServerSentEventConnection;
import io.undertow.server.handlers.sse.ServerSentEventHandler;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import io.undertow.util.StatusCodes;
import io.undertow.websockets.core.AbstractReceiveListener;
import io.undertow.websockets.core.BufferedTextMessage;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSockets;
public class SpurServer {
private static final Logger LOGGER = Logger.getLogger(SpurServer.class.getName());
private static final String SERVER_ALREADY_DEFINED = "Server already defined.";
private static final HttpString ACCESS_CONTROL_REQUEST_METHOD = new HttpString("Access-Control-Request-Method");
private static final HttpString ACCESS_CONTROL_ALLOW_METHOD = new HttpString("Access-Control-Allow-Methods");
private Undertow server = null;
private Boolean serverLock = true;
private AtomicBoolean serviceDefined = new AtomicBoolean(false);
private Map> endpointsMap = new HashMap<>();
private Map> webSocketChannelsMap = new HashMap<>();
private Map webSocketHandlerMap = new HashMap<>();
private Map sseHandlerMap = new HashMap<>();
private List requestFilters = new ArrayList<>();
public void start() {
startServer(Undertow.builder(), new SpurOptions());
}
public void start(SpurOptions options) {
startServer(Undertow.builder(), options);
}
public void start(SpurOptions options, Undertow.Builder builder) {
startServer(builder, options);
}
public void stop() {
synchronized (serverLock) {
if (server != null) {
server.stop();
LOGGER.info("Server stopped.");
} else {
throw new IllegalStateException("Server did not exist.");
}
}
}
public void resume() {
synchronized (serverLock) {
if (server != null) {
server.start();
LOGGER.info("Server resumed.");
} else {
throw new IllegalStateException("Server did not exist.");
}
}
}
public SpurServer get(String path, BiConsumer, Res> reqRes) {
return setPathHandler(GET, path, reqRes, null);
}
public SpurServer put(String path, Class requestBodyClass, BiConsumer, Res> reqRes) {
return setPathHandler(PUT, path, reqRes, requestBodyClass);
}
public SpurServer post(String path, Class requestBodyClass, BiConsumer, Res> reqRes) {
return setPathHandler(POST, path, reqRes, requestBodyClass);
}
public SpurServer patch(String path, Class requestBodyClass, BiConsumer, Res> reqRes) {
return setPathHandler(PATCH, path, reqRes, requestBodyClass);
}
public SpurServer delete(String path, BiConsumer, Res> reqRes) {
return setPathHandler(DELETE, path, reqRes, null);
}
public SpurServer preFilterRequests(Predicate assertion, Consumer failureHandler) {
requestFilters.add(new RequestFilter(httpServerExchange -> assertion.test(new Req(httpServerExchange, null)),
httpServerExchange -> failureHandler.accept(new Res(httpServerExchange))));
return this;
}
public SpurServer schedule(long intervalSeconds, Runnable runnable) {
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.scheduleAtFixedRate(runnable, 0, intervalSeconds, TimeUnit.SECONDS);
return this;
}
public SpurServer sse(String path) {
sseHandlerMap.put(path, Handlers.serverSentEvents());
return this;
}
public SpurServer websocket(String pathPrefix, WebSocketOnConnect webSocketOnConnect, WebSocketOnMessage webSocketOnMessage) {
webSocketHandlerMap.put(pathPrefix, new WebSocketHandler(pathPrefix, webSocketOnConnect, webSocketOnMessage));
return this;
}
public void broadcastWebsockets(String websocketPath, String msg) {
Set webSocketChannels = webSocketChannelsMap.get(websocketPath);
if (webSocketChannels != null) {
new ArrayList<>(webSocketChannels).forEach(webSocketChannel -> WebSockets.sendText(msg, webSocketChannel, null));
}
}
public void broadcastWebsockets(String websocketPath, String msg, String channelAttributeKey,
Predicate channelAttributeValueTest) {
Set webSocketChannels = webSocketChannelsMap.get(websocketPath);
if (webSocketChannels != null) {
new ArrayList<>(webSocketChannels).stream()
.filter(webSocketChannel -> channelAttributeValueTest.test(webSocketChannel.getAttribute(channelAttributeKey)))
.forEach(webSocketChannel -> WebSockets.sendText(msg, webSocketChannel, null));
}
}
public void broadcastSse(String path, String data) {
ServerSentEventHandler serverSentEventHandler = sseHandlerMap.get(path);
if (serverSentEventHandler != null) {
serverSentEventHandler.getConnections()
.forEach(serverSentEventConnection -> serverSentEventConnection.send(data));
}
}
public void broadcastSse(String path, Consumer action) {
ServerSentEventHandler serverSentEventHandler = sseHandlerMap.get(path);
if (serverSentEventHandler != null) {
serverSentEventHandler.getConnections()
.forEach(action::accept);
}
}
private void startServer(Undertow.Builder builder, SpurOptions options) {
if (serviceDefined.getAndSet(true)) {
throw new IllegalStateException(SERVER_ALREADY_DEFINED);
}
LOGGER.info("Listening to " + options.host + ":" + options.port);
if (options.httpsEnabled) {
SSLContext sslContext = null;
try {
sslContext = options.getSslContext();
} catch (Exception e) {
LOGGER.error("Can not create SSL context from properties: KEYSTORE=" + options.keystorePath + " TRUSTSTORE="
+ options.truststorePath, e);
throw new RuntimeException("Can not create SSL context from properties: KEYSTORE=" + options.keystorePath + " TRUSTSTORE="
+ options.truststorePath);
}
if (sslContext == null) {
throw new IllegalArgumentException(("HTTPS/SSL context must be configured when HTTPS is enabled"));
} else if (options.httpsPort == 0) {
throw new IllegalArgumentException(("HTTPS port must be configured when HTTPS is enabled"));
}
builder = builder.addHttpsListener(options.httpsPort, options.host, sslContext);
LOGGER.info("HTTPS Enabled");
}
synchronized (serverLock) {
server = builder.addHttpListener(options.port, options.host)
.setServerOption(UndertowOptions.REQUEST_PARSE_TIMEOUT, options.requestParseTimeOut)
.setServerOption(UndertowOptions.ENABLE_HTTP2, options.http2Enabled)
.setServerOption(UndertowOptions.MAX_ENTITY_SIZE, options.maxEntitySize)
.setHandler(getHandlers(options))
.build();
server.start();
}
}
private SpurServer setPathHandler(HttpString method, String path, BiConsumer, Res> reqRes, Class classType) {
if (serviceDefined.get()) {
throw new IllegalStateException(SERVER_ALREADY_DEFINED);
}
endpointsMap.putIfAbsent(path, new HashMap<>());
endpointsMap.get(path)
.put(method, new Endpoint(method, path, reqRes, classType));
return this;
}
private HttpHandler getHandlers(SpurOptions options) {
// Path handler
PathTemplateHandler pathTemplateHandler = Handlers.pathTemplate();
endpointsMap.forEach((path, methodEndpointMap) -> pathTemplateHandler.add(path, new HttpHandler() {
public void handleRequest(HttpServerExchange exchange) throws Exception {
// non-blocking
if (options.blockableHandlersEnabled && exchange.isInIoThread()) {
// LOGGER.info("Is in IO thread, dispatching for blockableHandlersEnabled...");
exchange.dispatch(this);
return;
}
// handler code
// LOGGER.info("STARTING Async");
invokePathTemplateHandler(options, methodEndpointMap, exchange);
}
}));
// Websocket
if (!webSocketHandlerMap.isEmpty()) {
webSocketHandlerMap.forEach((pathPrefix, webSocketHandler) -> {
LOGGER.info("Adding WS for path: " + webSocketHandler.getPath());
addWebSocketHandler(pathTemplateHandler, webSocketHandler);
});
}
// SSE
if (!sseHandlerMap.isEmpty()) {
sseHandlerMap.forEach(pathTemplateHandler::add);
}
// CORS
HttpHandler httpHandler = Handlers.predicate(exchange -> isValidCorsOrigin(options, getRequestHeader(exchange, Headers.ORIGIN)),
new CorsHandler(pathTemplateHandler), pathTemplateHandler);
// Gzip
if (options.gzipEnabled) {
httpHandler = new EncodingHandler(new ContentEncodingRepository().addEncodingHandler("gzip", new GzipEncodingProvider(), 50,
Predicates.maxContentSize(options.gzipMaxSize))).setNext(httpHandler);
}
// Custom handlers
for (RequestFilter requestFilter : requestFilters) {
httpHandler = Handlers.predicate(requestFilter.getAssertionPredicate(), httpHandler, requestFilter.getFailureHandler());
}
// Graceful shutdown
httpHandler = Handlers.gracefulShutdown(httpHandler);
// Basic auth
if (!options.basicAuthUser.isEmpty() && !options.basicAuthPassword.isEmpty()) {
httpHandler = new BasicAuthHandler(httpHandler, options.basicAuthUser, options.basicAuthPassword);
}
// Force HTTPS
if (options.forceHttps) {
httpHandler = Handlers.predicate(Predicates.secure(), httpHandler, new RedirectHttpsHandler());
}
return httpHandler;
}
private void addWebSocketHandler(PathTemplateHandler pathTemplateHandler, WebSocketHandler webSocketHandler) {
// TODO: Add security/auth
pathTemplateHandler.add(webSocketHandler.getPath(), Handlers.websocket((exchange, channel) -> {
webSocketChannelsMap.putIfAbsent(webSocketHandler.getPath(), channel.getPeerConnections());
// double randomKey = Math.random();
// LOGGER.info("Setting secret: " + randomKey);
// channel.setAttribute("myKey", "secret key: " + randomKey);
// broadcastToAllWebsockets("A new user connected!");
WebSocketMessageSender sender = new WebSocketMessageSender(channel);
webSocketHandler.getWebSocketOnConnect()
.onConnect(sender);
channel.getReceiveSetter()
.set(new AbstractReceiveListener() {
@Override
protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) {
// LOGGER.info("Key was: " + channel.getAttribute("myKey"));
webSocketHandler.getWebSocketOnMessage()
.onMessage(message.getData(), sender);
}
});
channel.resumeReceives();
}));
}
private void invokePathTemplateHandler(SpurOptions options, Map methodEndpointsMap, HttpServerExchange exchange) {
HttpString requestMethod = exchange.getRequestMethod();
String requestAccessControlRequestMethod = getRequestHeader(exchange, ACCESS_CONTROL_REQUEST_METHOD);
String requestOrigin = getRequestHeader(exchange, Headers.ORIGIN);
if (requestMethod.equals(HEAD)) {
requestMethod = GET;
} else if (requestMethod.equals(OPTIONS) && requestAccessControlRequestMethod != null && methodEndpointsMap.containsKey(
new HttpString(requestAccessControlRequestMethod)) && isValidCorsOrigin(options, requestOrigin)) {
setCorsMethodHeader(options, methodEndpointsMap, exchange);
exchange.endExchange();
return;
}
Endpoint endpoint = methodEndpointsMap.get(requestMethod);
if (endpoint == null) {
exchange.setStatusCode(StatusCodes.METHOD_NOT_ALLOWED);
exchange.getResponseHeaders()
.put(Headers.ALLOW, getAllowedMethods(methodEndpointsMap, options));
exchange.endExchange();
return;
}
Req req = new Req(exchange, endpoint.getBodyClassType());
req.parseBody((newExchange, body) -> endpoint.getReqResBiConsumer()
.accept(req, new Res(newExchange)));
}
private String getRequestHeader(HttpServerExchange exchange, HttpString headerName) {
if (exchange.getRequestHeaders()
.contains(headerName)) {
return exchange.getRequestHeaders()
.get(headerName)
.getFirst();
}
return null;
}
private boolean isValidCorsOrigin(SpurOptions options, String requestOrigin) {
return options.corsHeaders.contains("*") || options.corsHeaders.contains(requestOrigin);
}
private void setCorsMethodHeader(SpurOptions options, Map methodEndpointMap, HttpServerExchange exchange) {
exchange.getResponseHeaders()
.put(ACCESS_CONTROL_ALLOW_METHOD, getAllowedMethods(methodEndpointMap, options));
}
private String getAllowedMethods(Map methodEndpointMap, SpurOptions options) {
StringBuilder methodsAllowed = new StringBuilder();
Set methodsDefined = new TreeSet<>(methodEndpointMap.keySet());
if (methodsDefined.contains(GET)) {
methodsDefined.add(HEAD);
}
if (!options.corsHeaders.isEmpty()) {
methodsDefined.add(OPTIONS);
}
methodsDefined.forEach(httpString -> methodsAllowed.append(", " + httpString));
return methodsAllowed.toString()
.substring(2);
}
}