All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.net.Message Maven / Gradle / Ivy

Go to download

The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.

There is a newer version: 5.0.2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.net;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import javax.annotation.Nullable;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Ints;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.RequestFailureReason;
import org.apache.cassandra.io.IVersionedAsymmetricSerializer;
import org.apache.cassandra.io.IVersionedSerializer;
import org.apache.cassandra.io.util.DataInputBuffer;
import org.apache.cassandra.io.util.DataInputPlus;
import org.apache.cassandra.io.util.DataOutputPlus;
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.tracing.Tracing.TraceType;
import org.apache.cassandra.transport.Dispatcher;
import org.apache.cassandra.utils.MonotonicClockTranslation;
import org.apache.cassandra.utils.NoSpamLogger;
import org.apache.cassandra.utils.TimeUUID;

import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

import static org.apache.cassandra.db.TypeSizes.sizeof;
import static org.apache.cassandra.db.TypeSizes.sizeofUnsignedVInt;
import static org.apache.cassandra.locator.InetAddressAndPort.Serializer.inetAddressAndPortSerializer;
import static org.apache.cassandra.net.MessagingService.VERSION_3014;
import static org.apache.cassandra.net.MessagingService.VERSION_30;
import static org.apache.cassandra.net.MessagingService.VERSION_40;
import static org.apache.cassandra.net.MessagingService.instance;
import static org.apache.cassandra.utils.FBUtilities.getBroadcastAddressAndPort;
import static org.apache.cassandra.utils.MonotonicClock.Global.approxTime;
import static org.apache.cassandra.utils.vint.VIntCoding.computeUnsignedVIntSize;
import static org.apache.cassandra.utils.vint.VIntCoding.getUnsignedVInt;
import static org.apache.cassandra.utils.vint.VIntCoding.skipUnsignedVInt;

/**
 * Immutable main unit of internode communication - what used to be {@code MessageIn} and {@code MessageOut} fused
 * in one class.
 *
 * @param  The type of the message payload.
 */
public class Message
{
    private static final Logger logger = LoggerFactory.getLogger(Message.class);
    private static final NoSpamLogger noSpam1m = NoSpamLogger.getLogger(logger, 1, TimeUnit.MINUTES);

    public final Header header;
    public final T payload;

    Message(Header header, T payload)
    {
        this.header = header;
        this.payload = payload;
    }

    /** Sender of the message. */
    public InetAddressAndPort from()
    {
        return header.from;
    }

    /** Whether the message has crossed the node boundary, that is whether it originated from another node. */
    public boolean isCrossNode()
    {
        return !from().equals(getBroadcastAddressAndPort());
    }

    /**
     * id of the request/message. In 4.0+ can be shared between multiple messages of the same logical request,
     * whilst in versions above a new id would be allocated for each message sent.
     */
    public long id()
    {
        return header.id;
    }

    public Verb verb()
    {
        return header.verb;
    }

    boolean isFailureResponse()
    {
        return verb() == Verb.FAILURE_RSP;
    }

    /**
     * Creation time of the message. If cross-node timeouts are enabled ({@link DatabaseDescriptor#hasCrossNodeTimeout()},
     * {@code deserialize()} will use the marshalled value, otherwise will use current time on the deserializing machine.
     */
    public long createdAtNanos()
    {
        return header.createdAtNanos;
    }

    public long expiresAtNanos()
    {
        return header.expiresAtNanos;
    }

    /** For how long the message has lived. */
    public long elapsedSinceCreated(TimeUnit units)
    {
        return units.convert(approxTime.now() - createdAtNanos(), NANOSECONDS);
    }

    public long creationTimeMillis()
    {
        return approxTime.translate().toMillisSinceEpoch(createdAtNanos());
    }

    /** Whether a failure response should be returned upon failure */
    boolean callBackOnFailure()
    {
        return header.callBackOnFailure();
    }

    public boolean trackWarnings()
    {
        return header.trackWarnings();
    }

    /** See CASSANDRA-14145 */
    public boolean trackRepairedData()
    {
        return header.trackRepairedData();
    }

    /** Used for cross-DC write optimisation - pick one node in the DC and have it relay the write to its local peers */
    @Nullable
    public ForwardingInfo forwardTo()
    {
        return header.forwardTo();
    }

    /** The originator of the request - used when forwarding and will differ from {@link #from()} */
    @Nullable
    public InetAddressAndPort respondTo()
    {
        return header.respondTo();
    }

    @Nullable
    public TimeUUID traceSession()
    {
        return header.traceSession();
    }

    @Nullable
    public TraceType traceType()
    {
        return header.traceType();
    }

    /*
     * request/response convenience
     */

    /**
     * Make a request {@link Message} with supplied verb and payload. Will fill in remaining fields
     * automatically.
     *
     * If you know that you will need to set some params or flags - prefer using variants of {@code out()}
     * that allow providing them at point of message constructions, rather than allocating new messages
     * with those added flags and params. See {@code outWithFlag()}, {@code outWithFlags()}, and {@code outWithParam()}
     * family.
     */
    public static  Message out(Verb verb, T payload)
    {
        assert !verb.isResponse();

        return outWithParam(nextId(), verb, payload, null, null);
    }

    public static  Message synthetic(InetAddressAndPort from, Verb verb, T payload)
    {
        return new Message<>(new Header(-1, verb, from, -1, -1, 0, NO_PARAMS), payload);
    }

    public static  Message out(Verb verb, T payload, long expiresAtNanos)
    {
        return outWithParam(nextId(), verb, expiresAtNanos, payload, 0, null, null);
    }

    public static  Message outWithFlag(Verb verb, T payload, MessageFlag flag)
    {
        assert !verb.isResponse();
        return outWithParam(nextId(), verb, 0, payload, flag.addTo(0), null, null);
    }

    public static  Message outWithFlags(Verb verb, T payload, MessageFlag flag1, MessageFlag flag2)
    {
        assert !verb.isResponse();
        return outWithParam(nextId(), verb, 0, payload, flag2.addTo(flag1.addTo(0)), null, null);
    }

    public static  Message outWithFlags(Verb verb, T payload, Dispatcher.RequestTime requestTime, List flags)
    {
        assert !verb.isResponse();
        int encodedFlags = 0;
        for (MessageFlag flag : flags)
            encodedFlags = flag.addTo(encodedFlags);

        return new Message(new Header(nextId(),
                                         verb,
                                         getBroadcastAddressAndPort(),
                                         requestTime.startedAtNanos(),
                                         requestTime.computeDeadline(verb.expiresAfterNanos()),
                                         encodedFlags,
                                         buildParams(null, null)),
                              payload);
    }

    @VisibleForTesting
    static  Message outWithParam(long id, Verb verb, T payload, ParamType paramType, Object paramValue)
    {
        return outWithParam(id, verb, 0, payload, paramType, paramValue);
    }

    private static  Message outWithParam(long id, Verb verb, long expiresAtNanos, T payload, ParamType paramType, Object paramValue)
    {
        return outWithParam(id, verb, expiresAtNanos, payload, 0, paramType, paramValue);
    }

    private static  Message outWithParam(long id, Verb verb, long expiresAtNanos, T payload, int flags, ParamType paramType, Object paramValue)
    {
        return withParam(getBroadcastAddressAndPort(), id, verb, expiresAtNanos, payload, flags, paramType, paramValue);
    }

    private static  Message withParam(InetAddressAndPort from, long id, Verb verb, long expiresAtNanos, T payload, int flags, ParamType paramType, Object paramValue)
    {
        if (payload == null)
            throw new IllegalArgumentException();

        long createdAtNanos = approxTime.now();
        if (expiresAtNanos == 0)
            expiresAtNanos = verb.expiresAtNanos(createdAtNanos);

        return new Message<>(new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, buildParams(paramType, paramValue)), payload);
    }

    public static  Message internalResponse(Verb verb, T payload)
    {
        assert verb.isResponse();
        return outWithParam(0, verb, payload, null, null);
    }

    /**
     * Used by the {@code MultiRangeReadCommand} to split multi-range responses from a replica
     * into single-range responses.
     */
    public static  Message remoteResponse(InetAddressAndPort from, Verb verb, T payload)
    {
        assert verb.isResponse();
        long createdAtNanos = approxTime.now();
        long expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
        return new Message<>(new Header(0, verb, from, createdAtNanos, expiresAtNanos, 0, NO_PARAMS), payload);
    }

    /** Builds a response Message with provided payload, and all the right fields inferred from request Message */
    public  Message responseWith(T payload)
    {
        return outWithParam(id(), verb().responseVerb, expiresAtNanos(), payload, null, null);
    }

    /** Builds a response Message with no payload, and all the right fields inferred from request Message */
    public Message emptyResponse()
    {
        return responseWith(NoPayload.noPayload);
    }

    /** Builds a failure response Message with an explicit reason, and fields inferred from request Message */
    public Message failureResponse(RequestFailureReason reason)
    {
        return failureResponse(id(), expiresAtNanos(), reason);
    }

    static Message failureResponse(long id, long expiresAtNanos, RequestFailureReason reason)
    {
        return outWithParam(id, Verb.FAILURE_RSP, expiresAtNanos, reason, null, null);
    }

    public  Message withPayload(V newPayload)
    {
        return new Message<>(header, newPayload);
    }

    Message withCallBackOnFailure()
    {
        return new Message<>(header.withFlag(MessageFlag.CALL_BACK_ON_FAILURE), payload);
    }

    public Message withForwardTo(ForwardingInfo peers)
    {
        return new Message<>(header.withParam(ParamType.FORWARD_TO, peers), payload);
    }

    public Message withFlag(MessageFlag flag)
    {
        return new Message<>(header.withFlag(flag), payload);
    }

    public Message withParam(ParamType type, Object value)
    {
        return new Message<>(header.withParam(type, value), payload);
    }

    public Message withParams(Map values)
    {
        if (values == null || values.isEmpty())
            return this;
        return new Message<>(header.withParams(values), payload);
    }

    private static final EnumMap NO_PARAMS = new EnumMap<>(ParamType.class);

    private static Map buildParams(ParamType type, Object value)
    {
        Map params = NO_PARAMS;
        if (Tracing.isTracing())
            params = Tracing.instance.addTraceHeaders(new EnumMap<>(ParamType.class));

        if (type != null)
        {
            if (params.isEmpty())
                params = new EnumMap<>(ParamType.class);
            params.put(type, value);
        }

        return params;
    }

    private static Map addParam(Map params, ParamType type, Object value)
    {
        if (type == null)
            return params;

        params = new EnumMap<>(params);
        params.put(type, value);
        return params;
    }

    private static Map addParams(Map params, Map values)
    {
        if (values == null || values.isEmpty())
            return params;

        params = new EnumMap<>(params);
        params.putAll(values);
        return params;
    }

    /*
     * id generation
     */

    private static final long NO_ID = 0L; // this is a valid ID for pre40 nodes

    private static final AtomicInteger nextId = new AtomicInteger(0);

    private static long nextId()
    {
        long id;
        do
        {
            id = nextId.incrementAndGet();
        }
        while (id == NO_ID);

        return id;
    }

    /**
     * WARNING: this is inaccurate for messages from pre40 nodes, which can use 0 as an id (but will do so rarely)
     */
    @VisibleForTesting
    boolean hasId()
    {
        return id() != NO_ID;
    }

    /** we preface every message with this number so the recipient can validate the sender is sane */
    static final int PROTOCOL_MAGIC = 0xCA552DFA;

    static void validateLegacyProtocolMagic(int magic) throws InvalidLegacyProtocolMagic
    {
        if (magic != PROTOCOL_MAGIC)
            throw new InvalidLegacyProtocolMagic(magic);
    }

    public static final class InvalidLegacyProtocolMagic extends IOException
    {
        public final int read;
        private InvalidLegacyProtocolMagic(int read)
        {
            super(String.format("Read %d, Expected %d", read, PROTOCOL_MAGIC));
            this.read = read;
        }
    }

    public String toString()
    {
        return "(from:" + from() + ", type:" + verb().stage + " verb:" + verb() + ')';
    }

    /**
     * Split into a separate object to allow partial message deserialization without wasting work and allocation
     * afterwards, if the entire message is necessary and available.
     */
    public static class Header
    {
        public final long id;
        public final Verb verb;
        public final InetAddressAndPort from;
        public final long createdAtNanos;
        public final long expiresAtNanos;
        private final int flags;
        private final Map params;

        private Header(long id, Verb verb, InetAddressAndPort from, long createdAtNanos, long expiresAtNanos, int flags, Map params)
        {
            this.id = id;
            this.verb = verb;
            this.from = from;
            this.expiresAtNanos = expiresAtNanos;
            this.createdAtNanos = createdAtNanos;
            this.flags = flags;
            this.params = params;
        }

        Header withFlag(MessageFlag flag)
        {
            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flag.addTo(flags), params);
        }

        Header withParam(ParamType type, Object value)
        {
            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, addParam(params, type, value));
        }

        Header withParams(Map values)
        {
            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, addParams(params, values));
        }

        boolean callBackOnFailure()
        {
            return MessageFlag.CALL_BACK_ON_FAILURE.isIn(flags);
        }

        boolean trackRepairedData()
        {
            return MessageFlag.TRACK_REPAIRED_DATA.isIn(flags);
        }

        boolean trackWarnings()
        {
            return MessageFlag.TRACK_WARNINGS.isIn(flags);
        }

        @Nullable
        ForwardingInfo forwardTo()
        {
            return (ForwardingInfo) params.get(ParamType.FORWARD_TO);
        }

        @Nullable
        InetAddressAndPort respondTo()
        {
            InetAddressAndPort respondTo = (InetAddressAndPort) params.get(ParamType.RESPOND_TO);
            if (respondTo == null) respondTo = from;
            return respondTo;
        }

        @Nullable
        public TimeUUID traceSession()
        {
            return (TimeUUID) params.get(ParamType.TRACE_SESSION);
        }

        @Nullable
        public TraceType traceType()
        {
            return (TraceType) params.getOrDefault(ParamType.TRACE_TYPE, TraceType.QUERY);
        }

        public Map params()
        {
            return Collections.unmodifiableMap(params);
        }

        @Nullable
        public Map customParams()
        {
            return (Map) params.get(ParamType.CUSTOM_MAP);
        }
    }

    @SuppressWarnings("WeakerAccess")
    public static class Builder
    {
        private Verb verb;
        private InetAddressAndPort from;
        private T payload;
        private int flags = 0;
        private final Map params = new EnumMap<>(ParamType.class);
        private long createdAtNanos;
        private long expiresAtNanos;
        private long id;

        private boolean hasId;

        private Builder()
        {
        }

        public Builder from(InetAddressAndPort from)
        {
            this.from = from;
            return this;
        }

        public Builder withPayload(T payload)
        {
            this.payload = payload;
            return this;
        }

        public Builder withFlag(MessageFlag flag)
        {
            flags = flag.addTo(flags);
            return this;
        }

        public Builder withFlags(int flags)
        {
            this.flags = flags;
            return this;
        }

        public Builder withParam(ParamType type, Object value)
        {
            params.put(type, value);
            return this;
        }

        public Builder withCustomParam(String name, byte[] value)
        {
            Map customParams  = (Map)
                    params.computeIfAbsent(ParamType.CUSTOM_MAP, (t) -> new HashMap());

            customParams.put(name, value);
            return this;
        }

        /**
         * A shortcut to add tracing params.
         * Effectively, it is the same as calling {@link #withParam(ParamType, Object)} with tracing params
         * If there is already tracing params, calling this method overrides any existing ones.
         */
        public Builder withTracingParams()
        {
            if (Tracing.isTracing())
                Tracing.instance.addTraceHeaders(params);
            return this;
        }

        public Builder withoutParam(ParamType type)
        {
            params.remove(type);
            return this;
        }

        public Builder withParams(Map params)
        {
            this.params.putAll(params);
            return this;
        }

        public Builder ofVerb(Verb verb)
        {
            this.verb = verb;
            if (expiresAtNanos == 0 && verb != null && createdAtNanos != 0)
                expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
            if (!this.verb.isResponse() && from == null) // default to sending from self if we're a request verb
                from = getBroadcastAddressAndPort();
            return this;
        }

        public Builder withCreatedAt(long createdAtNanos)
        {
            this.createdAtNanos = createdAtNanos;
            if (expiresAtNanos == 0 && verb != null)
                expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
            return this;
        }

        public Builder withExpiresAt(long expiresAtNanos)
        {
            this.expiresAtNanos = expiresAtNanos;
            return this;
        }

        public Builder withId(long id)
        {
            this.id = id;
            hasId = true;
            return this;
        }

        public Message build()
        {
            if (verb == null)
                throw new IllegalArgumentException();
            if (from == null)
                throw new IllegalArgumentException();
            if (payload == null)
                throw new IllegalArgumentException();

            return new Message<>(new Header(hasId ? id : nextId(), verb, from, createdAtNanos, expiresAtNanos, flags, params), payload);
        }
    }

    public static  Builder builder(Message message)
    {
        return new Builder().from(message.from())
                               .withId(message.id())
                               .ofVerb(message.verb())
                               .withCreatedAt(message.createdAtNanos())
                               .withExpiresAt(message.expiresAtNanos())
                               .withFlags(message.header.flags)
                               .withParams(message.header.params)
                               .withPayload(message.payload);
    }

    public static  Builder builder(Verb verb, T payload)
    {
        return new Builder().ofVerb(verb)
                               .withCreatedAt(approxTime.now())
                               .withPayload(payload);
    }

    public static final Serializer serializer = new Serializer();

    /**
     * Each message contains a header with several fixed fields, an optional key-value params section, and then
     * the message payload itself. Below is a visualization of the layout.
     *
     *  The params are prefixed by the count of key-value pairs; this value is encoded as unsigned vint.
     *  An individual param has an unsvint id (more specifically, a {@link ParamType}), and a byte array value.
     *  The param value is prefixed with it's length, encoded as an unsigned vint, followed by by the value's bytes.
     *
     * Legacy Notes (see {@link Serializer#serialize(Message, DataOutputPlus, int)} for complete details):
     * - pre 4.0, the IP address was sent along in the header, before the verb. The IP address may be either IPv4 (4 bytes) or IPv6 (16 bytes)
     * - pre-4.0, the verb was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint
     * - pre-4.0, the payloadSize was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint
     * - pre-4.0, the count of param key-value pairs was encoded as a 4-byte integer; in 4.0 and up it is an unsigned vint
     * - pre-4.0, param names were encoded as strings; in 4.0 they are encoded as enum id vints
     * - pre-4.0, expiry time wasn't encoded at all; in 4.0 it's an unsigned vint
     * - pre-4.0, message id was an int; in 4.0 and up it's an unsigned vint
     * - pre-4.0, messages included PROTOCOL MAGIC BYTES; post-4.0, we rely on frame CRCs instead
     * - pre-4.0, messages would serialize boolean params as dummy ONE_BYTEs; post-4.0 we have a dedicated 'flags' vint
     *
     * 
     * {@code
     *            1 1 1 1 1 2 2 2 2 2 3
     *  0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Message ID (vint)             |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Creation timestamp (int)      |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Expiry (vint)                 |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Verb (vint)                   |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Flags (vint)                  |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Param count (vint)            |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * |                               /
     * /           Params              /
     * /                               |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * | Payload size (vint)           |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * |                               /
     * /           Payload             /
     * /                               |
     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
     * }
     * 
*/ public static final class Serializer { private static final int CREATION_TIME_SIZE = 4; private Serializer() { } public void serialize(Message message, DataOutputPlus out, int version) throws IOException { if (version >= VERSION_40) serializePost40(message, out, version); else serializePre40(message, out, version); } public Message deserialize(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException { return version >= VERSION_40 ? deserializePost40(in, peer, version) : deserializePre40(in, version); } /** * A partial variant of deserialize, taking in a previously deserialized {@link Header} as an argument. * * Skip deserializing the {@link Header} from the input stream in favour of using the provided header. */ public Message deserialize(DataInputPlus in, Header header, int version) throws IOException { return version >= VERSION_40 ? deserializePost40(in, header, version) : deserializePre40(in, header, version); } private int serializedSize(Message message, int version) { return version >= VERSION_40 ? serializedSizePost40(message, version) : serializedSizePre40(message, version); } /** * Size of the next message in the stream. Returns -1 if there aren't sufficient bytes read yet to determine size. */ int inferMessageSize(ByteBuffer buf, int index, int limit, int version) throws InvalidLegacyProtocolMagic { int size = version >= VERSION_40 ? inferMessageSizePost40(buf, index, limit) : inferMessageSizePre40(buf, index, limit); if (size > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes()) throw new OversizedMessageException(size); return size; } /** * Partially deserialize the message - by only extracting the header and leaving the payload alone. * * To get the rest of the message without repeating the work done here, use {@link #deserialize(DataInputPlus, Header, int)} * method. * * It's assumed that the provided buffer contains all the bytes necessary to deserialize the header fully. */ Header extractHeader(ByteBuffer buf, InetAddressAndPort from, long currentTimeNanos, int version) throws IOException { return version >= VERSION_40 ? extractHeaderPost40(buf, from, currentTimeNanos, version) : extractHeaderPre40(buf, currentTimeNanos, version); } private static long getExpiresAtNanos(long createdAtNanos, long expirationPeriodNanos) { return createdAtNanos + expirationPeriodNanos; } /* * 4.0 ser/deser */ private void serializeHeaderPost40(Header header, DataOutputPlus out, int version) throws IOException { out.writeUnsignedVInt(header.id); // int cast cuts off the high-order half of the timestamp, which we can assume remains // the same between now and when the recipient reconstructs it. out.writeInt((int) approxTime.translate().toMillisSinceEpoch(header.createdAtNanos)); out.writeUnsignedVInt(NANOSECONDS.toMillis(header.expiresAtNanos - header.createdAtNanos)); out.writeUnsignedVInt(header.verb.id); out.writeUnsignedVInt(header.flags); serializeParams(header.params, out, version); } private Header deserializeHeaderPost40(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException { long id = in.readUnsignedVInt(); long currentTimeNanos = approxTime.now(); MonotonicClockTranslation timeSnapshot = approxTime.translate(); long creationTimeNanos = calculateCreationTimeNanos(in.readInt(), timeSnapshot, currentTimeNanos); long expiresAtNanos = getExpiresAtNanos(creationTimeNanos, TimeUnit.MILLISECONDS.toNanos(in.readUnsignedVInt())); Verb verb = Verb.fromId(Ints.checkedCast(in.readUnsignedVInt())); int flags = Ints.checkedCast(in.readUnsignedVInt()); Map params = deserializeParams(in, version); return new Header(id, verb, peer, creationTimeNanos, expiresAtNanos, flags, params); } private void skipHeaderPost40(DataInputPlus in) throws IOException { skipUnsignedVInt(in); // id in.skipBytesFully(4); // createdAt skipUnsignedVInt(in); // expiresIn skipUnsignedVInt(in); // verb skipUnsignedVInt(in); // flags skipParamsPost40(in); // params } private int serializedHeaderSizePost40(Header header, int version) { long size = 0; size += sizeofUnsignedVInt(header.id); size += CREATION_TIME_SIZE; size += sizeofUnsignedVInt(NANOSECONDS.toMillis(header.expiresAtNanos - header.createdAtNanos)); size += sizeofUnsignedVInt(header.verb.id); size += sizeofUnsignedVInt(header.flags); size += serializedParamsSize(header.params, version); return Ints.checkedCast(size); } private Header extractHeaderPost40(ByteBuffer buf, InetAddressAndPort from, long currentTimeNanos, int version) throws IOException { MonotonicClockTranslation timeSnapshot = approxTime.translate(); int index = buf.position(); long id = getUnsignedVInt(buf, index); index += computeUnsignedVIntSize(id); int createdAtMillis = buf.getInt(index); index += sizeof(createdAtMillis); long expiresInMillis = getUnsignedVInt(buf, index); index += computeUnsignedVIntSize(expiresInMillis); Verb verb = Verb.fromId(Ints.checkedCast(getUnsignedVInt(buf, index))); index += computeUnsignedVIntSize(verb.id); int flags = Ints.checkedCast(getUnsignedVInt(buf, index)); index += computeUnsignedVIntSize(flags); Map params = extractParams(buf, index, version); long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, timeSnapshot, currentTimeNanos); long expiresAtNanos = getExpiresAtNanos(createdAtNanos, TimeUnit.MILLISECONDS.toNanos(expiresInMillis)); return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, params); } private void serializePost40(Message message, DataOutputPlus out, int version) throws IOException { serializeHeaderPost40(message.header, out, version); out.writeUnsignedVInt(message.payloadSize(version)); message.verb().serializer().serialize(message.payload, out, version); } private Message deserializePost40(DataInputPlus in, InetAddressAndPort peer, int version) throws IOException { Header header = deserializeHeaderPost40(in, peer, version); skipUnsignedVInt(in); // payload size, not needed by payload deserializer T payload = (T) header.verb.serializer().deserialize(in, version); return new Message<>(header, payload); } private Message deserializePost40(DataInputPlus in, Header header, int version) throws IOException { skipHeaderPost40(in); skipUnsignedVInt(in); // payload size, not needed by payload deserializer T payload = (T) header.verb.serializer().deserialize(in, version); return new Message<>(header, payload); } private int serializedSizePost40(Message message, int version) { long size = 0; size += serializedHeaderSizePost40(message.header, version); int payloadSize = message.payloadSize(version); size += sizeofUnsignedVInt(payloadSize) + payloadSize; return Ints.checkedCast(size); } private int inferMessageSizePost40(ByteBuffer buf, int readerIndex, int readerLimit) { int index = readerIndex; int idSize = computeUnsignedVIntSize(buf, index, readerLimit); if (idSize < 0) return -1; // not enough bytes to read id index += idSize; index += CREATION_TIME_SIZE; if (index > readerLimit) return -1; int expirationSize = computeUnsignedVIntSize(buf, index, readerLimit); if (expirationSize < 0) return -1; index += expirationSize; int verbIdSize = computeUnsignedVIntSize(buf, index, readerLimit); if (verbIdSize < 0) return -1; index += verbIdSize; int flagsSize = computeUnsignedVIntSize(buf, index, readerLimit); if (flagsSize < 0) return -1; index += flagsSize; int paramsSize = extractParamsSizePost40(buf, index, readerLimit); if (paramsSize < 0) return -1; index += paramsSize; long payloadSize = getUnsignedVInt(buf, index, readerLimit); if (payloadSize < 0) return -1; index += computeUnsignedVIntSize(payloadSize) + payloadSize; return index - readerIndex; } /* * legacy ser/deser */ private void serializeHeaderPre40(Header header, DataOutputPlus out, int version) throws IOException { out.writeInt(PROTOCOL_MAGIC); out.writeInt(Ints.checkedCast(header.id)); // int cast cuts off the high-order half of the timestamp, which we can assume remains // the same between now and when the recipient reconstructs it. out.writeInt((int) approxTime.translate().toMillisSinceEpoch(header.createdAtNanos)); inetAddressAndPortSerializer.serialize(header.from, out, version); out.writeInt(header.verb.toPre40Verb().id); serializeParams(addFlagsToLegacyParams(header.params, header.flags), out, version); } private Header deserializeHeaderPre40(DataInputPlus in, int version) throws IOException { validateLegacyProtocolMagic(in.readInt()); int id = in.readInt(); long currentTimeNanos = approxTime.now(); MonotonicClockTranslation timeSnapshot = approxTime.translate(); long creationTimeNanos = calculateCreationTimeNanos(in.readInt(), timeSnapshot, currentTimeNanos); InetAddressAndPort from = inetAddressAndPortSerializer.deserialize(in, version); Verb verb = Verb.fromId(in.readInt()); Map params = deserializeParams(in, version); int flags = removeFlagsFromLegacyParams(params); return new Header(id, verb, from, creationTimeNanos, verb.expiresAtNanos(creationTimeNanos), flags, params); } private static final int PRE_40_MESSAGE_PREFIX_SIZE = 12; // protocol magic + id + createdAt private void skipHeaderPre40(DataInputPlus in) throws IOException { in.skipBytesFully(PRE_40_MESSAGE_PREFIX_SIZE); // magic, id, createdAt in.skipBytesFully(in.readByte()); // from in.skipBytesFully(4); // verb skipParamsPre40(in); // params } private int serializedHeaderSizePre40(Header header, int version) { long size = 0; size += PRE_40_MESSAGE_PREFIX_SIZE; size += inetAddressAndPortSerializer.serializedSize(header.from, version); size += sizeof(header.verb.id); size += serializedParamsSize(addFlagsToLegacyParams(header.params, header.flags), version); return Ints.checkedCast(size); } private Header extractHeaderPre40(ByteBuffer buf, long currentTimeNanos, int version) throws IOException { MonotonicClockTranslation timeSnapshot = approxTime.translate(); int index = buf.position(); index += 4; // protocol magic long id = buf.getInt(index); index += 4; int createdAtMillis = buf.getInt(index); index += 4; InetAddressAndPort from = inetAddressAndPortSerializer.extract(buf, index); index += 1 + buf.get(index); Verb verb = Verb.fromId(buf.getInt(index)); index += 4; Map params = extractParams(buf, index, version); int flags = removeFlagsFromLegacyParams(params); long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, timeSnapshot, currentTimeNanos); long expiresAtNanos = verb.expiresAtNanos(createdAtNanos); return new Header(id, verb, from, createdAtNanos, expiresAtNanos, flags, params); } private void serializePre40(Message message, DataOutputPlus out, int version) throws IOException { if (message.isFailureResponse()) message = toPre40FailureResponse(message); serializeHeaderPre40(message.header, out, version); if (message.payload != null && message.payload != NoPayload.noPayload) { int payloadSize = message.payloadSize(version); out.writeInt(payloadSize); message.getPayloadSerializer().serialize(message.payload, out, version); } else { out.writeInt(0); } } private Message deserializePre40(DataInputPlus in, int version) throws IOException { Header header = deserializeHeaderPre40(in, version); return deserializePre40(in, header, false, version); } private Message deserializePre40(DataInputPlus in, Header header, int version) throws IOException { return deserializePre40(in, header, true, version); } private Message deserializePre40(DataInputPlus in, Header header, boolean skipHeader, int version) throws IOException { if (skipHeader) skipHeaderPre40(in); int payloadSize = in.readInt(); T payload = deserializePayloadPre40(in, version, getPayloadSerializer(header.verb, header.id, header.from), payloadSize); Message message = new Message<>(header, payload); return header.params.containsKey(ParamType.FAILURE_RESPONSE) ? (Message) toPost40FailureResponse(message) : message; } private T deserializePayloadPre40(DataInputPlus in, int version, IVersionedAsymmetricSerializer serializer, int payloadSize) throws IOException { if (payloadSize == 0 || serializer == null) { // if there's no deserializer for the verb, skip the payload bytes to leave // the stream in a clean state (for the next message) in.skipBytesFully(payloadSize); return null; } return serializer.deserialize(in, version); } private int serializedSizePre40(Message message, int version) { if (message.isFailureResponse()) message = toPre40FailureResponse(message); long size = 0; size += serializedHeaderSizePre40(message.header, version); int payloadSize = message.payloadSize(version); size += sizeof(payloadSize); size += payloadSize; return Ints.checkedCast(size); } private int inferMessageSizePre40(ByteBuffer buf, int readerIndex, int readerLimit) throws InvalidLegacyProtocolMagic { int index = readerIndex; // protocol magic index += 4; if (index > readerLimit) return -1; validateLegacyProtocolMagic(buf.getInt(index - 4)); // rest of prefix index += PRE_40_MESSAGE_PREFIX_SIZE - 4; // ip address index += 1; if (index > readerLimit) return -1; index += buf.get(index - 1); // verb index += 4; if (index > readerLimit) return -1; int paramsSize = extractParamsSizePre40(buf, index, readerLimit); if (paramsSize < 0) return -1; index += paramsSize; // payload index += 4; if (index > readerLimit) return -1; index += buf.getInt(index - 4); return index - readerIndex; } private Message toPre40FailureResponse(Message post40) { Map params = new EnumMap<>(ParamType.class); params.putAll(post40.header.params); params.put(ParamType.FAILURE_RESPONSE, LegacyFlag.instance); params.put(ParamType.FAILURE_REASON, post40.payload); Header header = new Header(post40.id(), post40.verb().toPre40Verb(), post40.from(), post40.createdAtNanos(), post40.expiresAtNanos(), 0, params); return new Message<>(header, NoPayload.noPayload); } private Message toPost40FailureResponse(Message pre40) { Map params = new EnumMap<>(ParamType.class); params.putAll(pre40.header.params); params.remove(ParamType.FAILURE_RESPONSE); RequestFailureReason reason = (RequestFailureReason) params.remove(ParamType.FAILURE_REASON); if (null == reason) reason = RequestFailureReason.UNKNOWN; Header header = new Header(pre40.id(), Verb.FAILURE_RSP, pre40.from(), pre40.createdAtNanos(), pre40.expiresAtNanos(), pre40.header.flags, params); return new Message<>(header, reason); } /* * created at + cross-node */ private static final long TIMESTAMP_WRAPAROUND_GRACE_PERIOD_START = 0xFFFFFFFFL - MINUTES.toMillis(15L); private static final long TIMESTAMP_WRAPAROUND_GRACE_PERIOD_END = MINUTES.toMillis(15L); @VisibleForTesting static long calculateCreationTimeNanos(int messageTimestampMillis, MonotonicClockTranslation timeSnapshot, long currentTimeNanos) { // We do not trust external time source, so we override their value with current time if (!DatabaseDescriptor.hasCrossNodeTimeout()) return currentTimeNanos; long currentTimeMillis = timeSnapshot.toMillisSinceEpoch(currentTimeNanos); // Reconstruct the message construction time sent by the remote host (we sent only the lower 4 bytes, assuming the // higher 4 bytes wouldn't change between the sender and receiver) long highBits = currentTimeMillis & 0xFFFFFFFF00000000L; long sentLowBits = messageTimestampMillis & 0x00000000FFFFFFFFL; long currentLowBits = currentTimeMillis & 0x00000000FFFFFFFFL; // if our sent bits occur within a grace period of a wrap around event, // and our current bits are no more than the same grace period after a wrap around event, // assume a wrap around has occurred, and deduct one highBit if ( sentLowBits > TIMESTAMP_WRAPAROUND_GRACE_PERIOD_START && currentLowBits < TIMESTAMP_WRAPAROUND_GRACE_PERIOD_END) { highBits -= 0x0000000100000000L; } // if the message timestamp wrapped, but we still haven't, add one highBit else if (sentLowBits < TIMESTAMP_WRAPAROUND_GRACE_PERIOD_END && currentLowBits > TIMESTAMP_WRAPAROUND_GRACE_PERIOD_START) { highBits += 0x0000000100000000L; } long sentTimeMillis = (highBits | sentLowBits); if (Math.abs(currentTimeMillis - sentTimeMillis) > MINUTES.toMillis(15)) { noSpam1m.warn("Bad timestamp {} generated, overriding with currentTimeMillis = {}", sentTimeMillis, currentTimeMillis); sentTimeMillis = currentTimeMillis; } return timeSnapshot.fromMillisSinceEpoch(sentTimeMillis); } /* * param ser/deser */ private Map addFlagsToLegacyParams(Map params, int flags) { if (flags == 0) return params; Map extended = new EnumMap<>(ParamType.class); extended.putAll(params); if (MessageFlag.CALL_BACK_ON_FAILURE.isIn(flags)) extended.put(ParamType.FAILURE_CALLBACK, LegacyFlag.instance); if (MessageFlag.TRACK_REPAIRED_DATA.isIn(flags)) extended.put(ParamType.TRACK_REPAIRED_DATA, LegacyFlag.instance); return extended; } private int removeFlagsFromLegacyParams(Map params) { int flags = 0; if (null != params.remove(ParamType.FAILURE_CALLBACK)) flags = MessageFlag.CALL_BACK_ON_FAILURE.addTo(flags); if (null != params.remove(ParamType.TRACK_REPAIRED_DATA)) flags = MessageFlag.TRACK_REPAIRED_DATA.addTo(flags); return flags; } private void serializeParams(Map params, DataOutputPlus out, int version) throws IOException { if (version >= VERSION_40) out.writeUnsignedVInt(params.size()); else out.writeInt(params.size()); for (Map.Entry kv : params.entrySet()) { ParamType type = kv.getKey(); if (version >= VERSION_40) out.writeUnsignedVInt(type.id); else out.writeUTF(type.legacyAlias); IVersionedSerializer serializer = type.serializer; Object value = kv.getValue(); int length = Ints.checkedCast(serializer.serializedSize(value, version)); if (version >= VERSION_40) out.writeUnsignedVInt(length); else out.writeInt(length); serializer.serialize(value, out, version); } } private Map deserializeParams(DataInputPlus in, int version) throws IOException { int count = version >= VERSION_40 ? Ints.checkedCast(in.readUnsignedVInt()) : in.readInt(); if (count == 0) return NO_PARAMS; Map params = new EnumMap<>(ParamType.class); for (int i = 0; i < count; i++) { ParamType type = version >= VERSION_40 ? ParamType.lookUpById(Ints.checkedCast(in.readUnsignedVInt())) : ParamType.lookUpByAlias(in.readUTF()); int length = version >= VERSION_40 ? Ints.checkedCast(in.readUnsignedVInt()) : in.readInt(); if (null != type) { // Have to special case deserializer as pre-4.0 needs length to decode correctly if (version < VERSION_40 && type == ParamType.RESPOND_TO) { params.put(type, InetAddressAndPort.FwdFrmSerializer.fwdFrmSerializer.pre40DeserializeWithLength(in, version, length)); } else { params.put(type, type.serializer.deserialize(in, version)); } } else { in.skipBytesFully(length); // forward compatibiliy with minor version changes } } return params; } /* * Extract post-4.0 params map from a ByteBuffer without modifying it. */ private Map extractParams(ByteBuffer buf, int readerIndex, int version) throws IOException { long count = version >= VERSION_40 ? getUnsignedVInt(buf, readerIndex) : buf.getInt(readerIndex); if (count == 0) return NO_PARAMS; final int position = buf.position(); buf.position(readerIndex); try (DataInputBuffer in = new DataInputBuffer(buf, false)) { return deserializeParams(in, version); } finally { buf.position(position); } } private void skipParamsPost40(DataInputPlus in) throws IOException { int count = Ints.checkedCast(in.readUnsignedVInt()); for (int i = 0; i < count; i++) { skipUnsignedVInt(in); in.skipBytesFully(Ints.checkedCast(in.readUnsignedVInt())); } } private void skipParamsPre40(DataInputPlus in) throws IOException { int count = in.readInt(); for (int i = 0; i < count; i++) { in.skipBytesFully(in.readShort()); in.skipBytesFully(in.readInt()); } } private long serializedParamsSize(Map params, int version) { long size = version >= VERSION_40 ? computeUnsignedVIntSize(params.size()) : sizeof(params.size()); for (Map.Entry kv : params.entrySet()) { ParamType type = kv.getKey(); Object value = kv.getValue(); long valueLength = type.serializer.serializedSize(value, version); if (version >= VERSION_40) size += sizeofUnsignedVInt(type.id) + sizeofUnsignedVInt(valueLength); else size += sizeof(type.legacyAlias) + 4; size += valueLength; } return size; } private int extractParamsSizePost40(ByteBuffer buf, int readerIndex, int readerLimit) { int index = readerIndex; long paramsCount = getUnsignedVInt(buf, index, readerLimit); if (paramsCount < 0) return -1; index += computeUnsignedVIntSize(paramsCount); for (int i = 0; i < paramsCount; i++) { long type = getUnsignedVInt(buf, index, readerLimit); if (type < 0) return -1; index += computeUnsignedVIntSize(type); long length = getUnsignedVInt(buf, index, readerLimit); if (length < 0) return -1; index += computeUnsignedVIntSize(length) + length; } return index - readerIndex; } private int extractParamsSizePre40(ByteBuffer buf, int readerIndex, int readerLimit) { int index = readerIndex; index += 4; if (index > readerLimit) return -1; int paramsCount = buf.getInt(index - 4); for (int i = 0; i < paramsCount; i++) { // try to read length and skip to the end of the param name index += 2; if (index > readerLimit) return -1; index += buf.getShort(index - 2); // try to read length and skip to the end of the param value index += 4; if (index > readerLimit) return -1; index += buf.getInt(index - 4); } return index - readerIndex; } private int payloadSize(Message message, int version) { long payloadSize = message.payload != null && message.payload != NoPayload.noPayload ? message.getPayloadSerializer().serializedSize(message.payload, version) : 0; return Ints.checkedCast(payloadSize); } } private IVersionedAsymmetricSerializer getPayloadSerializer() { return getPayloadSerializer(verb(), id(), from()); } // Verb#serializer() is null for legacy response messages. Once all Verbs with null handlers // are removed in a future major, this method can be replaced with a call to verb.serializer. private static IVersionedAsymmetricSerializer getPayloadSerializer(Verb verb, long id, InetAddressAndPort from) { return null != verb.serializer() ? verb.serializer() : instance().callbacks.responseSerializer(id, from); } private int serializedSize30; private int serializedSize3014; private int serializedSize40; /** * Serialized size of the entire message, for the provided messaging version. Caches the calculated value. */ public int serializedSize(int version) { switch (version) { case VERSION_30: if (serializedSize30 == 0) serializedSize30 = serializer.serializedSize(this, VERSION_30); return serializedSize30; case VERSION_3014: if (serializedSize3014 == 0) serializedSize3014 = serializer.serializedSize(this, VERSION_3014); return serializedSize3014; case VERSION_40: if (serializedSize40 == 0) serializedSize40 = serializer.serializedSize(this, VERSION_40); return serializedSize40; default: throw new IllegalStateException(); } } private int payloadSize30 = -1; private int payloadSize3014 = -1; private int payloadSize40 = -1; private int payloadSize(int version) { switch (version) { case VERSION_30: if (payloadSize30 < 0) payloadSize30 = serializer.payloadSize(this, VERSION_30); return payloadSize30; case VERSION_3014: if (payloadSize3014 < 0) payloadSize3014 = serializer.payloadSize(this, VERSION_3014); return payloadSize3014; case VERSION_40: if (payloadSize40 < 0) payloadSize40 = serializer.payloadSize(this, VERSION_40); return payloadSize40; default: throw new IllegalStateException(); } } static class OversizedMessageException extends RuntimeException { OversizedMessageException(int size) { super("Message of size " + size + " bytes exceeds allowed maximum of " + DatabaseDescriptor.getInternodeMaxMessageSizeInBytes() + " bytes"); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy