All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.eclipse.jetty.websocket.javax.common.JavaxWebSocketFrameHandler Maven / Gradle / Ivy

The newest version!
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.javax.common;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.Decoder;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.PongMessage;
import javax.websocket.server.ServerEndpointConfig;

import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.thread.AutoLock;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.CloseException;
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
import org.eclipse.jetty.websocket.core.internal.messages.MessageSink;
import org.eclipse.jetty.websocket.core.internal.messages.PartialByteArrayMessageSink;
import org.eclipse.jetty.websocket.core.internal.messages.PartialByteBufferMessageSink;
import org.eclipse.jetty.websocket.core.internal.messages.PartialStringMessageSink;
import org.eclipse.jetty.websocket.core.internal.util.InvokerUtils;
import org.eclipse.jetty.websocket.javax.common.decoders.AvailableDecoders;
import org.eclipse.jetty.websocket.javax.common.decoders.RegisteredDecoder;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedBinaryMessageSink;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedBinaryStreamMessageSink;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedTextMessageSink;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedTextStreamMessageSink;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JavaxWebSocketFrameHandler implements FrameHandler
{
    private final AutoLock lock = new AutoLock();
    private final Logger logger;
    private final JavaxWebSocketContainer container;
    private final Object endpointInstance;
    private final AtomicBoolean closeNotified = new AtomicBoolean();

    private MethodHandle openHandle;
    private MethodHandle closeHandle;
    private MethodHandle errorHandle;
    private MethodHandle pongHandle;
    private JavaxWebSocketMessageMetadata textMetadata;
    private JavaxWebSocketMessageMetadata binaryMetadata;
    private final UpgradeRequest upgradeRequest;
    private EndpointConfig endpointConfig;
    private final Map messageHandlerMap = new HashMap<>();
    private MessageSink textSink;
    private MessageSink binarySink;
    private MessageSink activeMessageSink;
    private JavaxWebSocketSession session;
    private CoreSession coreSession;
    protected byte dataType = OpCode.UNDEFINED;

    public JavaxWebSocketFrameHandler(JavaxWebSocketContainer container,
                                      UpgradeRequest upgradeRequest,
                                      Object endpointInstance,
                                      MethodHandle openHandle, MethodHandle closeHandle, MethodHandle errorHandle,
                                      JavaxWebSocketMessageMetadata textMetadata,
                                      JavaxWebSocketMessageMetadata binaryMetadata,
                                      MethodHandle pongHandle,
                                      EndpointConfig endpointConfig)
    {
        this.logger = LoggerFactory.getLogger(endpointInstance.getClass());

        this.container = container;
        this.upgradeRequest = upgradeRequest;
        if (endpointInstance instanceof ConfiguredEndpoint)
        {
            RuntimeException oops = new RuntimeException("ConfiguredEndpoint needs to be unwrapped");
            logger.warn("Unexpected ConfiguredEndpoint", oops);
            throw oops;
        }
        this.endpointInstance = endpointInstance;
        this.openHandle = openHandle;
        this.closeHandle = closeHandle;
        this.errorHandle = errorHandle;
        this.textMetadata = textMetadata;
        this.binaryMetadata = binaryMetadata;
        this.pongHandle = pongHandle;

        this.endpointConfig = endpointConfig;
    }

    public Object getEndpoint()
    {
        return endpointInstance;
    }

    public EndpointConfig getEndpointConfig()
    {
        return endpointConfig;
    }

    public JavaxWebSocketSession getSession()
    {
        return session;
    }

    @Override
    public void onOpen(CoreSession coreSession, Callback callback)
    {
        this.coreSession = coreSession;

        try
        {
            // Rewire EndpointConfig to call CoreSession setters if Jetty specific properties are set.
            endpointConfig = getWrappedEndpointConfig();
            session = new JavaxWebSocketSession(container, coreSession, this, endpointConfig);
            if (!session.isOpen())
                throw new IllegalStateException("Session is not open");

            openHandle = InvokerUtils.bindTo(openHandle, session, endpointConfig);
            closeHandle = InvokerUtils.bindTo(closeHandle, session);
            errorHandle = InvokerUtils.bindTo(errorHandle, session);
            pongHandle = InvokerUtils.bindTo(pongHandle, session);

            JavaxWebSocketMessageMetadata actualTextMetadata = JavaxWebSocketMessageMetadata.copyOf(textMetadata);
            if (actualTextMetadata != null)
            {
                if (actualTextMetadata.isMaxMessageSizeSet())
                    session.setMaxTextMessageBufferSize(actualTextMetadata.getMaxMessageSize());

                MethodHandle methodHandle = actualTextMetadata.getMethodHandle();
                methodHandle = InvokerUtils.bindTo(methodHandle, endpointInstance, endpointConfig, session);
                methodHandle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, session);
                actualTextMetadata.setMethodHandle(methodHandle);

                textSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualTextMetadata);
                textMetadata = actualTextMetadata;
            }

            JavaxWebSocketMessageMetadata actualBinaryMetadata = JavaxWebSocketMessageMetadata.copyOf(binaryMetadata);
            if (actualBinaryMetadata != null)
            {
                if (actualBinaryMetadata.isMaxMessageSizeSet())
                    session.setMaxBinaryMessageBufferSize(actualBinaryMetadata.getMaxMessageSize());

                MethodHandle methodHandle = actualBinaryMetadata.getMethodHandle();
                methodHandle = InvokerUtils.bindTo(methodHandle, endpointInstance, endpointConfig, session);
                methodHandle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, session);
                actualBinaryMetadata.setMethodHandle(methodHandle);

                binarySink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualBinaryMetadata);
                binaryMetadata = actualBinaryMetadata;
            }

            if (openHandle != null)
                openHandle.invoke();

            if (session.isOpen())
                container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionOpened(session));

            callback.succeeded();
            coreSession.demand(1);
        }
        catch (Throwable cause)
        {
            Exception wse = new WebSocketException(endpointInstance.getClass().getSimpleName() + " OPEN method error: " + cause.getMessage(), cause);
            callback.failed(wse);
        }
    }

    private EndpointConfig getWrappedEndpointConfig()
    {
        final Map listenerMap = new PutListenerMap(this.endpointConfig.getUserProperties(), this::configListener);

        EndpointConfig wrappedConfig;
        if (endpointConfig instanceof ServerEndpointConfig)
        {
            wrappedConfig = new ServerEndpointConfigWrapper((ServerEndpointConfig)endpointConfig)
            {
                @Override
                public Map getUserProperties()
                {
                    return listenerMap;
                }
            };
        }
        else if (endpointConfig instanceof ClientEndpointConfig)
        {
            wrappedConfig = new ClientEndpointConfigWrapper((ClientEndpointConfig)endpointConfig)
            {
                @Override
                public Map getUserProperties()
                {
                    return listenerMap;
                }
            };
        }
        else
        {
            wrappedConfig = new EndpointConfigWrapper(endpointConfig)
            {
                @Override
                public Map getUserProperties()
                {
                    return listenerMap;
                }
            };
        }

        return wrappedConfig;
    }

    @Override
    public void onFrame(Frame frame, Callback callback)
    {
        switch (frame.getOpCode())
        {
            case OpCode.TEXT:
                dataType = OpCode.TEXT;
                onText(frame, callback);
                break;
            case OpCode.BINARY:
                dataType = OpCode.BINARY;
                onBinary(frame, callback);
                break;
            case OpCode.CONTINUATION:
                onContinuation(frame, callback);
                break;
            case OpCode.PING:
                onPing(frame, callback);
                break;
            case OpCode.PONG:
                onPong(frame, callback);
                break;
            case OpCode.CLOSE:
                onClose(frame, callback);
                break;
            default:
                callback.failed(new IllegalStateException());
        }

        if (frame.isFin() && !frame.isControlFrame())
            dataType = OpCode.UNDEFINED;
    }

    public void onClose(Frame frame, Callback callback)
    {
        notifyOnClose(CloseStatus.getCloseStatus(frame), callback);
    }

    @Override
    public void onClosed(CloseStatus closeStatus, Callback callback)
    {
        if (activeMessageSink != null)
        {
            activeMessageSink.fail(new CloseException(closeStatus.getCode(), closeStatus.getCause()));
            activeMessageSink = null;
        }

        notifyOnClose(closeStatus, callback);
        container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionClosed(session));

        // Close AvailableEncoders and AvailableDecoders to call destroy() on any instances of Encoder/Encoder created.
        session.getDecoders().close();
        session.getEncoders().close();
    }

    private void notifyOnClose(CloseStatus closeStatus, Callback callback)
    {
        // Make sure onClose is only notified once.
        if (!closeNotified.compareAndSet(false, true))
        {
            callback.succeeded();
            return;
        }

        try
        {
            if (closeHandle != null)
            {
                CloseReason closeReason = new CloseReason(CloseReason.CloseCodes.getCloseCode(closeStatus.getCode()), closeStatus.getReason());
                closeHandle.invoke(closeReason);
            }

            callback.succeeded();
        }
        catch (Throwable cause)
        {
            callback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " CLOSE method error: " + cause.getMessage(), cause));
        }
    }

    @Override
    public void onError(Throwable cause, Callback callback)
    {
        try
        {
            if (errorHandle != null)
                errorHandle.invoke(cause);
            else
                logger.warn("Unhandled Error: " + endpointInstance, cause);
            callback.succeeded();
        }
        catch (Throwable t)
        {
            WebSocketException wsError = new WebSocketException(endpointInstance.getClass().getSimpleName() + " ERROR method error: " + cause.getMessage(), t);
            wsError.addSuppressed(cause);
            callback.failed(wsError);
        }
    }

    @Override
    public boolean isDemanding()
    {
        return true;
    }

    public Set getMessageHandlers()
    {
        return messageHandlerMap.values().stream()
            .map(RegisteredMessageHandler::getMessageHandler)
            .collect(Collectors.toUnmodifiableSet());
    }

    public Map getMessageHandlerMap()
    {
        return messageHandlerMap;
    }

    public JavaxWebSocketMessageMetadata getBinaryMetadata()
    {
        return binaryMetadata;
    }

    public JavaxWebSocketMessageMetadata getTextMetadata()
    {
        return textMetadata;
    }

    public  void addMessageHandler(Class clazz, MessageHandler.Partial handler)
    {
        try
        {
            MethodHandle methodHandle = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup()
                .findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(void.class, Object.class, boolean.class))
                .bindTo(handler);

            JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
            metadata.setMethodHandle(methodHandle);
            byte basicType;
            // MessageHandler.Partial has no decoder support!
            if (byte[].class.isAssignableFrom(clazz))
            {
                basicType = OpCode.BINARY;
                metadata.setSinkClass(PartialByteArrayMessageSink.class);
            }
            else if (ByteBuffer.class.isAssignableFrom(clazz))
            {
                basicType = OpCode.BINARY;
                metadata.setSinkClass(PartialByteBufferMessageSink.class);
            }
            else if (String.class.isAssignableFrom(clazz))
            {
                basicType = OpCode.TEXT;
                metadata.setSinkClass(PartialStringMessageSink.class);
            }
            else
            {
                throw new RuntimeException(
                    "Unable to add " + handler.getClass().getName() + " with type " + clazz + ": only supported types byte[], " + ByteBuffer.class.getName() +
                        ", " + String.class.getName());
            }

            // Register the Metadata as a MessageHandler.
            registerMessageHandler(clazz, handler, basicType, metadata);
        }
        catch (NoSuchMethodException e)
        {
            throw new IllegalStateException("Unable to find method", e);
        }
        catch (IllegalAccessException e)
        {
            throw new IllegalStateException("Unable to access " + handler.getClass().getName(), e);
        }
    }

    public  void addMessageHandler(Class clazz, MessageHandler.Whole handler)
    {
        try
        {
            MethodHandle methodHandle = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup()
                .findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(void.class, Object.class))
                .bindTo(handler);

            if (PongMessage.class.isAssignableFrom(clazz))
            {
                assertBasicTypeNotRegistered(OpCode.PONG, handler);
                this.pongHandle = methodHandle;
                registerMessageHandler(OpCode.PONG, clazz, handler, null);
                return;
            }

            AvailableDecoders availableDecoders = session.getDecoders();
            RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz);
            if (registeredDecoder == null)
                throw new IllegalStateException("Unable to find Decoder for type: " + clazz);

            // Create the message metadata specific to the MessageHandler type.
            JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
            metadata.setMethodHandle(methodHandle);
            byte basicType;
            if (registeredDecoder.implementsInterface(Decoder.Binary.class))
            {
                basicType = OpCode.BINARY;
                metadata.setRegisteredDecoders(availableDecoders.getBinaryDecoders(clazz));
                metadata.setSinkClass(DecodedBinaryMessageSink.class);
            }
            else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class))
            {
                basicType = OpCode.BINARY;
                metadata.setRegisteredDecoders(availableDecoders.getBinaryStreamDecoders(clazz));
                metadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
            }
            else if (registeredDecoder.implementsInterface(Decoder.Text.class))
            {
                basicType = OpCode.TEXT;
                metadata.setRegisteredDecoders(availableDecoders.getTextDecoders(clazz));
                metadata.setSinkClass(DecodedTextMessageSink.class);
            }
            else if (registeredDecoder.implementsInterface(Decoder.TextStream.class))
            {
                basicType = OpCode.TEXT;
                metadata.setRegisteredDecoders(availableDecoders.getTextStreamDecoders(clazz));
                metadata.setSinkClass(DecodedTextStreamMessageSink.class);
            }
            else
            {
                throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + clazz + " is unrecognized by declared decoders");
            }

            // Register the Metadata as a MessageHandler.
            registerMessageHandler(clazz, handler, basicType, metadata);
        }
        catch (NoSuchMethodException e)
        {
            throw new IllegalStateException("Unable to find method", e);
        }
        catch (IllegalAccessException e)
        {
            throw new IllegalStateException("Unable to access " + handler.getClass().getName(), e);
        }
    }

    private void assertBasicTypeNotRegistered(byte basicWebSocketType, MessageHandler replacement)
    {
        Object messageImpl;
        switch (basicWebSocketType)
        {
            case OpCode.TEXT:
                messageImpl = textSink;
                break;
            case OpCode.BINARY:
                messageImpl = binarySink;
                break;
            case OpCode.PONG:
                messageImpl = pongHandle;
                break;
            default:
                throw new IllegalStateException();
        }

        if (messageImpl != null)
        {
            throw new IllegalStateException("Cannot register " + replacement.getClass().getName() +
                ": Basic WebSocket type " + OpCode.name(basicWebSocketType) + " is already registered");
        }
    }

    private void registerMessageHandler(Class clazz, MessageHandler handler, byte basicMessageType, JavaxWebSocketMessageMetadata metadata)
    {
        assertBasicTypeNotRegistered(basicMessageType, handler);
        MessageSink messageSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, metadata);
        switch (basicMessageType)
        {
            case OpCode.TEXT:
                this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
                this.textMetadata = metadata;
                break;
            case OpCode.BINARY:
                this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
                this.binaryMetadata = metadata;
                break;
            default:
                throw new IllegalStateException();
        }
    }

    private  MessageSink registerMessageHandler(byte basicWebSocketMessageType, Class handlerType, MessageHandler handler, MessageSink messageSink)
    {
        try (AutoLock l = lock.lock())
        {
            RegisteredMessageHandler registeredHandler = messageHandlerMap.get(basicWebSocketMessageType);
            if (registeredHandler != null)
            {
                throw new IllegalStateException(String.format("Cannot register %s: Basic WebSocket type %s is already registered to %s",
                    handler.getClass().getName(),
                    OpCode.name(basicWebSocketMessageType),
                    registeredHandler.getMessageHandler().getClass().getName()
                ));
            }

            registeredHandler = new RegisteredMessageHandler(basicWebSocketMessageType, handlerType, handler);
            getMessageHandlerMap().put(registeredHandler.getWebsocketMessageType(), registeredHandler);
            return messageSink;
        }
    }

    public void removeMessageHandler(MessageHandler handler)
    {
        try (AutoLock l = lock.lock())
        {
            Optional> optionalEntry = messageHandlerMap.entrySet().stream()
                .filter((entry) -> entry.getValue().getMessageHandler().equals(handler))
                .findFirst();

            if (optionalEntry.isPresent())
            {
                byte key = optionalEntry.get().getKey();
                messageHandlerMap.remove(key);
                switch (key)
                {
                    case OpCode.PONG:
                        this.pongHandle = null;
                        break;
                    case OpCode.TEXT:
                        this.textMetadata = null;
                        this.textSink = null;
                        break;
                    case OpCode.BINARY:
                        this.binaryMetadata = null;
                        this.binarySink = null;
                        break;
                    default:
                        throw new IllegalStateException("Invalid MessageHandler type " + OpCode.name(key));
                }
            }
        }
    }

    public String toString()
    {
        StringBuilder ret = new StringBuilder();
        ret.append(this.getClass().getSimpleName());
        ret.append('@').append(Integer.toHexString(this.hashCode()));
        ret.append("[endpoint=");
        if (endpointInstance == null)
        {
            ret.append("");
        }
        else
        {
            ret.append(endpointInstance.getClass().getName());
        }
        ret.append(']');
        return ret.toString();
    }

    private void acceptMessage(Frame frame, Callback callback)
    {
        // No message sink is active
        if (activeMessageSink == null)
        {
            callback.succeeded();
            coreSession.demand(1);
            return;
        }

        // Accept the payload into the message sink
        MessageSink messageSink = activeMessageSink;
        if (frame.isFin())
            activeMessageSink = null;
        messageSink.accept(frame, callback);
    }

    public void onPing(Frame frame, Callback callback)
    {
        coreSession.sendFrame(new Frame(OpCode.PONG).setPayload(frame.getPayload()), Callback.from(() ->
        {
            callback.succeeded();
            coreSession.demand(1);
        }), false);
    }

    public void onPong(Frame frame, Callback callback)
    {
        if (pongHandle != null)
        {
            try
            {
                ByteBuffer payload = frame.getPayload();
                if (payload == null)
                    payload = BufferUtil.EMPTY_BUFFER;

                // Use JSR356 PongMessage interface
                JavaxWebSocketPongMessage pongMessage = new JavaxWebSocketPongMessage(payload);
                pongHandle.invoke(pongMessage);
            }
            catch (Throwable cause)
            {
                throw new WebSocketException(endpointInstance.getClass().getSimpleName() + " PONG method error: " + cause.getMessage(), cause);
            }
        }
        callback.succeeded();
        coreSession.demand(1);
    }

    public void onText(Frame frame, Callback callback)
    {
        if (activeMessageSink == null)
            activeMessageSink = textSink;

        acceptMessage(frame, callback);
    }

    public void onBinary(Frame frame, Callback callback)
    {
        if (activeMessageSink == null)
            activeMessageSink = binarySink;

        acceptMessage(frame, callback);
    }

    public void onContinuation(Frame frame, Callback callback)
    {
        switch (dataType)
        {
            case OpCode.TEXT:
                onText(frame, callback);
                break;
            case OpCode.BINARY:
                onBinary(frame, callback);
                break;
            default:
                throw new ProtocolException("Unable to process continuation during dataType " + dataType);
        }
    }

    public UpgradeRequest getUpgradeRequest()
    {
        return upgradeRequest;
    }

    private void configListener(String key, Object value)
    {
        if (!key.startsWith("org.eclipse.jetty.websocket."))
            return;

        switch (key)
        {
            case "org.eclipse.jetty.websocket.autoFragment":
                coreSession.setAutoFragment((Boolean)value);
                break;

            case "org.eclipse.jetty.websocket.maxFrameSize":
                coreSession.setMaxFrameSize((Long)value);
                break;

            case "org.eclipse.jetty.websocket.outputBufferSize":
                coreSession.setOutputBufferSize((Integer)value);
                break;

            case "org.eclipse.jetty.websocket.inputBufferSize":
                coreSession.setInputBufferSize((Integer)value);
                break;
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy