All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.micronaut.http.netty.websocket.NettyWebSocketSession Maven / Gradle / Ivy

/*
 * Copyright 2017-2020 original authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.micronaut.http.netty.websocket;

import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.convert.ArgumentConversionContext;
import io.micronaut.core.convert.value.ConvertibleMultiValues;
import io.micronaut.core.convert.value.MutableConvertibleValues;
import io.micronaut.core.convert.value.MutableConvertibleValuesMap;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.MediaType;
import io.micronaut.http.codec.MediaTypeCodecRegistry;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketSession;
import io.micronaut.websocket.exceptions.WebSocketSessionException;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.AttributeKey;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;

import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

/**
 * Implementation of the {@link WebSocketSession} interface for Netty.
 *
 * @author graemerocher
 * @since 1.0
 */
@Internal
public class NettyWebSocketSession implements WebSocketSession {
    /**
     * The WebSocket session is stored within a Channel attribute using the given key.
     */
    public static final AttributeKey WEB_SOCKET_SESSION_KEY = AttributeKey.newInstance("micronaut.websocket.session");

    private final String id;
    private final Channel channel;
    private final HttpRequest request;
    private final String protocolVersion;
    private final boolean isSecure;
    private final MediaTypeCodecRegistry codecRegistry;
    private final MutableConvertibleValues attributes;
    private final WebSocketMessageEncoder messageEncoder;

    /**
     * Creates a new netty web socket session.
     * @param id The ID
     * @param channel The channel
     * @param request The original request used to create the session
     * @param codecRegistry The codec registry
     * @param protocolVersion The protocol version
     * @param isSecure Whether the session is secure
     */
    protected NettyWebSocketSession(
            String id,
            Channel channel,
            HttpRequest request,
            MediaTypeCodecRegistry codecRegistry,
            String protocolVersion,
            boolean isSecure) {
        this.id = id;
        this.channel = channel;
        this.request = request;
        this.protocolVersion = protocolVersion;
        this.isSecure = isSecure;
        this.channel.attr(WEB_SOCKET_SESSION_KEY).set(this);
        this.codecRegistry = codecRegistry;
        this.messageEncoder = new WebSocketMessageEncoder(this.codecRegistry);
        this.attributes = request.getAttribute("micronaut.SESSION", MutableConvertibleValues.class).orElseGet(MutableConvertibleValuesMap::new);
    }

    @Override
    public String getId() {
        return id;
    }

    @Override
    public MutableConvertibleValues getAttributes() {
        return attributes;
    }

    @Override
    public boolean isOpen() {
        return channel.isOpen() && channel.isActive();
    }

    @Override
    public boolean isWritable() {
        return channel.isWritable();
    }

    @Override
    public boolean isSecure() {
        return isSecure;
    }

    @Override
    public Set getOpenSessions() {
        return Collections.emptySet();
    }

    @Override
    public URI getRequestURI() {
        return request.getUri();
    }

    @Override
    public ConvertibleMultiValues getRequestParameters() {
        return request.getParameters();
    }

    @Override
    public String getProtocolVersion() {
        return protocolVersion;
    }

    @Override
    public  CompletableFuture sendAsync(T message, MediaType mediaType) {
        if (isOpen()) {
            if (message != null) {
                CompletableFuture future = new CompletableFuture<>();
                WebSocketFrame frame;
                if (message instanceof WebSocketFrame) {
                    frame = (WebSocketFrame) message;
                } else {
                    frame = messageEncoder.encodeMessage(message, mediaType);
                }
                channel.writeAndFlush(frame).addListener(f -> {
                    if (f.isSuccess()) {
                        future.complete(message);
                    } else {
                        future.completeExceptionally(new WebSocketSessionException("Send Failure: " + f.cause().getMessage(), f.cause()));
                    }
                });
                return future;
            } else {
                return CompletableFuture.completedFuture(null);
            }
        } else {
            throw new WebSocketSessionException("Session closed");
        }
    }

    @Override
    public void sendSync(Object message, MediaType mediaType) {
        if (isOpen()) {
            if (message != null) {
                try {
                    WebSocketFrame frame;
                    if (message instanceof WebSocketFrame) {
                        frame = (WebSocketFrame) message;
                    } else {
                        frame = messageEncoder.encodeMessage(message, mediaType);
                    }
                    channel.writeAndFlush(frame).sync().get();
                } catch (InterruptedException e) {
                    throw new WebSocketSessionException("Send interrupt: " + e.getMessage(), e);
                } catch (ExecutionException e) {
                    throw new WebSocketSessionException("Send Failure: " + e.getMessage(), e);
                }
            }
        } else {
            throw new WebSocketSessionException("Session closed");
        }
    }

    @Override
    public  Flux send(T message, MediaType mediaType) {
        if (message == null) {
            return Flux.empty();
        }

        return Flux.create(emitter -> {
            if (!isOpen()) {
                emitter.error(new WebSocketSessionException("Session closed"));
            } else {
                WebSocketFrame frame;
                if (message instanceof WebSocketFrame) {
                    frame = (WebSocketFrame) message;
                } else {
                    frame = messageEncoder.encodeMessage(message, mediaType);
                }

                ChannelFuture channelFuture = channel.writeAndFlush(frame);
                channelFuture.addListener(future -> {
                    if (future.isSuccess()) {
                        emitter.next(message);
                        emitter.complete();
                    } else {
                        emitter.error(new WebSocketSessionException("Send Failure: " + future.cause().getMessage(), future.cause()));
                    }
                });
            }
        }, FluxSink.OverflowStrategy.ERROR);
    }

    @NonNull
    @Override
    public CompletableFuture sendPingAsync(@NonNull byte[] content) {
        if (isOpen()) {
            ByteBuf messageBuffer = channel.alloc().buffer(content.length);
            messageBuffer.writeBytes(content);
            PingWebSocketFrame frame = new PingWebSocketFrame(messageBuffer);
            CompletableFuture future = new CompletableFuture<>();
            channel.writeAndFlush(frame).addListener(f -> {
                if (f.isSuccess()) {
                    future.complete(null);
                } else {
                    future.completeExceptionally(new WebSocketSessionException("Send Failure: " + f.cause().getMessage(), f.cause()));
                }
            });
            return future;
        } else {
            throw new WebSocketSessionException("Session closed");
        }
    }

    @Override
    public void close() {
        close(CloseReason.NORMAL);
    }

    @Override
    public void close(CloseReason closeReason) {
        if (channel.isOpen()) {
            channel.writeAndFlush(new CloseWebSocketFrame(closeReason.getCode(), closeReason.getReason()))
                    .addListener(future -> channel.close());
        }
    }

    @Override
    public String toString() {
        return "WebSocket Session: " + getId();
    }

    @Override
    public MutableConvertibleValues put(CharSequence key, @Nullable Object value) {
        return attributes.put(key, value);
    }

    @Override
    public MutableConvertibleValues remove(CharSequence key) {
        return attributes.remove(key);
    }

    @Override
    public MutableConvertibleValues clear() {
        return attributes.clear();
    }

    @Override
    public Set names() {
        return attributes.names();
    }

    @Override
    public Collection values() {
        return attributes.values();
    }

    @Override
    public  Optional get(CharSequence name, ArgumentConversionContext conversionContext) {
        return attributes.get(name, conversionContext);
    }
}