All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.quarkus.websockets.next.runtime.WebSocketConnectionImpl Maven / Gradle / Ivy

There is a newer version: 3.17.2
Show newest version
package io.quarkus.websockets.next.runtime;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import io.quarkus.websockets.next.HandshakeRequest;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.mutiny.Uni;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.ServerWebSocket;
import io.vertx.core.http.WebSocketBase;
import io.vertx.ext.web.RoutingContext;

class WebSocketConnectionImpl extends WebSocketConnectionBase implements WebSocketConnection {

    private final String generatedEndpointClass;

    private final String endpointId;

    private final ServerWebSocket webSocket;

    private final ConnectionManager connectionManager;

    private final BroadcastSender defaultBroadcast;

    WebSocketConnectionImpl(String generatedEndpointClass, String endpointClass, ServerWebSocket webSocket,
            ConnectionManager connectionManager,
            Codecs codecs, RoutingContext ctx, TrafficLogger trafficLogger) {
        super(Map.copyOf(ctx.pathParams()), codecs, new HandshakeRequestImpl(webSocket, ctx), trafficLogger);
        this.generatedEndpointClass = generatedEndpointClass;
        this.endpointId = endpointClass;
        this.webSocket = Objects.requireNonNull(webSocket);
        this.connectionManager = Objects.requireNonNull(connectionManager);
        this.defaultBroadcast = new BroadcastImpl(null);
    }

    @Override
    WebSocketBase webSocket() {
        return webSocket;
    }

    @Override
    public String endpointId() {
        return endpointId;
    }

    @Override
    public BroadcastSender broadcast() {
        return defaultBroadcast;
    }

    @Override
    public Set getOpenConnections() {
        return connectionManager.getConnections(generatedEndpointClass).stream().filter(WebSocketConnection::isOpen)
                .collect(Collectors.toUnmodifiableSet());
    }

    @Override
    public String subprotocol() {
        return webSocket.subProtocol();
    }

    @Override
    public String toString() {
        return "WebSocket connection [endpointId=" + endpointId + ", path=" + webSocket.path() + ", id=" + identifier + "]";
    }

    @Override
    public int hashCode() {
        return Objects.hash(identifier);
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        WebSocketConnectionImpl other = (WebSocketConnectionImpl) obj;
        return Objects.equals(identifier, other.identifier);
    }

    private static class HandshakeRequestImpl implements HandshakeRequest {

        private final ServerWebSocket webSocket;

        private final Map> headers;

        HandshakeRequestImpl(ServerWebSocket webSocket, RoutingContext ctx) {
            this.webSocket = webSocket;
            this.headers = initHeaders(ctx);
        }

        @Override
        public String header(String name) {
            List values = headers(name);
            return values.isEmpty() ? null : values.get(0);
        }

        @Override
        public List headers(String name) {
            return headers.getOrDefault(Objects.requireNonNull(name).toLowerCase(), List.of());
        }

        @Override
        public Map> headers() {
            return headers;
        }

        @Override
        public String scheme() {
            return webSocket.scheme();
        }

        @Override
        public String host() {
            return webSocket.authority().host();
        }

        @Override
        public int port() {
            return webSocket.authority().port();
        }

        @Override
        public String path() {
            return webSocket.path();
        }

        @Override
        public String query() {
            return webSocket.query();
        }

        static Map> initHeaders(RoutingContext ctx) {
            Map> headers = new HashMap<>();
            for (Entry e : ctx.request().headers()) {
                String key = e.getKey().toLowerCase();
                List values = headers.get(key);
                if (values == null) {
                    values = new ArrayList<>();
                    headers.put(key, values);
                }
                values.add(e.getValue());
            }
            for (Entry> e : headers.entrySet()) {
                // Make the list of values immutable
                e.setValue(List.copyOf(e.getValue()));
            }
            return Map.copyOf(headers);
        }

    }

    private class BroadcastImpl implements WebSocketConnection.BroadcastSender {

        private static final BiFunction> SEND_TEXT_STR = new BiFunction<>() {
            @Override
            public Uni apply(WebSocketConnection c, String s) {
                return c.sendText(s);
            }
        };
        private static final BiFunction> SEND_TEXT_POJO = new BiFunction<>() {
            @Override
            public Uni apply(WebSocketConnection c, Object o) {
                return c.sendText(o);
            }
        };
        private static final BiFunction> SEND_BINARY = new BiFunction<>() {
            @Override
            public Uni apply(WebSocketConnection c, Buffer b) {
                return c.sendBinary(b);
            }
        };

        private final Predicate filter;

        BroadcastImpl(Predicate filter) {
            this.filter = filter;
        }

        @Override
        public BroadcastSender filter(Predicate predicate) {
            return new BroadcastImpl(Objects.requireNonNull(predicate));
        }

        @Override
        public Uni sendText(String message) {
            return doSend(SEND_TEXT_STR, message);
        }

        @Override
        public  Uni sendText(M message) {
            return doSend(SEND_TEXT_POJO, message);
        }

        @Override
        public Uni sendBinary(Buffer message) {
            return doSend(SEND_BINARY, message);
        }

        @Override
        public Uni sendPing(Buffer data) {
            throw new UnsupportedOperationException();
        }

        @Override
        public Uni sendPong(Buffer data) {
            throw new UnsupportedOperationException();
        }

        private  Uni doSend(BiFunction> sendFunction, M message) {
            Set connections = connectionManager.getConnections(generatedEndpointClass);
            if (connections.isEmpty()) {
                return Uni.createFrom().voidItem();
            }
            List> unis = new ArrayList<>(connections.size());
            for (WebSocketConnection connection : connections) {
                if (connection.isOpen()
                        && (filter == null || filter.test(connection))) {
                    unis.add(sendFunction.apply(connection, message)
                            // Intentionally ignore 'WebSocket is closed' failures
                            // It might happen that the connection is closed in the mean time
                            .onFailure(t -> Endpoints.isWebSocketIsClosedFailure(t, (WebSocketConnectionBase) connection))
                            .recoverWithNull());
                }
            }
            if (unis.isEmpty()) {
                return Uni.createFrom().voidItem();
            }
            return Uni.join().all(unis).andCollectFailures().replaceWithVoid();
        }

    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy