All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.protocol.servlet.websocket.WebSocketSession Maven / Gradle / Ivy

package com.github.netty.protocol.servlet.websocket;

import com.github.netty.core.util.IOUtil;
import com.github.netty.core.util.TypeUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.internal.PlatformDependent;

import javax.websocket.*;
import javax.websocket.server.ServerEndpointConfig;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Websocket session
 *
 * @author wangzihao
 * 2018/10/13/013
 */
public class WebSocketSession implements Session {
    public static final AttributeKey CHANNEL_ATTR_KEY_SESSION = AttributeKey.valueOf(WebSocketSession.class + "#WebSocketSession");
    public static final SendResult SEND_RESULT_OK = new SendResult();
    private static final AtomicLong ids = new AtomicLong(0);
    private final Channel channel;
    private final WebSocketServerHandshaker13Extension webSocketServerHandshaker;
    private final Endpoint localEndpoint;
    private final WebSocketServerContainer webSocketContainer;
    private final String id;
    private final URI requestUri;
    private final Map> requestParameterMap;
    private final String queryString;
    private final Principal userPrincipal;
    private final String httpSessionId;
    private final List negotiatedExtensions;
    private final Map pathParameters;
    private final ServerEndpointConfig serverEndpointConfig;
    private final Set messageHandlers = new LinkedHashSet<>();
    private final List encoderEntries = new ArrayList<>();
    private final Map userProperties = new ConcurrentHashMap<>();
    private AsyncRemoteEndpoint asyncRemoteEndpoint;
    private BasicRemoteEndpoint basicRemoteEndpoint;
    private volatile State state = State.OPEN;
    private int maxBinaryMessageBufferSize;
    private int maxTextMessageBufferSize;
    private long maxIdleTimeout;
    private long asyncSendTimeout;
    private int rsv;

    public WebSocketSession(Channel channel, WebSocketServerContainer webSocketContainer,
                            WebSocketServerHandshaker13Extension webSocketServerHandshaker,
                            Map> requestParameterMap,
                            String queryString, Principal userPrincipal, String httpSessionId,
                            List negotiatedExtensions, Map pathParameters,
                            Endpoint localEndpoint, ServerEndpointConfig serverEndpointConfig) throws DeploymentException {
        this.id = Long.toHexString(ids.getAndIncrement());
        this.webSocketContainer = webSocketContainer;
        this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
        this.asyncSendTimeout = webSocketContainer.getDefaultAsyncSendTimeout();
        this.channel = channel;
        this.webSocketServerHandshaker = webSocketServerHandshaker;
//        this.rsv = webSocketServerHandshaker.getRsv();
        this.maxTextMessageBufferSize = webSocketServerHandshaker.maxFramePayloadLength();
        this.maxBinaryMessageBufferSize = webSocketServerHandshaker.maxFramePayloadLength();
        this.localEndpoint = localEndpoint;

        this.requestUri = URI.create(webSocketServerHandshaker.uri());
        this.requestParameterMap = requestParameterMap;
        this.queryString = queryString;
        this.userPrincipal = userPrincipal;
        this.httpSessionId = httpSessionId;
        this.negotiatedExtensions = negotiatedExtensions;
        this.pathParameters = pathParameters;
        this.serverEndpointConfig = serverEndpointConfig;

        webSocketContainer.registerSession(localEndpoint, this);

        for (Class encoderClazz : serverEndpointConfig.getEncoders()) {
            Encoder instance;
            try {
                instance = encoderClazz.getConstructor().newInstance();
                instance.init(serverEndpointConfig);
            } catch (ReflectiveOperationException e) {
                throw new DeploymentException("The specified encoder of type [" + encoderClazz.getName() + "] could not be instantiated",
                        e);
            }
            TypeUtil.TypeResult typeResult = TypeUtil.getGenericType(Encoder.class, encoderClazz);
            Class type = typeResult == null ? Object.class : typeResult.getClazz();
            EncoderEntry entry = new EncoderEntry(type, instance);
            encoderEntries.add(entry);
        }
        channel.closeFuture().addListener(e -> {
            if (isOpen()) {
                closeByAbort();
            }
        });
    }

    /**
     * Get httpSession from the properties bound in the pipe
     *
     * @param channel channel
     * @return WebSocketSession
     */
    public static WebSocketSession getSession(Channel channel) {
        if (isChannelActive(channel)) {
            Attribute attribute = channel.attr(CHANNEL_ATTR_KEY_SESSION);
            if (attribute != null) {
                return attribute.get();
            }
        }
        return null;
    }

    /**
     * Bind WebsocketSession to the pipe property
     *
     * @param websocketSession websocketSession
     * @param channel          channel
     */
    public static void setSession(Channel channel, WebSocketSession websocketSession) {
        if (isChannelActive(channel)) {
            channel.attr(CHANNEL_ATTR_KEY_SESSION).set(websocketSession);
        }
    }

    /**
     * Whether the pipe is active
     *
     * @param channel channel
     * @return boolean isChannelActive
     */
    public static boolean isChannelActive(Channel channel) {
        return channel != null && channel.isActive();
    }

    private static CloseReason valueOf(CloseWebSocketFrame frame) {
        int code = frame.statusCode();
        if (code < 1000 || code > 4999) {
            return new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "");
        } else {
            return new CloseReason(CloseReason.CloseCodes.getCloseCode(code), frame.reasonText());
        }
    }

    public ServerEndpointConfig getServerEndpointConfig() {
        return serverEndpointConfig;
    }

    public WebSocketServerHandshaker13Extension getWebSocketServerHandshaker() {
        return webSocketServerHandshaker;
    }

    @Override
    public boolean equals(Object obj) {
        if (obj instanceof WebSocketSession) {
            return Objects.equals(((WebSocketSession) obj).id, this.id);
        }
        return false;
    }

    @Override
    public int hashCode() {
        return id.hashCode();
    }

    @Override
    public WebSocketServerContainer getContainer() {
        return webSocketContainer;
    }

    @Override
    public Set getMessageHandlers() {
        checkState();
        return messageHandlers;
    }

    @Override
    public void removeMessageHandler(MessageHandler listener) {
        checkState();
        messageHandlers.remove(listener);
    }

    @Override
    public String getProtocolVersion() {
        checkState();
        return webSocketServerHandshaker.version().toHttpHeaderValue();
    }

    @Override
    public String getNegotiatedSubprotocol() {
        checkState();
        return String.join(",", webSocketServerHandshaker.subprotocols());
    }

    @Override
    public List getNegotiatedExtensions() {
        checkState();
        return negotiatedExtensions;
    }

    public String getHttpSessionId() {
        return httpSessionId;
    }

    @Override
    public boolean isSecure() {
        checkState();
        return "wss".equalsIgnoreCase(requestUri.getScheme());
    }

    @Override
    public boolean isOpen() {
        return state == State.OPEN;
    }

    @Override
    public long getMaxIdleTimeout() {
        checkState();
        return maxIdleTimeout;
    }

    @Override
    public void setMaxIdleTimeout(long timeout) {
        checkState();
        this.maxIdleTimeout = timeout;
    }

    @Override
    public int getMaxBinaryMessageBufferSize() {
        checkState();
        return maxBinaryMessageBufferSize;
    }

    @Override
    public void setMaxBinaryMessageBufferSize(int max) {
        checkState();
        this.maxBinaryMessageBufferSize = max;
    }

    @Override
    public int getMaxTextMessageBufferSize() {
        checkState();
        return maxTextMessageBufferSize;
    }

    @Override
    public void setMaxTextMessageBufferSize(int max) {
        checkState();
        this.maxTextMessageBufferSize = max;
    }

    @Override
    public RemoteEndpoint.Async getAsyncRemote() {
        checkState();
        if (asyncRemoteEndpoint == null) {
            synchronized (this) {
                if (asyncRemoteEndpoint == null) {
                    asyncRemoteEndpoint = new AsyncRemoteEndpoint();
                }
            }
        }
        return asyncRemoteEndpoint;
    }

    @Override
    public RemoteEndpoint.Basic getBasicRemote() {
        checkState();
        if (basicRemoteEndpoint == null) {
            synchronized (this) {
                if (basicRemoteEndpoint == null) {
                    basicRemoteEndpoint = new BasicRemoteEndpoint();
                }
            }
        }
        return basicRemoteEndpoint;
    }

    @Override
    public String getId() {
        return id;
    }

    @Override
    public void close() throws IOException {
        close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, ""));
    }

    @Override
    public void close(CloseReason closeReason) throws IOException {
        destroy(closeReason);
        webSocketServerHandshaker.close(channel, new CloseWebSocketFrame(true, rsv, closeReason.getCloseCode().getCode(), closeReason.getReasonPhrase()));
    }

    void closeByClient(CloseWebSocketFrame frame) {
        destroy(valueOf(frame));
        channel.close();
    }

    void closeByAbort() {
        destroy(new CloseReason(CloseReason.CloseCodes.CLOSED_ABNORMALLY, ""));
    }

    void destroy(CloseReason closeReason) {
        state = State.OUTPUT_CLOSED;
        localEndpoint.onClose(this, closeReason);
        for (EncoderEntry entry : encoderEntries) {
            try {
                entry.getEncoder().destroy();
            } catch (Throwable e) {
                onError(e);
            }
        }
        webSocketContainer.unregisterSession(localEndpoint, this);
        state = State.CLOSED;
    }

    public void onError(Throwable thr) {
        localEndpoint.onError(this, thr);
    }

    @Override
    public URI getRequestURI() {
        checkState();
        return requestUri;
    }

    @Override
    public Map> getRequestParameterMap() {
        checkState();
        return requestParameterMap;
    }

    @Override
    public String getQueryString() {
        checkState();
        return queryString;
    }

    @Override
    public Map getPathParameters() {
        checkState();
        return pathParameters;
    }

    @Override
    public Map getUserProperties() {
        checkState();
        return userProperties;
    }

    @Override
    public Principal getUserPrincipal() {
        checkState();
        return userPrincipal;
    }

    @Override
    public Set getOpenSessions() {
        return webSocketContainer.getOpenSessions(localEndpoint);
    }

    @Override
    public void addMessageHandler(MessageHandler handler) throws IllegalStateException {
        checkState();
        messageHandlers.add(handler);
    }

    @Override
    public  void addMessageHandler(Class clazz, MessageHandler.Partial handler) throws IllegalStateException {
        addMessageHandler(handler);
    }

    @Override
    public  void addMessageHandler(Class clazz, MessageHandler.Whole handler) throws IllegalStateException {
        addMessageHandler(handler);
    }

    private void checkState() {
        if (state == State.CLOSED) {
            /*
             * As per RFC 6455, a WebSocket connection is considered to be
             * closed once a peer has sent and received a WebSocket close frame.
             */
            throw new IllegalStateException("The WebSocket session [" + id + "] has been closed and no method (apart from close()) may be called on a closed session");
        }
    }

    private Encoder findEncoder(Object obj) {
        for (EncoderEntry entry : encoderEntries) {
            if (entry.getClazz().isAssignableFrom(obj.getClass())) {
                return entry.getEncoder();
            }
        }
        return null;
    }

    public Future sendObjectImpl(Object obj, SendHandler completion) {
        if (obj == null) {
            throw new IllegalArgumentException("Invalid null data argument");
        }

        /*
         * Note that the implementation will convert primitives and their object
         * equivalents by default but that users are free to specify their own
         * encoders and decoders for this if they wish.
         */
        Encoder encoder = findEncoder(obj);
        ChannelFuture future = null;
        Exception exception = null;
        try {
            if (encoder == null && TypeUtil.isPrimitive(obj.getClass())) {
                String msg = obj.toString();
                future = channel.writeAndFlush(new TextWebSocketFrame(true, rsv, msg));
            } else if (encoder == null && obj instanceof ByteBuffer) {
                ByteBuffer msg = (ByteBuffer) obj;
                future = channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(msg)));
            } else if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
                ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
                future = channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(msg)));
            } else if (encoder instanceof Encoder.Text) {
                String msg = ((Encoder.Text) encoder).encode(obj);
                future = channel.writeAndFlush(new TextWebSocketFrame(true, rsv, msg));
            } else if (encoder instanceof Encoder.TextStream) {
                try (Writer w = getBasicRemote().getSendWriter()) {
                    ((Encoder.TextStream) encoder).encode(obj, w);
                }
            } else if (encoder instanceof Encoder.Binary) {
                ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
                future = channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(msg)));
                if (completion != null) {
                    future.addListener((ChannelFutureListener) future12 -> completion.onResult(newSendResult(future12)));
                }
            } else if (encoder instanceof Encoder.BinaryStream) {
                try (OutputStream os = getBasicRemote().getSendStream()) {
                    ((Encoder.BinaryStream) encoder).encode(obj, os);
                }
            } else {
                throw new EncodeException(obj, "No encoder specified for object of class [" + obj.getClass() + "]");
            }
        } catch (Exception e) {
            exception = e;
        }

        if (completion != null) {
            if (future == null) {
                completion.onResult(exception == null ? new SendResult() : new SendResult(exception));
            } else {
                future.addListener((ChannelFutureListener) future1 -> completion.onResult(newSendResult(future1)));
            }
        }
        if (exception != null) {
            onError(exception);
        }
        return future == null ? new SimpleFuture(exception) : future;
    }

    private SendResult newSendResult(ChannelFuture future) {
        if (future.isSuccess()) {
            return SEND_RESULT_OK;
        } else {
            return new SendResult(future.cause());
        }
    }

    private void sync(ChannelFuture future) {
        if (future.channel().eventLoop().inEventLoop()) {
            return;
        }
        try {
            future.sync();
        } catch (InterruptedException e) {
            PlatformDependent.throwException(e);
        }
    }

    public enum State {
        OPEN,
        OUTPUT_CLOSED,
        CLOSED
    }

    public static class SimpleFuture implements Future {
        private Exception exception;

        SimpleFuture(Exception exception) {
            this.exception = exception;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return false;
        }

        @Override
        public boolean isCancelled() {
            return false;
        }

        @Override
        public boolean isDone() {
            return true;
        }

        @Override
        public Void get() throws InterruptedException, ExecutionException {
            if (exception != null) {
                throw new ExecutionException(exception);
            }
            return null;
        }

        @Override
        public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            return get();
        }
    }

    public static class EncoderEntry {
        private final Class clazz;
        private final Encoder encoder;

        EncoderEntry(Class clazz, Encoder encoder) {
            this.clazz = clazz;
            this.encoder = encoder;
        }

        public Class getClazz() {
            return clazz;
        }

        public Encoder getEncoder() {
            return encoder;
        }
    }

    class AsyncRemoteEndpoint implements RemoteEndpoint.Async {
        private final AtomicBoolean batchingAllowed = new AtomicBoolean(false);

        @Override
        public long getSendTimeout() {
            return asyncSendTimeout;
        }

        @Override
        public void setSendTimeout(long timeout) {
            // TODO: 10-16/0016 The socket send timeout setting is not implemented
            asyncSendTimeout = timeout;
        }

        @Override
        public void sendText(String text, SendHandler completion) {
            channel.writeAndFlush(new TextWebSocketFrame(true, rsv, text))
                    .addListener((ChannelFutureListener) future -> completion.onResult(newSendResult(future)));
        }

        @Override
        public Future sendText(String text) {
            return channel.writeAndFlush(new TextWebSocketFrame(true, rsv, text));
        }

        @Override
        public Future sendBinary(ByteBuffer data) {
            return channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(data)));
        }

        @Override
        public void sendBinary(ByteBuffer data, SendHandler completion) {
            if (completion == null) {
                throw new IllegalArgumentException("Invalid null handler argument");
            }
            channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(data)))
                    .addListener((ChannelFutureListener) future -> completion.onResult(newSendResult(future)));
        }

        @Override
        public Future sendObject(Object obj) {
            return sendObjectImpl(obj, null);
        }

        @Override
        public void sendObject(Object obj, SendHandler completion) {
            if (completion == null) {
                throw new IllegalArgumentException("Invalid null handler argument");
            }
            sendObjectImpl(obj, completion);
        }

        @Override
        public boolean getBatchingAllowed() {
            return batchingAllowed.get();
        }

        @Override
        public void setBatchingAllowed(boolean batchingAllowed) throws IOException {
            boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
            if (oldValue && !batchingAllowed) {
                flushBatch();
            }
        }

        @Override
        public void flushBatch() throws IOException {
            channel.flush();
        }

        @Override
        public void sendPing(ByteBuffer applicationData) throws IOException, IllegalArgumentException {
            ChannelFuture future = channel.writeAndFlush(new PingWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(applicationData)));
            sync(future);
        }

        @Override
        public void sendPong(ByteBuffer applicationData) throws IOException, IllegalArgumentException {
            ChannelFuture future = channel.writeAndFlush(new PongWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(applicationData)));
            sync(future);
        }
    }

    class BasicRemoteEndpoint implements RemoteEndpoint.Basic {
        private final AtomicBoolean batchingAllowed = new AtomicBoolean(false);
        private Writer writer;
        private OutputStream outputStream;

        @Override
        public void sendText(String text) throws IOException {
            ChannelFuture future = channel.writeAndFlush(new TextWebSocketFrame(true, rsv, text));
            sync(future);
        }

        @Override
        public void sendBinary(ByteBuffer data) throws IOException {
            ChannelFuture future = channel.writeAndFlush(new BinaryWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(data)));
            sync(future);
        }

        @Override
        public void sendText(String fragment, boolean isLast) throws IOException {
            ChannelFuture future = channel.writeAndFlush(new TextWebSocketFrame(isLast, rsv, fragment));
            sync(future);
        }

        @Override
        public void sendBinary(ByteBuffer partialByte, boolean isLast) throws IOException {
            ChannelFuture future = channel.writeAndFlush(new BinaryWebSocketFrame(isLast, rsv, Unpooled.wrappedBuffer(partialByte)));
            sync(future);
        }

        @Override
        public OutputStream getSendStream() throws IOException {
            if (outputStream == null) {
                synchronized (this) {
                    if (outputStream == null) {
                        outputStream = new BasicRemoteOutputStream();
                    }
                }
            }
            return outputStream;
        }

        @Override
        public Writer getSendWriter() throws IOException {
            if (writer == null) {
                synchronized (this) {
                    if (writer == null) {
                        writer = new OutputStreamWriter(getSendStream(), Charset.forName("UTF-8"));
                    }
                }
            }
            return writer;
        }

        @Override
        public void sendObject(Object data) throws IOException, EncodeException {
            Future future = sendObjectImpl(data, null);
            try {
                future.get();
            } catch (InterruptedException e) {
                PlatformDependent.throwException(e);
            } catch (ExecutionException e) {
                Throwable cause = e.getCause();
                if (cause == null) {
                    cause = e;
                }
                PlatformDependent.throwException(cause);
            }
        }

        @Override
        public boolean getBatchingAllowed() {
            return batchingAllowed.get();
        }

        @Override
        public void setBatchingAllowed(boolean batchingAllowed) throws IOException {
            boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
            if (oldValue && !batchingAllowed) {
                flushBatch();
            }
        }

        @Override
        public void flushBatch() throws IOException {
            channel.flush();
        }

        @Override
        public void sendPing(ByteBuffer applicationData) throws IOException, IllegalArgumentException {
            ChannelFuture future = channel.writeAndFlush(new PingWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(applicationData)));
            sync(future);
        }

        @Override
        public void sendPong(ByteBuffer applicationData) throws IOException, IllegalArgumentException {
            ChannelFuture future = channel.writeAndFlush(new PongWebSocketFrame(true, rsv, Unpooled.wrappedBuffer(applicationData)));
            sync(future);
        }

        class BasicRemoteOutputStream extends OutputStream {
            @Override
            public void write(int b) throws IOException {
                int byteLen = 1;
                byte[] bytes = new byte[byteLen];
                IOUtil.setByte(bytes, 0, b);
                write(bytes, 0, byteLen);
            }

            @Override
            public void write(byte[] b, int off, int len) throws IOException {
                ByteBuf ioByteBuf = channel.alloc().heapBuffer(len);
                ioByteBuf.writeBytes(b, off, len);
                sync(channel.write(ioByteBuf));
            }

            @Override
            public void flush() throws IOException {
                channel.flush();
            }

            @Override
            public void close() throws IOException {
                flush();
            }
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy