All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.driver.core.Requests Maven / Gradle / Ivy

There is a newer version: 3.6.0-1
Show newest version
/*
 *      Copyright (C) 2012-2015 DataStax Inc.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package com.datastax.driver.core;

import com.datastax.driver.$internal.com.google.common.base.Preconditions;
import com.datastax.driver.$internal.com.google.common.collect.ImmutableMap;
import io.netty.buffer.ByteBuf;

import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;

class Requests {

    private Requests() {
    }

    static class Startup extends Message.Request {
        private static final String CQL_VERSION_OPTION = "CQL_VERSION";
        private static final String CQL_VERSION = "3.0.0";

        static final String COMPRESSION_OPTION = "COMPRESSION";

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Startup msg, ByteBuf dest, ProtocolVersion version) {
                CBUtil.writeStringMap(msg.options, dest);
            }

            @Override
            public int encodedSize(Startup msg, ProtocolVersion version) {
                return CBUtil.sizeOfStringMap(msg.options);
            }
        };

        private final Map options;
        private final ProtocolOptions.Compression compression;

        Startup(ProtocolOptions.Compression compression) {
            super(Message.Request.Type.STARTUP);
            this.compression = compression;

            ImmutableMap.Builder map = new ImmutableMap.Builder();
            map.put(CQL_VERSION_OPTION, CQL_VERSION);
            if (compression != ProtocolOptions.Compression.NONE)
                map.put(COMPRESSION_OPTION, compression.toString());
            this.options = map.build();
        }

        @Override
        protected Request copyInternal() {
            return new Startup(compression);
        }

        @Override
        public String toString() {
            return "STARTUP " + options;
        }
    }

    // Only for protocol v1
    static class Credentials extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {

            @Override
            public void encode(Credentials msg, ByteBuf dest, ProtocolVersion version) {
                assert version == ProtocolVersion.V1;
                CBUtil.writeStringMap(msg.credentials, dest);
            }

            @Override
            public int encodedSize(Credentials msg, ProtocolVersion version) {
                assert version == ProtocolVersion.V1;
                return CBUtil.sizeOfStringMap(msg.credentials);
            }
        };

        private final Map credentials;

        Credentials(Map credentials) {
            super(Message.Request.Type.CREDENTIALS);
            this.credentials = credentials;
        }

        @Override
        protected Request copyInternal() {
            return new Credentials(credentials);
        }
    }

    static class Options extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Options msg, ByteBuf dest, ProtocolVersion version) {
            }

            @Override
            public int encodedSize(Options msg, ProtocolVersion version) {
                return 0;
            }
        };

        Options() {
            super(Message.Request.Type.OPTIONS);
        }

        @Override
        protected Request copyInternal() {
            return new Options();
        }

        @Override
        public String toString() {
            return "OPTIONS";
        }
    }

    static class Query extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Query msg, ByteBuf dest, ProtocolVersion version) {
                CBUtil.writeLongString(msg.query, dest);
                msg.options.encode(dest, version);
            }

            @Override
            public int encodedSize(Query msg, ProtocolVersion version) {
                return CBUtil.sizeOfLongString(msg.query)
                        + msg.options.encodedSize(version);
            }
        };

        final String query;
        final QueryProtocolOptions options;

        Query(String query) {
            this(query, QueryProtocolOptions.DEFAULT, false);
        }

        Query(String query, QueryProtocolOptions options, boolean tracingRequested) {
            super(Type.QUERY, tracingRequested);
            this.query = query;
            this.options = options;
        }

        @Override
        protected Request copyInternal() {
            return new Query(this.query, options, isTracingRequested());
        }

        @Override
        protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
            return new Query(this.query, options.copy(newConsistencyLevel), isTracingRequested());
        }

        @Override
        public String toString() {
            return "QUERY " + query + '(' + options + ')';
        }
    }

    static class Execute extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Execute msg, ByteBuf dest, ProtocolVersion version) {
                CBUtil.writeBytes(msg.statementId.bytes, dest);
                msg.options.encode(dest, version);
            }

            @Override
            public int encodedSize(Execute msg, ProtocolVersion version) {
                return CBUtil.sizeOfBytes(msg.statementId.bytes)
                        + msg.options.encodedSize(version);
            }
        };

        final MD5Digest statementId;
        final QueryProtocolOptions options;

        Execute(MD5Digest statementId, QueryProtocolOptions options, boolean tracingRequested) {
            super(Message.Request.Type.EXECUTE, tracingRequested);
            this.statementId = statementId;
            this.options = options;
        }

        @Override
        protected Request copyInternal() {
            return new Execute(statementId, options, isTracingRequested());
        }

        @Override
        protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
            return new Execute(statementId, options.copy(newConsistencyLevel), isTracingRequested());
        }

        @Override
        public String toString() {
            return "EXECUTE " + statementId + " (" + options + ')';
        }
    }

    enum QueryFlag {
        // The order of that enum matters!!
        VALUES,
        SKIP_METADATA,
        PAGE_SIZE,
        PAGING_STATE,
        SERIAL_CONSISTENCY,
        DEFAULT_TIMESTAMP,
        VALUE_NAMES;

        static EnumSet deserialize(int flags) {
            EnumSet set = EnumSet.noneOf(QueryFlag.class);
            QueryFlag[] values = QueryFlag.values();
            for (int n = 0; n < values.length; n++) {
                if ((flags & (1 << n)) != 0)
                    set.add(values[n]);
            }
            return set;
        }

        static int serialize(EnumSet flags) {
            int i = 0;
            for (QueryFlag flag : flags)
                i |= 1 << flag.ordinal();
            return i;
        }
    }

    static class QueryProtocolOptions {

        static final QueryProtocolOptions DEFAULT = new QueryProtocolOptions(
                Message.Request.Type.QUERY,
                ConsistencyLevel.ONE,
                Collections.emptyList(),
                Collections.emptyMap(),
                false,
                -1,
                null,
                ConsistencyLevel.SERIAL, Long.MIN_VALUE);

        private final EnumSet flags = EnumSet.noneOf(QueryFlag.class);
        private final Message.Request.Type requestType;
        final ConsistencyLevel consistency;
        final List positionalValues;
        final Map namedValues;
        final boolean skipMetadata;
        final int pageSize;
        final ByteBuffer pagingState;
        final ConsistencyLevel serialConsistency;
        final long defaultTimestamp;

        QueryProtocolOptions(Message.Request.Type requestType,
                             ConsistencyLevel consistency,
                             List positionalValues,
                             Map namedValues,
                             boolean skipMetadata,
                             int pageSize,
                             ByteBuffer pagingState,
                             ConsistencyLevel serialConsistency,
                             long defaultTimestamp) {

            Preconditions.checkArgument(positionalValues.isEmpty() || namedValues.isEmpty());

            this.requestType = requestType;
            this.consistency = consistency;
            this.positionalValues = positionalValues;
            this.namedValues = namedValues;
            this.skipMetadata = skipMetadata;
            this.pageSize = pageSize;
            this.pagingState = pagingState;
            this.serialConsistency = serialConsistency;
            this.defaultTimestamp = defaultTimestamp;

            // Populate flags
            if (!positionalValues.isEmpty())
                flags.add(QueryFlag.VALUES);
            if (!namedValues.isEmpty()) {
                flags.add(QueryFlag.VALUES);
                flags.add(QueryFlag.VALUE_NAMES);
            }
            if (skipMetadata)
                flags.add(QueryFlag.SKIP_METADATA);
            if (pageSize >= 0)
                flags.add(QueryFlag.PAGE_SIZE);
            if (pagingState != null)
                flags.add(QueryFlag.PAGING_STATE);
            if (serialConsistency != ConsistencyLevel.SERIAL)
                flags.add(QueryFlag.SERIAL_CONSISTENCY);
            if (defaultTimestamp != Long.MIN_VALUE)
                flags.add(QueryFlag.DEFAULT_TIMESTAMP);
        }

        QueryProtocolOptions copy(ConsistencyLevel newConsistencyLevel) {
            return new QueryProtocolOptions(requestType, newConsistencyLevel, positionalValues, namedValues, skipMetadata, pageSize, pagingState, serialConsistency, defaultTimestamp);
        }

        void encode(ByteBuf dest, ProtocolVersion version) {
            switch (version) {
                case V1:
                    // only EXECUTE messages have variables in V1, and their list must be written
                    // even if it is empty; and they are never named
                    if (requestType == Message.Request.Type.EXECUTE)
                        CBUtil.writeValueList(positionalValues, dest);
                    CBUtil.writeConsistencyLevel(consistency, dest);
                    break;
                case V2:
                case V3:
                case V4:
                    CBUtil.writeConsistencyLevel(consistency, dest);
                    dest.writeByte((byte) QueryFlag.serialize(flags));
                    if (flags.contains(QueryFlag.VALUES)) {
                        if (flags.contains(QueryFlag.VALUE_NAMES)) {
                            assert version.compareTo(ProtocolVersion.V3) >= 0;
                            CBUtil.writeNamedValueList(namedValues, dest);
                        } else {
                            CBUtil.writeValueList(positionalValues, dest);
                        }
                    }
                    if (flags.contains(QueryFlag.PAGE_SIZE))
                        dest.writeInt(pageSize);
                    if (flags.contains(QueryFlag.PAGING_STATE))
                        CBUtil.writeValue(pagingState, dest);
                    if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
                        CBUtil.writeConsistencyLevel(serialConsistency, dest);
                    if (version.compareTo(ProtocolVersion.V3) >= 0 && flags.contains(QueryFlag.DEFAULT_TIMESTAMP))
                        dest.writeLong(defaultTimestamp);
                    break;
                default:
                    throw version.unsupported();
            }
        }

        int encodedSize(ProtocolVersion version) {
            switch (version) {
                case V1:
                    // only EXECUTE messages have variables in V1, and their list must be written
                    // even if it is empty; and they are never named
                    return (requestType == Message.Request.Type.EXECUTE ? CBUtil.sizeOfValueList(positionalValues) : 0)
                            + CBUtil.sizeOfConsistencyLevel(consistency);
                case V2:
                case V3:
                case V4:
                    int size = 0;
                    size += CBUtil.sizeOfConsistencyLevel(consistency);
                    size += 1; // flags
                    if (flags.contains(QueryFlag.VALUES)) {
                        if (flags.contains(QueryFlag.VALUE_NAMES)) {
                            assert version.compareTo(ProtocolVersion.V3) >= 0;
                            size += CBUtil.sizeOfNamedValueList(namedValues);
                        } else {
                            size += CBUtil.sizeOfValueList(positionalValues);
                        }
                    }
                    if (flags.contains(QueryFlag.PAGE_SIZE))
                        size += 4;
                    if (flags.contains(QueryFlag.PAGING_STATE))
                        size += CBUtil.sizeOfValue(pagingState);
                    if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
                        size += CBUtil.sizeOfConsistencyLevel(serialConsistency);
                    if (version == ProtocolVersion.V3 && flags.contains(QueryFlag.DEFAULT_TIMESTAMP))
                        size += 8;
                    return size;
                default:
                    throw version.unsupported();
            }
        }

        @Override
        public String toString() {
            return String.format("[cl=%s, positionalVals=%s, namedVals=%s, skip=%b, psize=%d, state=%s, serialCl=%s]",
                    consistency, positionalValues, namedValues, skipMetadata, pageSize, pagingState, serialConsistency);
        }
    }

    static class Batch extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Batch msg, ByteBuf dest, ProtocolVersion version) {
                int queries = msg.queryOrIdList.size();
                assert queries <= 0xFFFF;

                dest.writeByte(fromType(msg.type));
                dest.writeShort(queries);

                for (int i = 0; i < queries; i++) {
                    Object q = msg.queryOrIdList.get(i);
                    dest.writeByte((byte) (q instanceof String ? 0 : 1));
                    if (q instanceof String)
                        CBUtil.writeLongString((String) q, dest);
                    else
                        CBUtil.writeBytes(((MD5Digest) q).bytes, dest);

                    CBUtil.writeValueList(msg.values.get(i), dest);
                }

                msg.options.encode(dest, version);
            }

            @Override
            public int encodedSize(Batch msg, ProtocolVersion version) {
                int size = 3; // type + nb queries
                for (int i = 0; i < msg.queryOrIdList.size(); i++) {
                    Object q = msg.queryOrIdList.get(i);
                    size += 1 + (q instanceof String
                            ? CBUtil.sizeOfLongString((String) q)
                            : CBUtil.sizeOfBytes(((MD5Digest) q).bytes));

                    size += CBUtil.sizeOfValueList(msg.values.get(i));
                }
                size += msg.options.encodedSize(version);
                return size;
            }

            private byte fromType(BatchStatement.Type type) {
                switch (type) {
                    case LOGGED:
                        return 0;
                    case UNLOGGED:
                        return 1;
                    case COUNTER:
                        return 2;
                    default:
                        throw new AssertionError();
                }
            }
        };

        final BatchStatement.Type type;
        final List queryOrIdList;
        final List> values;
        final BatchProtocolOptions options;

        Batch(BatchStatement.Type type, List queryOrIdList, List> values, BatchProtocolOptions options, boolean tracingRequested) {
            super(Message.Request.Type.BATCH, tracingRequested);
            this.type = type;
            this.queryOrIdList = queryOrIdList;
            this.values = values;
            this.options = options;
        }

        @Override
        protected Request copyInternal() {
            return new Batch(type, queryOrIdList, values, options, isTracingRequested());
        }

        @Override
        protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
            return new Batch(type, queryOrIdList, values, options.copy(newConsistencyLevel), isTracingRequested());
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("BATCH of [");
            for (int i = 0; i < queryOrIdList.size(); i++) {
                if (i > 0) sb.append(", ");
                sb.append(queryOrIdList.get(i)).append(" with ").append(values.get(i).size()).append(" values");
            }
            sb.append("] with options ").append(options);
            return sb.toString();
        }
    }

    static class BatchProtocolOptions {
        private final EnumSet flags = EnumSet.noneOf(QueryFlag.class);
        final ConsistencyLevel consistency;
        final ConsistencyLevel serialConsistency;
        final long defaultTimestamp;

        BatchProtocolOptions(ConsistencyLevel consistency, ConsistencyLevel serialConsistency, long defaultTimestamp) {
            this.consistency = consistency;
            this.serialConsistency = serialConsistency;
            this.defaultTimestamp = defaultTimestamp;

            if (serialConsistency != ConsistencyLevel.SERIAL)
                flags.add(QueryFlag.SERIAL_CONSISTENCY);
            if (defaultTimestamp != Long.MIN_VALUE)
                flags.add(QueryFlag.DEFAULT_TIMESTAMP);
        }

        BatchProtocolOptions copy(ConsistencyLevel newConsistencyLevel) {
            return new BatchProtocolOptions(newConsistencyLevel, serialConsistency, defaultTimestamp);
        }

        void encode(ByteBuf dest, ProtocolVersion version) {
            switch (version) {
                case V2:
                    CBUtil.writeConsistencyLevel(consistency, dest);
                    break;
                case V3:
                case V4:
                    CBUtil.writeConsistencyLevel(consistency, dest);
                    dest.writeByte((byte) QueryFlag.serialize(flags));
                    if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
                        CBUtil.writeConsistencyLevel(serialConsistency, dest);
                    if (flags.contains(QueryFlag.DEFAULT_TIMESTAMP))
                        dest.writeLong(defaultTimestamp);
                    break;
                default:
                    throw version.unsupported();
            }
        }

        int encodedSize(ProtocolVersion version) {
            switch (version) {
                case V2:
                    return CBUtil.sizeOfConsistencyLevel(consistency);
                case V3:
                case V4:
                    int size = 0;
                    size += CBUtil.sizeOfConsistencyLevel(consistency);
                    size += 1; // flags
                    if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
                        size += CBUtil.sizeOfConsistencyLevel(serialConsistency);
                    if (flags.contains(QueryFlag.DEFAULT_TIMESTAMP))
                        size += 8;
                    return size;
                default:
                    throw version.unsupported();
            }
        }

        @Override
        public String toString() {
            return String.format("[cl=%s, serialCl=%s, defaultTs=%d]",
                    consistency, serialConsistency, defaultTimestamp);
        }
    }

    static class Prepare extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {

            @Override
            public void encode(Prepare msg, ByteBuf dest, ProtocolVersion version) {
                CBUtil.writeLongString(msg.query, dest);
            }

            @Override
            public int encodedSize(Prepare msg, ProtocolVersion version) {
                return CBUtil.sizeOfLongString(msg.query);
            }
        };

        private final String query;

        Prepare(String query) {
            super(Message.Request.Type.PREPARE);
            this.query = query;
        }

        @Override
        protected Request copyInternal() {
            return new Prepare(query);
        }

        @Override
        public String toString() {
            return "PREPARE " + query;
        }
    }

    static class Register extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {
            @Override
            public void encode(Register msg, ByteBuf dest, ProtocolVersion version) {
                dest.writeShort(msg.eventTypes.size());
                for (ProtocolEvent.Type type : msg.eventTypes)
                    CBUtil.writeEnumValue(type, dest);
            }

            @Override
            public int encodedSize(Register msg, ProtocolVersion version) {
                int size = 2;
                for (ProtocolEvent.Type type : msg.eventTypes)
                    size += CBUtil.sizeOfEnumValue(type);
                return size;
            }
        };

        private final List eventTypes;

        Register(List eventTypes) {
            super(Message.Request.Type.REGISTER);
            this.eventTypes = eventTypes;
        }

        @Override
        protected Request copyInternal() {
            return new Register(eventTypes);
        }

        @Override
        public String toString() {
            return "REGISTER " + eventTypes;
        }
    }

    static class AuthResponse extends Message.Request {

        static final Message.Coder coder = new Message.Coder() {

            @Override
            public void encode(AuthResponse response, ByteBuf dest, ProtocolVersion version) {
                CBUtil.writeValue(response.token, dest);
            }

            @Override
            public int encodedSize(AuthResponse response, ProtocolVersion version) {
                return CBUtil.sizeOfValue(response.token);
            }
        };

        private final byte[] token;

        AuthResponse(byte[] token) {
            super(Message.Request.Type.AUTH_RESPONSE);
            this.token = token;
        }

        @Override
        protected Request copyInternal() {
            return new AuthResponse(token);
        }
    }
}