All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.qpid.proton.engine.impl.WebSocketHandlerImpl Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.qpid.proton.engine.impl;

import org.apache.qpid.proton.engine.WebSocketHandler;
import org.apache.qpid.proton.engine.WebSocketHeader;

import java.io.*;

import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;

import java.security.SecureRandom;

import java.util.*;

public class WebSocketHandlerImpl implements WebSocketHandler
{
    private WebSocketUpgrade _webSocketUpgrade = null;

    @Override
    public String createUpgradeRequest(String hostName, String webSocketPath, int webSocketPort, String webSocketProtocol, Map additionalHeaders)
    {
        _webSocketUpgrade = createWebSocketUpgrade(hostName, webSocketPath, webSocketPort, webSocketProtocol, additionalHeaders);
        return _webSocketUpgrade.createUpgradeRequest();
    }

    @Override
    public void createPong(ByteBuffer ping, ByteBuffer pong)
    {
        if ((ping == null) || (pong == null))
        {
            throw new IllegalArgumentException("input parameter cannot be null");
        }

        if (ping.capacity() > pong.capacity())
        {
            throw new IllegalArgumentException("insufficient output buffer size");
        }

        if (ping.remaining() > 0)
        {
            byte[] buffer = ping.array();
            buffer[0] = WebSocketHeader.FINBIT_MASK | WebSocketHeader.OPCODE_PONG;

            pong.clear();
            pong.put(buffer);
        }
        else
        {
            pong.clear();
            pong.limit(0);
        }
    }

    @Override
    public Boolean validateUpgradeReply(ByteBuffer buffer)
    {
        Boolean retVal = false;

        if (_webSocketUpgrade != null)
        {
            int size = buffer.remaining();

            if (size > 0)
            {
                byte[] data = new byte[buffer.remaining()];
                buffer.get(data);

                retVal = _webSocketUpgrade.validateUpgradeReply(data);
                _webSocketUpgrade = null;
            }
        }

        return retVal;
    }

    @Override
    public void wrapBuffer(ByteBuffer srcBuffer, ByteBuffer dstBuffer)
    {
        if ((srcBuffer == null) || (dstBuffer == null))
        {
            throw new IllegalArgumentException("input parameter is null");
        }

        if (srcBuffer.remaining() > 0)
        {
            // We always send masked data
            // RFC: "client MUST mask all frames that it sends to the server"
            final byte[] MASKING_KEY = createRandomMaskingKey();

            // Get data length
            final int DATA_LENGTH = srcBuffer.remaining();

            // Auto growing buffer for the WS frame, initialized to minimum size
            ByteArrayOutputStream webSocketFrame = new ByteArrayOutputStream(WebSocketHeader.MIN_HEADER_LENGTH_MASKED + DATA_LENGTH);

            // Create the first byte
            // We always send final WebSocket frame
            // We always send binary message (AMQP)
            byte firstByte = (byte) (WebSocketHeader.FINBIT_MASK | WebSocketHeader.OPCODE_BINARY);
            webSocketFrame.write(firstByte);

            // Create the second byte
            // RFC: "client MUST mask all frames that it sends to the server"
            byte secondByte = WebSocketHeader.MASKBIT_MASK;

            // RFC: The length of the "Payload data", in bytes: if 0-125, that is the payload length.
            if (DATA_LENGTH <= WebSocketHeader.PAYLOAD_SHORT_MAX)
            {
                secondByte = (byte) (secondByte | DATA_LENGTH);
                webSocketFrame.write(secondByte);
            }
            // RFC: If 126, the following 2 bytes interpreted as a 16-bit unsigned integer are the payload length
            else if (DATA_LENGTH <= WebSocketHeader.PAYLOAD_MEDIUM_MAX)
            {
                // Create payload byte
                secondByte = (byte) (secondByte | WebSocketHeader.PAYLOAD_EXTENDED_16);
                webSocketFrame.write(secondByte);

                // Create extended length bytes
                webSocketFrame.write((byte) (DATA_LENGTH >>> 8));
                webSocketFrame.write((byte) (DATA_LENGTH));
            }
            // RFC: If 127, the following 8 bytes interpreted as a 64-bit unsigned integer (the most significant bit MUST be 0) are the payload length.
            // No need for "else if" because if it is longer than what 8 byte length can hold... all bets are off anyway
            else
            {
                secondByte = (byte) (secondByte | WebSocketHeader.PAYLOAD_EXTENDED_64);
                webSocketFrame.write(secondByte);

                webSocketFrame.write((byte) (DATA_LENGTH >>> 56));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 48));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 40));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 32));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 24));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 16));
                webSocketFrame.write((byte) (DATA_LENGTH >>> 8));
                webSocketFrame.write((byte) (DATA_LENGTH));
            }

            // Write mask
            webSocketFrame.write(MASKING_KEY[0]);
            webSocketFrame.write(MASKING_KEY[1]);
            webSocketFrame.write(MASKING_KEY[2]);
            webSocketFrame.write(MASKING_KEY[3]);

            // Write masked data
            for (int i = 0; i < DATA_LENGTH; i++)
            {
                byte nextByte = srcBuffer.get();
                nextByte ^= MASKING_KEY[i % 4];
                webSocketFrame.write(nextByte);
            }

            // Copy frame to destination buffer
            dstBuffer.clear();
            if (dstBuffer.capacity() >= webSocketFrame.size())
            {
                dstBuffer.put(webSocketFrame.toByteArray());
            }
            else
            {
                throw new OutOfMemoryError("insufficient output buffer size");
            }
        }
        else
        {
            dstBuffer.clear();
        }
    }

    @Override
    public WebsocketTuple unwrapBuffer(ByteBuffer srcBuffer)
    {
        WebsocketTuple result = new WebsocketTuple(0, WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_UNKNOWN);

        if (srcBuffer == null)
        {
            throw new IllegalArgumentException("input parameter is null");
        }

        WebSocketMessageType retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_UNKNOWN;

        if (srcBuffer.remaining() > WebSocketHeader.MIN_HEADER_LENGTH)
        {
            // Read the first byte
            byte firstByte = srcBuffer.get();

            // Get and check the opcode
            byte opcode = (byte) (firstByte & WebSocketHeader.OPCODE_MASK);

            // Read the second byte
            byte secondByte = srcBuffer.get();
            byte maskBit = (byte) (secondByte & WebSocketHeader.MASKBIT_MASK);
            byte payloadLength = (byte) (secondByte & WebSocketHeader.PAYLOAD_MASK);

            long finalPayloadLength = -1;

            if (payloadLength <= WebSocketHeader.PAYLOAD_SHORT_MAX)
            {
                finalPayloadLength = payloadLength;
            }
            else if (payloadLength == WebSocketHeader.PAYLOAD_EXTENDED_16)
            {
                // Check if we have enough bytes to read
                try
                {
                    //Apply mask to turn into unsigned value
                    finalPayloadLength = srcBuffer.getShort() & 0xFFFF;
                }
                catch (BufferUnderflowException e)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_HEADER_CHUNK;
                }
            }
            else if (payloadLength == WebSocketHeader.PAYLOAD_EXTENDED_64)
            {
                //Check if we have enough bytes to read
                try
                {
                    finalPayloadLength = srcBuffer.getLong();
                }
                catch (BufferUnderflowException e)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_HEADER_CHUNK;
                }
            }

            if (retVal == WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_UNKNOWN)
            {
                if (opcode == WebSocketHeader.OPCODE_BINARY)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_AMQP;
                }
                else if (opcode == WebSocketHeader.OPCODE_PING)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_PING;
                }
                else if (opcode == WebSocketHeader.OPCODE_CLOSE)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_CLOSE;
                }
                else if(opcode == WebSocketHeader.OPCODE_CONTINUATION)
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_CHUNK;
                }
                else
                {
                    retVal = WebSocketMessageType.WEB_SOCKET_MESSAGE_TYPE_UNKNOWN;
                }
            }

            result.setLength(finalPayloadLength);
            result.setType(retVal);
        }

        return result;
    }

    protected WebSocketUpgrade createWebSocketUpgrade(String hostName, String webSocketPath, int webSocketPort, String webSocketProtocol, Map additionalHeaders)
    {
        return new WebSocketUpgrade(hostName, webSocketPath, webSocketPort, webSocketProtocol, additionalHeaders);
    }

    protected byte[] createRandomMaskingKey()
    {
        final byte[] maskingKey = new byte[4];
        Random random = new SecureRandom();
        random.nextBytes(maskingKey);

        return maskingKey;
    }

    public int calculateHeaderSize(int payloadSize)
    {
        int retVal = 0;
        if (payloadSize > 0)
        {
            if (payloadSize <= WebSocketHeader.PAYLOAD_SHORT_MAX)
            {
                retVal = WebSocketHeader.MIN_HEADER_LENGTH_MASKED;
            }
            else if (payloadSize <= WebSocketHeader.PAYLOAD_MEDIUM_MAX)
            {
                retVal = WebSocketHeader.MED_HEADER_LENGTH_MASKED;
            }
            else
            {
                retVal = WebSocketHeader.MAX_HEADER_LENGTH_MASKED;
            }
        }
        return retVal;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy