All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.transport.Envelope Maven / Gradle / Ivy

Go to download

The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.

There is a newer version: 5.0-rc1
Show newest version

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.transport;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.List;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.metrics.ClientMessageSizeMetrics;
import org.apache.cassandra.transport.messages.ErrorMessage;
import org.apache.cassandra.utils.ByteBufferUtil;

public class Envelope
{
    public static final byte PROTOCOL_VERSION_MASK = 0x7f;

    public final Header header;
    public final ByteBuf body;

    /**
     * An on-wire message envelope consists of a header and a body.
     *
     * The header is defined the following way in native protocol version 3 and later:
     *
     *   0         8        16        24        32         40
     *   +---------+---------+---------+---------+---------+
     *   | version |  flags  |      stream       | opcode  |
     *   +---------+---------+---------+---------+---------+
     *   |                length                 |
     *   +---------+---------+---------+---------+
     */
    public Envelope(Header header, ByteBuf body)
    {
        this.header = header;
        this.body = body;
    }

    public void retain()
    {
        body.retain();
    }

    public boolean release()
    {
        return body.release();
    }

    @VisibleForTesting
    public Envelope clone()
    {
        return new Envelope(header, Unpooled.wrappedBuffer(ByteBufferUtil.clone(body.nioBuffer())));
    }

    public static Envelope create(Message.Type type, int streamId, ProtocolVersion version, EnumSet flags, ByteBuf body)
    {
        Header header = new Header(version, flags, streamId, type, body.readableBytes());
        return new Envelope(header, body);
    }

    // used by V4 and earlier in Encoder.encode
    public ByteBuf encodeHeader()
    {
        ByteBuf buf = CBUtil.allocator.buffer(Header.LENGTH);

        Message.Type type = header.type;
        buf.writeByte(type.direction.addToVersion(header.version.asInt()));
        buf.writeByte(Header.Flag.serialize(header.flags));

        // Continue to support writing pre-v3 headers so that we can give proper error messages to drivers that
        // connect with the v1/v2 protocol. See CASSANDRA-11464.
        if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3))
            buf.writeShort(header.streamId);
        else
            buf.writeByte(header.streamId);

        buf.writeByte(type.opcode);
        buf.writeInt(body.readableBytes());
        return buf;
    }

    // Used by V5 and later
    public void encodeHeaderInto(ByteBuffer buf)
    {
        buf.put((byte) header.type.direction.addToVersion(header.version.asInt()));
        buf.put((byte) Envelope.Header.Flag.serialize(header.flags));

        if (header.version.isGreaterOrEqualTo(ProtocolVersion.V3))
            buf.putShort((short) header.streamId);
        else
            buf.put((byte) header.streamId);

        buf.put((byte) header.type.opcode);
        buf.putInt(body.readableBytes());
    }

    // Used by V5 and later
    public void encodeInto(ByteBuffer buf)
    {
        encodeHeaderInto(buf);
        buf.put(body.nioBuffer());
    }

    public static class Header
    {
        // 9 bytes in protocol version 3 and later
        public static final int LENGTH = 9;

        public static final int BODY_LENGTH_SIZE = 4;

        public final ProtocolVersion version;
        public final EnumSet flags;
        public final int streamId;
        public final Message.Type type;
        public final long bodySizeInBytes;

        private Header(ProtocolVersion version, EnumSet flags, int streamId, Message.Type type, long bodySizeInBytes)
        {
            this.version = version;
            this.flags = flags;
            this.streamId = streamId;
            this.type = type;
            this.bodySizeInBytes = bodySizeInBytes;
        }

        public enum Flag
        {
            // The order of that enum matters!!
            COMPRESSED,
            TRACING,
            CUSTOM_PAYLOAD,
            WARNING,
            USE_BETA;

            private static final Flag[] ALL_VALUES = values();

            public static EnumSet deserialize(int flags)
            {
                EnumSet set = EnumSet.noneOf(Flag.class);
                for (int n = 0; n < ALL_VALUES.length; n++)
                {
                    if ((flags & (1 << n)) != 0)
                        set.add(ALL_VALUES[n]);
                }
                return set;
            }

            public static int serialize(EnumSet flags)
            {
                int i = 0;
                for (Flag flag : flags)
                    i |= 1 << flag.ordinal();
                return i;
            }
        }
    }

    public Envelope with(ByteBuf newBody)
    {
        return new Envelope(header, newBody);
    }

    public static class Decoder extends ByteToMessageDecoder
    {
        private static final int MAX_TOTAL_LENGTH = DatabaseDescriptor.getNativeTransportMaxFrameSize();

        private boolean discardingTooLongMessage;
        private long tooLongTotalLength;
        private long bytesToDiscard;
        private int tooLongStreamId;

        /**
         * Used by protocol V5 and later to extract a CQL message header from the buffer containing
         * it, without modifying the position of the underlying buffer. This essentially mirrors the
         * pre-V5 code in {@link Decoder#decode(ByteBuf)}, with three differences:
         * 
    *
  • The input is a ByteBuffer rather than a ByteBuf
  • *
  • This cannot return null, as V5 always deals with entire CQL messages. Coalescing of bytes * off the wire happens at the layer below, in {@link org.apache.cassandra.net.FrameDecoder}
  • *
  • This method never throws {@link ProtocolException}. Instead, a subclass of * {@link HeaderExtractionResult} is returned which may provide either a {@link Header} or a * {@link ProtocolException},depending on the result of its {@link HeaderExtractionResult#isSuccess()} * method.
  • *
* * @param buffer ByteBuffer containing the message envelope * @return The result of attempting to extract a header from the input buffer. */ HeaderExtractionResult extractHeader(ByteBuffer buffer) { Preconditions.checkArgument(buffer.remaining() >= Header.LENGTH, "Undersized buffer supplied. Expected %s, actual %s", Header.LENGTH, buffer.remaining()); int idx = buffer.position(); int firstByte = buffer.get(idx++); int versionNum = firstByte & PROTOCOL_VERSION_MASK; int flags = buffer.get(idx++); int streamId = buffer.getShort(idx); idx += 2; int opcode = buffer.get(idx++); long bodyLength = buffer.getInt(idx); // if a negative length is read, return error but report length as 0 so we don't attempt to skip if (bodyLength < 0) return new HeaderExtractionResult.Error(new ProtocolException("Invalid value for envelope header body length field: " + bodyLength), streamId, bodyLength); Message.Direction direction = Message.Direction.extractFromVersion(firstByte); Message.Type type; ProtocolVersion version; EnumSet decodedFlags; try { // This throws a protocol exception if the version number is unsupported, // the opcode is unknown or invalid flags are set for the version version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols()); decodedFlags = decodeFlags(version, flags); type = Message.Type.fromOpcode(opcode, direction); return new HeaderExtractionResult.Success(new Header(version, decodedFlags, streamId, type, bodyLength)); } catch (ProtocolException e) { // Including the streamId and bodyLength is a best effort to allow the caller // to send a meaningful response to the client and continue processing the // rest of the frame. It's possible that these are bogus and may have contributed // to the ProtocolException. If so, the upstream CQLMessageHandler should run into // further errors and once it breaches its threshold for consecutive errors, it will // cause the channel to be closed. return new HeaderExtractionResult.Error(e, streamId, bodyLength); } } public static abstract class HeaderExtractionResult { enum Outcome { SUCCESS, ERROR }; private final Outcome outcome; private final int streamId; private final long bodyLength; private HeaderExtractionResult(Outcome outcome, int streamId, long bodyLength) { this.outcome = outcome; this.streamId = streamId; this.bodyLength = bodyLength; } boolean isSuccess() { return outcome == Outcome.SUCCESS; } int streamId() { return streamId; } long bodyLength() { return bodyLength; } Header header() { throw new IllegalStateException(String.format("Unable to provide header from extraction result : %s", outcome)); }; ProtocolException error() { throw new IllegalStateException(String.format("Unable to provide error from extraction result : %s", outcome)); } private static class Success extends HeaderExtractionResult { private final Header header; Success(Header header) { super(Outcome.SUCCESS, header.streamId, header.bodySizeInBytes); this.header = header; } @Override Header header() { return header; } } private static class Error extends HeaderExtractionResult { private final ProtocolException error; private Error(ProtocolException error, int streamId, long bodyLength) { super(Outcome.ERROR, streamId, bodyLength); this.error = error; } @Override ProtocolException error() { return error; } } } @VisibleForTesting Envelope decode(ByteBuf buffer) { if (discardingTooLongMessage) { bytesToDiscard = discard(buffer, bytesToDiscard); // If we have discarded everything, throw the exception if (bytesToDiscard <= 0) fail(); return null; } int readableBytes = buffer.readableBytes(); if (readableBytes == 0) return null; int idx = buffer.readerIndex(); // Check the first byte for the protocol version before we wait for a complete header. Protocol versions // 1 and 2 use a shorter header, so we may never have a complete header's worth of bytes. int firstByte = buffer.getByte(idx++); Message.Direction direction = Message.Direction.extractFromVersion(firstByte); int versionNum = firstByte & PROTOCOL_VERSION_MASK; ProtocolVersion version; try { version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols()); } catch (ProtocolException e) { // Skip the remaining useless bytes. Otherwise the channel closing logic may try to decode again. buffer.skipBytes(readableBytes); throw e; } // Wait until we have the complete header if (readableBytes < Header.LENGTH) return null; int flags = buffer.getByte(idx++); EnumSet decodedFlags = decodeFlags(version, flags); int streamId = buffer.getShort(idx); idx += 2; // This throws a protocol exceptions if the opcode is unknown Message.Type type; try { type = Message.Type.fromOpcode(buffer.getByte(idx++), direction); } catch (ProtocolException e) { throw ErrorMessage.wrap(e, streamId); } long bodyLength = buffer.getUnsignedInt(idx); idx += Header.BODY_LENGTH_SIZE; long totalLength = bodyLength + Header.LENGTH; if (totalLength > MAX_TOTAL_LENGTH) { // Enter the discard mode and discard everything received so far. discardingTooLongMessage = true; tooLongStreamId = streamId; tooLongTotalLength = totalLength; bytesToDiscard = discard(buffer, totalLength); if (bytesToDiscard <= 0) fail(); return null; } if (buffer.readableBytes() < totalLength) return null; ClientMessageSizeMetrics.bytesReceived.inc(totalLength); ClientMessageSizeMetrics.bytesReceivedPerRequest.update(totalLength); // extract body ByteBuf body = buffer.slice(idx, (int) bodyLength); body.retain(); idx += bodyLength; buffer.readerIndex(idx); return new Envelope(new Header(version, decodedFlags, streamId, type, bodyLength), body); } private EnumSet decodeFlags(ProtocolVersion version, int flags) { EnumSet decodedFlags = Header.Flag.deserialize(flags); if (version.isBeta() && !decodedFlags.contains(Header.Flag.USE_BETA)) throw new ProtocolException(String.format("Beta version of the protocol used (%s), but USE_BETA flag is unset", version), version); return decodedFlags; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List results) { Envelope envelope = decode(buffer); if (envelope == null) return; results.add(envelope); } private void fail() { // Reset to the initial state and throw the exception long tooLongTotalLength = this.tooLongTotalLength; this.tooLongTotalLength = 0; discardingTooLongMessage = false; String msg = String.format("Request is too big: length %d exceeds maximum allowed length %d.", tooLongTotalLength, MAX_TOTAL_LENGTH); throw ErrorMessage.wrap(new InvalidRequestException(msg), tooLongStreamId); } } // How much remains to be discarded private static long discard(ByteBuf buffer, long remainingToDiscard) { int availableToDiscard = (int) Math.min(remainingToDiscard, buffer.readableBytes()); buffer.skipBytes(availableToDiscard); return remainingToDiscard - availableToDiscard; } @ChannelHandler.Sharable public static class Encoder extends MessageToMessageEncoder { public static final Encoder instance = new Envelope.Encoder(); private Encoder(){} public void encode(ChannelHandlerContext ctx, Envelope source, List results) { ByteBuf serializedHeader = source.encodeHeader(); int messageSize = serializedHeader.readableBytes() + source.body.readableBytes(); ClientMessageSizeMetrics.bytesSent.inc(messageSize); ClientMessageSizeMetrics.bytesSentPerResponse.update(messageSize); results.add(serializedHeader); results.add(source.body); } } @ChannelHandler.Sharable public static class Decompressor extends MessageToMessageDecoder { public static Decompressor instance = new Envelope.Decompressor(); private Decompressor(){} public void decode(ChannelHandlerContext ctx, Envelope source, List results) throws IOException { Connection connection = ctx.channel().attr(Connection.attributeKey).get(); if (!source.header.flags.contains(Header.Flag.COMPRESSED) || connection == null) { results.add(source); return; } org.apache.cassandra.transport.Compressor compressor = connection.getCompressor(); if (compressor == null) { results.add(source); return; } results.add(compressor.decompress(source)); } } @ChannelHandler.Sharable public static class Compressor extends MessageToMessageEncoder { public static Compressor instance = new Compressor(); private Compressor(){} public void encode(ChannelHandlerContext ctx, Envelope source, List results) throws IOException { Connection connection = ctx.channel().attr(Connection.attributeKey).get(); // Never compress STARTUP messages if (source.header.type == Message.Type.STARTUP || connection == null) { results.add(source); return; } org.apache.cassandra.transport.Compressor compressor = connection.getCompressor(); if (compressor == null) { results.add(source); return; } source.header.flags.add(Header.Flag.COMPRESSED); results.add(compressor.compress(source)); } } }