All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.talend.sdk.component.server.configuration.WebSocketBroadcastSetup Maven / Gradle / Ivy

/**
 * Copyright (C) 2006-2021 Talend Inc. - www.talend.com
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.talend.sdk.component.server.configuration;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static java.util.Locale.ENGLISH;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.TreeMap;
import java.util.logging.Logger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import javax.enterprise.context.Dependent;
import javax.enterprise.inject.Instance;
import javax.inject.Inject;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletContextEvent;
import javax.servlet.ServletContextListener;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;
import javax.ws.rs.ApplicationPath;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.HttpHeaders;
import javax.xml.namespace.QName;

import org.apache.cxf.Bus;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.continuations.Continuation;
import org.apache.cxf.continuations.ContinuationCallback;
import org.apache.cxf.continuations.ContinuationProvider;
import org.apache.cxf.endpoint.ServerRegistry;
import org.apache.cxf.jaxrs.JAXRSServiceFactoryBean;
import org.apache.cxf.message.ExchangeImpl;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageImpl;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.transport.AbstractDestination;
import org.apache.cxf.transport.Conduit;
import org.apache.cxf.transport.MessageObserver;
import org.apache.cxf.transport.http.AbstractHTTPDestination;
import org.apache.cxf.transport.http.ContinuationProviderFactory;
import org.apache.cxf.transport.http.DestinationRegistry;
import org.apache.cxf.transport.http.HTTPSession;
import org.apache.cxf.transport.servlet.ServletController;
import org.apache.cxf.transport.servlet.ServletDestination;
import org.apache.cxf.transport.servlet.servicelist.ServiceListGeneratorServlet;
import org.apache.cxf.transports.http.configuration.HTTPServerPolicy;
import org.apache.cxf.ws.addressing.EndpointReferenceType;
import org.talend.sdk.component.server.front.cxf.CxfExtractor;
import org.talend.sdk.component.server.front.memory.InMemoryRequest;
import org.talend.sdk.component.server.front.memory.InMemoryResponse;
import org.talend.sdk.component.server.front.memory.MemoryInputStream;
import org.talend.sdk.component.server.front.memory.SimpleServletConfig;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;

// ensure any JAX-RS command can use websockets
@Slf4j
@Dependent
@WebListener
public class WebSocketBroadcastSetup implements ServletContextListener {

    private static final String EOM = "^@";

    @Inject
    private Bus bus;

    @Inject
    private CxfExtractor cxf;

    @Inject
    private Instance applications;

    @Override
    public void contextInitialized(final ServletContextEvent sce) {
        final ServerContainer container =
                ServerContainer.class.cast(sce.getServletContext().getAttribute(ServerContainer.class.getName()));

        final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class
                .cast(bus
                        .getExtension(ServerRegistry.class)
                        .getServers()
                        .iterator()
                        .next()
                        .getEndpoint()
                        .get(JAXRSServiceFactoryBean.class.getName()));

        final String appBase = StreamSupport
                .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false)
                .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class))
                .map(a -> a.getClass().getAnnotation(ApplicationPath.class))
                .map(ApplicationPath::value)
                .findFirst()
                .map(s -> !s.startsWith("/") ? "/" + s : s)
                .orElse("/api/v1");
        final String version = appBase.replaceFirst("/api", "");

        final DestinationRegistry registry = cxf.getRegistry();
        final ServletContext servletContext = sce.getServletContext();

        final WebSocketRegistry webSocketRegistry = new WebSocketRegistry(registry);
        final ServletController controller = new ServletController(webSocketRegistry,
                new SimpleServletConfig(servletContext, "Talend Component Kit Websocket Transport"),
                new ServiceListGeneratorServlet(registry, bus));
        webSocketRegistry.controller = controller;

        Stream
                .concat(factory
                        .getClassResourceInfo()
                        .stream()
                        .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream())
                        .filter(cri -> cri.getAnnotatedMethod().getDeclaringClass().getName().startsWith("org.talend."))
                        .map(ori -> {
                            final String uri = ori.getClassResourceInfo().getURITemplate().getValue()
                                    + ori.getURITemplate().getValue();
                            return ServerEndpointConfig.Builder
                                    .create(Endpoint.class,
                                            "/websocket" + version + "/"
                                                    + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri)
                                    .configurator(new ServerEndpointConfig.Configurator() {

                                        @Override
                                        public  T getEndpointInstance(final Class clazz)
                                                throws InstantiationException {
                                            final Map> headers = new HashMap<>();
                                            if (!ori.getProduceTypes().isEmpty()) {
                                                headers
                                                        .put(HttpHeaders.CONTENT_TYPE, singletonList(
                                                                ori.getProduceTypes().iterator().next().toString()));
                                            }
                                            if (!ori.getConsumeTypes().isEmpty()) {
                                                headers
                                                        .put(HttpHeaders.ACCEPT, singletonList(
                                                                ori.getConsumeTypes().iterator().next().toString()));
                                            }
                                            return (T) new JAXRSEndpoint(appBase, controller, servletContext,
                                                    ori.getHttpMethod(), uri, headers);
                                        }
                                    })
                                    .build();
                        }),
                        Stream
                                .of(ServerEndpointConfig.Builder
                                        .create(Endpoint.class, "/websocket" + version + "/bus")
                                        .configurator(new ServerEndpointConfig.Configurator() {

                                            @Override
                                            public  T getEndpointInstance(final Class clazz)
                                                    throws InstantiationException {

                                                return (T) new JAXRSEndpoint(appBase, controller, servletContext, "GET",
                                                        "/", emptyMap());
                                            }
                                        })
                                        .build()))
                .sorted(Comparator.comparing(ServerEndpointConfig::getPath))
                .peek(e -> log.info("Deploying WebSocket(path={})", e.getPath()))
                .forEach(config -> {
                    try {
                        container.addEndpoint(config);
                    } catch (final DeploymentException e) {
                        throw new IllegalStateException(e);
                    }
                });
    }

    @Data
    @EqualsAndHashCode(callSuper = false)
    private static class JAXRSEndpoint extends Endpoint {

        private final String appBase;

        private final ServletController controller;

        private final ServletContext context;

        private final String defaultMethod;

        private final String defaultUri;

        private final Map> baseHeaders;

        @Override
        public void onOpen(final Session session, final EndpointConfig endpointConfig) {
            log.debug("Opened session {}", session.getId());
            session.addMessageHandler(InputStream.class, message -> {
                final Map> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
                headers.putAll(baseHeaders);

                final StringBuilder buffer = new StringBuilder(128);
                try { // read headers from the message
                    if (!"SEND".equalsIgnoreCase(readLine(buffer, message))) {
                        throw new IllegalArgumentException("not a message");
                    }

                    String line;
                    int del;
                    while ((line = readLine(buffer, message)) != null) {
                        final boolean done = line.endsWith(EOM);
                        if (done) {
                            line = line.substring(0, line.length() - EOM.length());
                        }
                        if (!line.isEmpty()) {
                            del = line.indexOf(':');
                            if (del < 0) {
                                headers.put(line.trim(), emptyList());
                            } else {
                                headers
                                        .put(line.substring(0, del).trim(),
                                                singletonList(line.substring(del + 1).trim()));
                            }
                        }
                        if (done) {
                            break;
                        }
                    }
                } catch (final IOException ioe) {
                    throw new IllegalStateException(ioe);
                }

                final List uris = headers.get("destination");
                final String uri;
                if (uris == null || uris.isEmpty()) {
                    uri = defaultUri;
                } else {
                    uri = uris.iterator().next();
                }

                final List methods = headers.get("destinationMethod");
                final String method;
                if (methods == null || methods.isEmpty()) {
                    method = defaultMethod;
                } else {
                    method = methods.iterator().next();
                }

                final String queryString;
                final String path;
                final int query = uri.indexOf('?');
                if (query > 0) {
                    queryString = uri.substring(query + 1);
                    path = uri.substring(0, query);
                } else {
                    queryString = null;
                    path = uri;
                }

                try {
                    final InMemoryRequest request = new InMemoryRequest(method.toUpperCase(ENGLISH), headers, path,
                            appBase + path, appBase, queryString, 8080, context, new WebSocketInputStream(message),
                            session::getUserPrincipal, controller);
                    final InMemoryResponse response = new InMemoryResponse(session::isOpen, () -> {
                        if (session.getBasicRemote().getBatchingAllowed()) {
                            try {
                                session.getBasicRemote().flushBatch();
                            } catch (final IOException e) {
                                throw new IllegalStateException(e);
                            }
                        }
                    }, bytes -> {
                        try {
                            session.getBasicRemote().sendBinary(ByteBuffer.wrap(bytes));
                        } catch (final IOException e) {
                            throw new IllegalStateException(e);
                        }
                    }, (status, responseHeaders) -> {
                        final StringBuilder top = new StringBuilder("MESSAGE\r\n");
                        top.append("status: ").append(status).append("\r\n");
                        responseHeaders
                                .forEach((k,
                                        v) -> top.append(k).append(": ").append(String.join(",", v)).append("\r\n"));
                        top.append("\r\n");// empty line, means the next bytes are the payload
                        return top.toString();
                    }) {

                        @Override
                        protected void onClose(final OutputStream stream) throws IOException {
                            stream.write(EOM.getBytes(StandardCharsets.UTF_8));
                        }
                    };
                    request.setResponse(response);
                    controller.invoke(request, response);
                } catch (final ServletException e) {
                    throw new IllegalArgumentException(e);
                }
            });
        }

        @Override
        public void onClose(final Session session, final CloseReason closeReason) {
            log.debug("Closed session {}", session.getId());
        }

        @Override
        public void onError(final Session session, final Throwable throwable) {
            log.warn("Error for session {}", session.getId(), throwable);
        }

        private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException {
            int c;
            while ((c = in.read()) != -1) {
                if (c == '\n') {
                    break;
                } else if (c != '\r') {
                    buffer.append((char) c);
                }
            }

            if (buffer.length() == 0) {
                return null;
            }
            final String string = buffer.toString();
            buffer.setLength(0);
            return string;
        }
    }

    private static class WebSocketInputStream extends MemoryInputStream {

        private int previous = Integer.MAX_VALUE;

        private WebSocketInputStream(final InputStream delegate) {
            super(delegate);
        }

        @Override
        public int read() throws IOException {
            if (finished) {
                return -1;
            }
            if (previous != Integer.MAX_VALUE) {
                previous = Integer.MAX_VALUE;
                return previous;
            }
            final int read = delegate.read();
            if (read == '^') {
                previous = delegate.read();
                if (previous == '@') {
                    finished = true;
                    return -1;
                }
            }
            if (read < 0) {
                finished = true;
            }
            return read;
        }
    }

    private static class WebSocketRegistry implements DestinationRegistry {

        private final DestinationRegistry delegate;

        private ServletController controller;

        private WebSocketRegistry(final DestinationRegistry registry) {
            this.delegate = registry;
        }

        @Override
        public void addDestination(final AbstractHTTPDestination destination) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void removeDestination(final String path) {
            throw new UnsupportedOperationException();
        }

        @Override
        public AbstractHTTPDestination getDestinationForPath(final String path) {
            return wrap(delegate.getDestinationForPath(path));
        }

        @Override
        public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) {
            return wrap(delegate.getDestinationForPath(path, tryDecoding));
        }

        @Override
        public AbstractHTTPDestination checkRestfulRequest(final String address) {
            return wrap(delegate.checkRestfulRequest(address));
        }

        @Override
        public Collection getDestinations() {
            return delegate.getDestinations();
        }

        @Override
        public AbstractDestination[] getSortedDestinations() {
            return delegate.getSortedDestinations();
        }

        @Override
        public Set getDestinationsPaths() {
            return delegate.getDestinationsPaths();
        }

        private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) {
            try {
                return destination == null ? null : new WebSocketDestination(destination, this);
            } catch (final IOException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    private static class WebSocketDestination extends AbstractHTTPDestination {

        static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class);

        private final AbstractHTTPDestination delegate;

        private WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry)
                throws IOException {
            super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false);
            this.delegate = delegate;
            this.cproviderFactory = new WebSocketContinuationFactory(registry);
        }

        @Override
        public EndpointReferenceType getAddress() {
            return delegate.getAddress();
        }

        @Override
        public Conduit getBackChannel(final Message inMessage) throws IOException {
            return delegate.getBackChannel(inMessage);
        }

        @Override
        public EndpointInfo getEndpointInfo() {
            return delegate.getEndpointInfo();
        }

        @Override
        public void shutdown() {
            throw new UnsupportedOperationException();
        }

        @Override
        public void setMessageObserver(final MessageObserver observer) {
            throw new UnsupportedOperationException();
        }

        @Override
        public MessageObserver getMessageObserver() {
            return delegate.getMessageObserver();
        }

        @Override
        protected Logger getLogger() {
            return LOG;
        }

        @Override
        public Bus getBus() {
            return delegate.getBus();
        }

        @Override
        public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req,
                final HttpServletResponse resp) throws IOException {
            // eager create the message to ensure we set our continuation for @Suspended
            Message inMessage = retrieveFromContinuation(req);
            if (inMessage == null) {
                inMessage = new MessageImpl();

                final ExchangeImpl exchange = new ExchangeImpl();
                exchange.setInMessage(inMessage);
                setupMessage(inMessage, config, context, req, resp);

                exchange.setSession(new HTTPSession(req));
                MessageImpl.class.cast(inMessage).setDestination(this);
            }

            delegate.invoke(config, context, req, resp);
        }

        @Override
        public void finalizeConfig() {
            delegate.finalizeConfig();
        }

        @Override
        public String getBeanName() {
            return delegate.getBeanName();
        }

        @Override
        public EndpointReferenceType getAddressWithId(final String id) {
            return delegate.getAddressWithId(id);
        }

        @Override
        public String getId(final Map context) {
            return delegate.getId(context);
        }

        @Override
        public String getContextMatchStrategy() {
            return delegate.getContextMatchStrategy();
        }

        @Override
        public boolean isFixedParameterOrder() {
            return delegate.isFixedParameterOrder();
        }

        @Override
        public boolean isMultiplexWithAddress() {
            return delegate.isMultiplexWithAddress();
        }

        @Override
        public HTTPServerPolicy getServer() {
            return delegate.getServer();
        }

        @Override
        public void assertMessage(final Message message) {
            delegate.assertMessage(message);
        }

        @Override
        public boolean canAssert(final QName type) {
            return delegate.canAssert(type);
        }

        @Override
        public String getPath() {
            return delegate.getPath();
        }
    }

    private static class WebSocketContinuationFactory implements ContinuationProviderFactory {

        private static final String KEY = WebSocketContinuationFactory.class.getName();

        private final WebSocketRegistry registry;

        private WebSocketContinuationFactory(final WebSocketRegistry registry) {
            this.registry = registry;
        }

        @Override
        public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req,
                final HttpServletResponse resp) {
            return new WebSocketContinuation(inMessage, req, resp, registry);
        }

        @Override
        public Message retrieveFromContinuation(final HttpServletRequest req) {
            return Message.class.cast(req.getAttribute(KEY));
        }
    }

    private static class WebSocketContinuation implements ContinuationProvider, Continuation {

        private final Message message;

        private final HttpServletRequest request;

        private final HttpServletResponse response;

        private final WebSocketRegistry registry;

        private final ContinuationCallback callback;

        private Object object;

        private boolean resumed;

        private boolean pending;

        private boolean isNew;

        private WebSocketContinuation(final Message message, final HttpServletRequest request,
                final HttpServletResponse response, final WebSocketRegistry registry) {
            this.message = message;
            this.request = request;
            this.response = response;
            this.registry = registry;
            this.request
                    .setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE,
                            message.getExchange().getInMessage());
            this.callback = message.getExchange().get(ContinuationCallback.class);
        }

        @Override
        public Continuation getContinuation() {
            return this;
        }

        @Override
        public void complete() {
            message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE);
            if (callback != null) {
                final Exception ex = message.getExchange().get(Exception.class);
                if (ex == null) {
                    callback.onComplete();
                } else {
                    callback.onError(ex);
                }
            }
            try {
                response.getWriter().close();
            } catch (final IOException e) {
                throw new IllegalStateException(e);
            }
        }

        @Override
        public boolean suspend(final long timeout) {
            isNew = false;
            resumed = false;
            pending = true;
            message.getExchange().getInMessage().getInterceptorChain().suspend();
            return true;
        }

        @Override
        public void resume() {
            resumed = true;
            try {
                registry.controller.invoke(request, response);
            } catch (final ServletException e) {
                throw new IllegalStateException(e);
            }
        }

        @Override
        public void reset() {
            pending = false;
            resumed = false;
            isNew = false;
            object = null;
        }

        @Override
        public boolean isNew() {
            return isNew;
        }

        @Override
        public boolean isPending() {
            return pending;
        }

        @Override
        public boolean isResumed() {
            return resumed;
        }

        @Override
        public boolean isTimeout() {
            return false;
        }

        @Override
        public Object getObject() {
            return object;
        }

        @Override
        public void setObject(final Object o) {
            object = o;
        }

        @Override
        public boolean isReadyForWrite() {
            return true;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy