io.undertow.websockets.jsr.ServerWebSocketContainer Maven / Gradle / Ivy
/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.websockets.jsr;
import static java.lang.System.currentTimeMillis;
import java.io.Closeable;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.SSLContext;
import javax.servlet.DispatcherType;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.ClientEndpoint;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.Session;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData;
import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandshaker;
import io.undertow.httpcore.StatusCodes;
import io.undertow.server.HttpServerExchange;
import io.undertow.servlet.api.ClassIntrospecter;
import io.undertow.servlet.api.InstanceFactory;
import io.undertow.servlet.api.InstanceHandle;
import io.undertow.servlet.api.ThreadSetupHandler;
import io.undertow.servlet.spec.ServletContextImpl;
import io.undertow.servlet.util.ConstructorInstanceFactory;
import io.undertow.servlet.util.ImmediateInstanceHandle;
import io.undertow.util.CopyOnWriteMap;
import io.undertow.util.PathTemplate;
import io.undertow.websockets.jsr.annotated.AnnotatedEndpointFactory;
import io.undertow.websockets.jsr.handshake.Handshake;
import io.undertow.websockets.jsr.handshake.HandshakeUtil;
/**
* {@link ServerContainer} implementation which allows to deploy endpoints for a server.
*
* @author Norman Maurer
*/
public class ServerWebSocketContainer implements ServerContainer, Closeable {
public static final String TIMEOUT = "io.undertow.websocket.CONNECT_TIMEOUT";
public static final int DEFAULT_WEB_SOCKET_TIMEOUT_SECONDS = 10;
public static final int DEFAULT_MAX_FRAME_SIZE = 65536;
private final ClassIntrospecter classIntrospecter;
private final Map, ConfiguredClientEndpoint> clientEndpoints = new CopyOnWriteMap<>();
private final List configuredServerEndpoints = new ArrayList<>();
private final Set> annotatedEndpointClasses = new HashSet<>();
/**
* set of all deployed server endpoint paths. Due to the comparison function we can detect
* overlaps
*/
private final TreeSet seenPaths = new TreeSet<>();
private final boolean dispatchToWorker;
private final InetSocketAddress clientBindAddress;
private final WebSocketReconnectHandler webSocketReconnectHandler;
private final Supplier eventLoopSupplier;
private final Supplier executorSupplier;
private volatile long defaultAsyncSendTimeout;
private volatile long defaultMaxSessionIdleTimeout;
private volatile int defaultMaxBinaryMessageBufferSize;
private volatile int defaultMaxTextMessageBufferSize;
private volatile boolean deploymentComplete = false;
private final List deploymentExceptions = new ArrayList<>();
private ServletContextImpl contextToAddFilter = null;
private final List pauseListeners = new ArrayList<>();
private final List installedExtensions;
private final List clientSslProviders;
private final int maxFrameSize;
private final ThreadSetupHandler.Action invokeEndpointTask;
private volatile boolean closed = false;
public ServerWebSocketContainer(final ClassIntrospecter classIntrospecter, final Supplier eventLoopSupplier, List threadSetupHandlers, boolean dispatchToWorker, boolean clientMode) {
this(classIntrospecter, ServerWebSocketContainer.class.getClassLoader(), eventLoopSupplier, threadSetupHandlers, dispatchToWorker, null, null);
}
public ServerWebSocketContainer(final ClassIntrospecter classIntrospecter, final ClassLoader classLoader, Supplier eventLoopSupplier, List threadSetupHandlers, boolean dispatchToWorker, Supplier executorSupplier) {
this(classIntrospecter, classLoader, eventLoopSupplier, threadSetupHandlers, dispatchToWorker, null, null, executorSupplier, Collections.emptyList());
}
public ServerWebSocketContainer(final ClassIntrospecter classIntrospecter, final ClassLoader classLoader, Supplier eventLoopSupplier, List threadSetupHandlers, boolean dispatchToWorker, InetSocketAddress clientBindAddress, WebSocketReconnectHandler reconnectHandler) {
this(classIntrospecter, classLoader, eventLoopSupplier, threadSetupHandlers, dispatchToWorker, clientBindAddress, reconnectHandler, null, Collections.emptyList());
}
public ServerWebSocketContainer(final ClassIntrospecter classIntrospecter, final ClassLoader classLoader, Supplier eventLoopSupplier, List threadSetupHandlers, boolean dispatchToWorker, InetSocketAddress clientBindAddress, WebSocketReconnectHandler reconnectHandler, Supplier executorSupplier, List installedExtensions) {
this(classIntrospecter, classLoader, eventLoopSupplier, threadSetupHandlers, dispatchToWorker, clientBindAddress, reconnectHandler, executorSupplier, installedExtensions, DEFAULT_MAX_FRAME_SIZE);
}
public ServerWebSocketContainer(final ClassIntrospecter classIntrospecter, final ClassLoader classLoader, Supplier eventLoopSupplier, List threadSetupHandlers, boolean dispatchToWorker, InetSocketAddress clientBindAddress, WebSocketReconnectHandler reconnectHandler, Supplier executorSupplier, List installedExtensions, int maxFrameSize) {
this.classIntrospecter = classIntrospecter;
this.eventLoopSupplier = eventLoopSupplier;
this.dispatchToWorker = dispatchToWorker;
this.clientBindAddress = clientBindAddress;
this.executorSupplier = executorSupplier;
this.installedExtensions = new ArrayList<>(installedExtensions);
this.webSocketReconnectHandler = reconnectHandler;
this.maxFrameSize = maxFrameSize;
ThreadSetupHandler.Action task = new ThreadSetupHandler.Action() {
@Override
public Void call(HttpServerExchange exchange, Runnable context) throws Exception {
context.run();
return null;
}
};
List clientSslProviders = new ArrayList<>();
for (WebsocketClientSslProvider provider : ServiceLoader.load(WebsocketClientSslProvider.class, classLoader)) {
clientSslProviders.add(provider);
}
this.clientSslProviders = Collections.unmodifiableList(clientSslProviders);
for (ThreadSetupHandler handler : threadSetupHandlers) {
task = handler.create(task);
}
this.invokeEndpointTask = task;
}
@Override
public long getDefaultAsyncSendTimeout() {
return defaultAsyncSendTimeout;
}
@Override
public void setAsyncSendTimeout(long defaultAsyncSendTimeout) {
this.defaultAsyncSendTimeout = defaultAsyncSendTimeout;
}
public Session connectToServer(final Object annotatedEndpointInstance, WebsocketConnectionBuilder connectionBuilder) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ConfiguredClientEndpoint config = getClientEndpoint(annotatedEndpointInstance.getClass(), false);
if (config == null) {
throw JsrWebSocketMessages.MESSAGES.notAValidClientEndpointType(annotatedEndpointInstance.getClass());
}
Endpoint instance = config.getFactory().createInstance(new ImmediateInstanceHandle<>(annotatedEndpointInstance));
return connectToServerInternal(instance, config, connectionBuilder);
}
@Override
public Session connectToServer(final Object annotatedEndpointInstance, final URI path) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ConfiguredClientEndpoint config = getClientEndpoint(annotatedEndpointInstance.getClass(), false);
if (config == null) {
throw JsrWebSocketMessages.MESSAGES.notAValidClientEndpointType(annotatedEndpointInstance.getClass());
}
Endpoint instance = config.getFactory().createInstance(new ImmediateInstanceHandle<>(annotatedEndpointInstance));
SSLContext ssl = null;
if (path.getScheme().equals("wss")) {
for (WebsocketClientSslProvider provider : clientSslProviders) {
ssl = provider.getSsl(eventLoopSupplier.get(), annotatedEndpointInstance, path);
if (ssl != null) {
break;
}
}
if (ssl == null) {
try {
ssl = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
//ignore
}
}
}
return connectToServerInternal(instance, ssl, config, path);
}
public Session connectToServer(Class> aClass, WebsocketConnectionBuilder connectionBuilder) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ConfiguredClientEndpoint config = getClientEndpoint(aClass, true);
if (config == null) {
throw JsrWebSocketMessages.MESSAGES.notAValidClientEndpointType(aClass);
}
try {
AnnotatedEndpointFactory factory = config.getFactory();
InstanceHandle> instance = config.getInstanceFactory().createInstance();
return connectToServerInternal(factory.createInstance(instance), config, connectionBuilder);
} catch (InstantiationException e) {
throw new RuntimeException(e);
}
}
@Override
public Session connectToServer(Class> aClass, URI uri) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ConfiguredClientEndpoint config = getClientEndpoint(aClass, true);
if (config == null) {
throw JsrWebSocketMessages.MESSAGES.notAValidClientEndpointType(aClass);
}
try {
AnnotatedEndpointFactory factory = config.getFactory();
InstanceHandle> instance = config.getInstanceFactory().createInstance();
SSLContext ssl = null;
if (uri.getScheme().equals("wss")) {
for (WebsocketClientSslProvider provider : clientSslProviders) {
ssl = provider.getSsl(eventLoopSupplier.get(), aClass, uri);
if (ssl != null) {
break;
}
}
if (ssl == null) {
try {
ssl = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
//ignore
}
}
}
return connectToServerInternal(factory.createInstance(instance), ssl, config, uri);
} catch (InstantiationException e) {
throw new RuntimeException(e);
}
}
@Override
public Session connectToServer(final Endpoint endpointInstance, final ClientEndpointConfig config, final URI path) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ClientEndpointConfig cec = config != null ? config : ClientEndpointConfig.Builder.create().build();
SSLContext ssl = null;
if (path.getScheme().equals("wss")) {
for (WebsocketClientSslProvider provider : clientSslProviders) {
ssl = provider.getSsl(eventLoopSupplier.get(), endpointInstance, cec, path);
if (ssl != null) {
break;
}
}
if (ssl == null) {
try {
ssl = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
//ignore
}
}
}
//in theory we should not be able to connect until the deployment is complete, but the definition of when a deployment is complete is a bit nebulous.
ClientNegotiation clientNegotiation = new ClientNegotiation(cec.getPreferredSubprotocols(), toExtensionList(cec.getExtensions()), cec);
WebsocketConnectionBuilder connectionBuilder = new WebsocketConnectionBuilder(path, eventLoopSupplier.get())
.setSsl(ssl)
.setBindAddress(clientBindAddress)
.setClientNegotiation(clientNegotiation);
return connectToServer(endpointInstance, config, connectionBuilder);
}
private static List toExtensionList(final List extensions) {
List ret = new ArrayList<>();
for (Extension e : extensions) {
final Map parameters = new HashMap<>();
for (Extension.Parameter p : e.getParameters()) {
parameters.put(p.getName(), p.getValue());
}
ret.add(new WebSocketExtensionData(e.getName(), parameters));
}
return ret;
}
public Session connectToServer(final Endpoint endpointInstance, final ClientEndpointConfig config, WebsocketConnectionBuilder connectionBuilder) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
ClientEndpointConfig cec = config != null ? config : ClientEndpointConfig.Builder.create().build();
WebSocketClientNegotiation clientNegotiation = connectionBuilder.getClientNegotiation();
CompletableFuture sessionCompletableFuture = new CompletableFuture<>();
EndpointSessionHandler sessionHandler = new EndpointSessionHandler(this);
final List extensions = new ArrayList<>();
final Map extMap = new HashMap<>();
for (Extension ext : cec.getExtensions()) {
extMap.put(ext.getName(), ext);
}
for (WebSocketExtensionData e : clientNegotiation.getSelectedExtensions()) {
Extension ext = extMap.get(e.name());
if (ext == null) {
throw JsrWebSocketMessages.MESSAGES.extensionWasNotPresentInClientHandshake(e.name(), clientNegotiation.getSupportedExtensions());
}
extensions.add(new ExtensionImpl(e));
}
CompletableFuture session = connectionBuilder
.connect(new Function() {
@Override
public UndertowSession apply(Channel channel) {
channel.config().setAutoRead(false);
ConfiguredClientEndpoint configured = clientEndpoints.get(endpointInstance.getClass());
if (configured == null) {
synchronized (clientEndpoints) {
configured = clientEndpoints.get(endpointInstance.getClass());
if (configured == null) {
clientEndpoints.put(endpointInstance.getClass(), configured = new ConfiguredClientEndpoint());
}
}
}
EncodingFactory encodingFactory = null;
try {
encodingFactory = EncodingFactory.createFactory(classIntrospecter, cec.getDecoders(), cec.getEncoders());
} catch (DeploymentException e) {
throw new RuntimeException(e);
}
UndertowSession undertowSession = new UndertowSession(channel, connectionBuilder.getUri(), Collections.emptyMap(), Collections.>emptyMap(), sessionHandler, null, new ImmediateInstanceHandle<>(endpointInstance), cec, connectionBuilder.getUri().getQuery(), encodingFactory.createEncoding(cec), configured, clientNegotiation.getSelectedSubProtocol(), extensions, connectionBuilder, executorSupplier.get());
invokeEndpointMethod(executorSupplier.get(), new Runnable() {
@Override
public void run() {
try {
endpointInstance.onOpen(undertowSession, cec);
} finally {
undertowSession.getFrameHandler().start();
channel.config().setAutoRead(true);
channel.read();
sessionCompletableFuture.complete(undertowSession);
}
}
});
return undertowSession;
}
}).exceptionally(new Function() {
@Override
public UndertowSession apply(Throwable throwable) {
sessionCompletableFuture.completeExceptionally(throwable);
return null;
}
});
Number timeout = (Number) cec.getUserProperties().get(TIMEOUT);
try {
return sessionCompletableFuture.get(timeout == null ? DEFAULT_WEB_SOCKET_TIMEOUT_SECONDS : timeout.intValue(), TimeUnit.SECONDS);
} catch (Exception e) {
session.cancel(true);
throw new IOException(e);
}
}
@Override
public Session connectToServer(final Class extends Endpoint> endpointClass, final ClientEndpointConfig cec, final URI path) throws DeploymentException, IOException {
if (closed) {
throw new ClosedChannelException();
}
try {
Endpoint endpoint = classIntrospecter.createInstanceFactory(endpointClass).createInstance().getInstance();
return connectToServer(endpoint, cec, path);
} catch (InstantiationException | NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
public void doUpgrade(HttpServletRequest request,
HttpServletResponse response, final ServerEndpointConfig sec,
Map pathParams)
throws ServletException, IOException {
ServerEndpointConfig.Configurator configurator = sec.getConfigurator();
try {
EncodingFactory encodingFactory = EncodingFactory.createFactory(classIntrospecter, sec.getDecoders(), sec.getEncoders());
PathTemplate pt = PathTemplate.create(sec.getPath());
InstanceFactory> instanceFactory = null;
try {
instanceFactory = classIntrospecter.createInstanceFactory(sec.getEndpointClass());
} catch (Exception e) {
//so it is possible that this is still valid if a custom configurator is in use
if (configurator == null || configurator.getClass() == ServerEndpointConfig.Configurator.class) {
throw JsrWebSocketMessages.MESSAGES.couldNotDeploy(e);
} else {
instanceFactory = new InstanceFactory
© 2015 - 2025 Weber Informatics LLC | Privacy Policy