All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.driver.core.Requests Maven / Gradle / Ivy

/*
 * Copyright DataStax, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.datastax.driver.core;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;

class Requests {

  static final ByteBuffer[] EMPTY_BB_ARRAY = new ByteBuffer[0];

  private Requests() {}

  static class Startup extends Message.Request {
    private static final String CQL_VERSION_OPTION = "CQL_VERSION";
    private static final String CQL_VERSION = "3.0.0";
    private static final String DRIVER_VERSION_OPTION = "DRIVER_VERSION";
    private static final String DRIVER_NAME_OPTION = "DRIVER_NAME";
    private static final String DRIVER_NAME = "DataStax Java Driver";

    static final String COMPRESSION_OPTION = "COMPRESSION";
    static final String NO_COMPACT_OPTION = "NO_COMPACT";

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Startup msg, ByteBuf dest, ProtocolVersion version) {
            CBUtil.writeStringMap(msg.options, dest);
          }

          @Override
          public int encodedSize(Startup msg, ProtocolVersion version) {
            return CBUtil.sizeOfStringMap(msg.options);
          }
        };

    private final Map options;
    private final ProtocolOptions.Compression compression;
    private final boolean noCompact;

    Startup(ProtocolOptions.Compression compression, boolean noCompact) {
      super(Message.Request.Type.STARTUP);
      this.compression = compression;
      this.noCompact = noCompact;

      ImmutableMap.Builder map = new ImmutableMap.Builder();
      map.put(CQL_VERSION_OPTION, CQL_VERSION);
      if (compression != ProtocolOptions.Compression.NONE)
        map.put(COMPRESSION_OPTION, compression.toString());
      if (noCompact) map.put(NO_COMPACT_OPTION, "true");

      map.put(DRIVER_VERSION_OPTION, Cluster.getDriverVersion());
      map.put(DRIVER_NAME_OPTION, DRIVER_NAME);

      this.options = map.build();
    }

    @Override
    protected Request copyInternal() {
      return new Startup(compression, noCompact);
    }

    @Override
    public String toString() {
      return "STARTUP " + options;
    }
  }

  // Only for protocol v1
  static class Credentials extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {

          @Override
          public void encode(Credentials msg, ByteBuf dest, ProtocolVersion version) {
            assert version == ProtocolVersion.V1;
            CBUtil.writeStringMap(msg.credentials, dest);
          }

          @Override
          public int encodedSize(Credentials msg, ProtocolVersion version) {
            assert version == ProtocolVersion.V1;
            return CBUtil.sizeOfStringMap(msg.credentials);
          }
        };

    private final Map credentials;

    Credentials(Map credentials) {
      super(Message.Request.Type.CREDENTIALS);
      this.credentials = credentials;
    }

    @Override
    protected Request copyInternal() {
      return new Credentials(credentials);
    }
  }

  static class Options extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Options msg, ByteBuf dest, ProtocolVersion version) {}

          @Override
          public int encodedSize(Options msg, ProtocolVersion version) {
            return 0;
          }
        };

    Options() {
      super(Message.Request.Type.OPTIONS);
    }

    @Override
    protected Request copyInternal() {
      return new Options();
    }

    @Override
    public String toString() {
      return "OPTIONS";
    }
  }

  static class Query extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Query msg, ByteBuf dest, ProtocolVersion version) {
            CBUtil.writeLongString(msg.query, dest);
            msg.options.encode(dest, version);
          }

          @Override
          public int encodedSize(Query msg, ProtocolVersion version) {
            return CBUtil.sizeOfLongString(msg.query) + msg.options.encodedSize(version);
          }
        };

    final String query;
    final QueryProtocolOptions options;

    Query(String query) {
      this(query, QueryProtocolOptions.DEFAULT, false);
    }

    Query(String query, QueryProtocolOptions options, boolean tracingRequested) {
      super(Type.QUERY, tracingRequested);
      this.query = query;
      this.options = options;
    }

    @Override
    protected Request copyInternal() {
      return new Query(this.query, options, isTracingRequested());
    }

    @Override
    protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
      return new Query(this.query, options.copy(newConsistencyLevel), isTracingRequested());
    }

    @Override
    public String toString() {
      return "QUERY " + query + '(' + options + ')';
    }
  }

  static class Execute extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Execute msg, ByteBuf dest, ProtocolVersion version) {
            CBUtil.writeShortBytes(msg.statementId.bytes, dest);
            if (ProtocolFeature.PREPARED_METADATA_CHANGES.isSupportedBy(version))
              CBUtil.writeShortBytes(msg.resultMetadataId.bytes, dest);
            msg.options.encode(dest, version);
          }

          @Override
          public int encodedSize(Execute msg, ProtocolVersion version) {
            int size = CBUtil.sizeOfShortBytes(msg.statementId.bytes);
            if (ProtocolFeature.PREPARED_METADATA_CHANGES.isSupportedBy(version))
              size += CBUtil.sizeOfShortBytes(msg.resultMetadataId.bytes);
            size += msg.options.encodedSize(version);
            return size;
          }
        };

    final MD5Digest statementId;
    final MD5Digest resultMetadataId;
    final QueryProtocolOptions options;

    Execute(
        MD5Digest statementId,
        MD5Digest resultMetadataId,
        QueryProtocolOptions options,
        boolean tracingRequested) {
      super(Message.Request.Type.EXECUTE, tracingRequested);
      this.statementId = statementId;
      this.resultMetadataId = resultMetadataId;
      this.options = options;
    }

    @Override
    protected Request copyInternal() {
      return new Execute(statementId, resultMetadataId, options, isTracingRequested());
    }

    @Override
    protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
      return new Execute(
          statementId, resultMetadataId, options.copy(newConsistencyLevel), isTracingRequested());
    }

    @Override
    public String toString() {
      if (resultMetadataId != null)
        return "EXECUTE preparedId: "
            + statementId
            + " resultMetadataId: "
            + resultMetadataId
            + " ("
            + options
            + ')';
      else return "EXECUTE preparedId: " + statementId + " (" + options + ')';
    }
  }

  enum QueryFlag {
    VALUES(0x00000001),
    SKIP_METADATA(0x00000002),
    PAGE_SIZE(0x00000004),
    PAGING_STATE(0x00000008),
    SERIAL_CONSISTENCY(0x00000010),
    DEFAULT_TIMESTAMP(0x00000020),
    VALUE_NAMES(0x00000040),
    NOW_IN_SECONDS(0x00000100),
    ;

    private int mask;

    QueryFlag(int mask) {
      this.mask = mask;
    }

    static EnumSet deserialize(int flags) {
      EnumSet set = EnumSet.noneOf(QueryFlag.class);
      for (QueryFlag flag : values()) {
        if ((flags & flag.mask) != 0) set.add(flag);
      }
      return set;
    }

    static void serialize(EnumSet flags, ByteBuf dest, ProtocolVersion version) {
      int i = 0;
      for (QueryFlag flag : flags) i |= flag.mask;
      if (version.compareTo(ProtocolVersion.V5) >= 0) {
        dest.writeInt(i);
      } else {
        dest.writeByte((byte) i);
      }
    }

    static int serializedSize(ProtocolVersion version) {
      return version.compareTo(ProtocolVersion.V5) >= 0 ? 4 : 1;
    }
  }

  static class QueryProtocolOptions {

    static final QueryProtocolOptions DEFAULT =
        new QueryProtocolOptions(
            Message.Request.Type.QUERY,
            ConsistencyLevel.ONE,
            EMPTY_BB_ARRAY,
            Collections.emptyMap(),
            false,
            -1,
            null,
            ConsistencyLevel.SERIAL,
            Long.MIN_VALUE,
            Integer.MIN_VALUE);

    private final EnumSet flags = EnumSet.noneOf(QueryFlag.class);
    private final Message.Request.Type requestType;
    final ConsistencyLevel consistency;
    final ByteBuffer[] positionalValues;
    final Map namedValues;
    final boolean skipMetadata;
    final int pageSize;
    final ByteBuffer pagingState;
    final ConsistencyLevel serialConsistency;
    final long defaultTimestamp;
    final int nowInSeconds;

    QueryProtocolOptions(
        Message.Request.Type requestType,
        ConsistencyLevel consistency,
        ByteBuffer[] positionalValues,
        Map namedValues,
        boolean skipMetadata,
        int pageSize,
        ByteBuffer pagingState,
        ConsistencyLevel serialConsistency,
        long defaultTimestamp,
        int nowInSeconds) {

      Preconditions.checkArgument(positionalValues.length == 0 || namedValues.isEmpty());

      this.requestType = requestType;
      this.consistency = consistency;
      this.positionalValues = positionalValues;
      this.namedValues = namedValues;
      this.skipMetadata = skipMetadata;
      this.pageSize = pageSize;
      this.pagingState = pagingState;
      this.serialConsistency = serialConsistency;
      this.defaultTimestamp = defaultTimestamp;
      this.nowInSeconds = nowInSeconds;

      // Populate flags
      if (positionalValues.length > 0) {
        flags.add(QueryFlag.VALUES);
      }
      if (!namedValues.isEmpty()) {
        flags.add(QueryFlag.VALUES);
        flags.add(QueryFlag.VALUE_NAMES);
      }
      if (skipMetadata) flags.add(QueryFlag.SKIP_METADATA);
      if (pageSize >= 0) flags.add(QueryFlag.PAGE_SIZE);
      if (pagingState != null) flags.add(QueryFlag.PAGING_STATE);
      if (serialConsistency != ConsistencyLevel.SERIAL) flags.add(QueryFlag.SERIAL_CONSISTENCY);
      if (defaultTimestamp != Long.MIN_VALUE) flags.add(QueryFlag.DEFAULT_TIMESTAMP);
      if (nowInSeconds != Integer.MIN_VALUE) flags.add(QueryFlag.NOW_IN_SECONDS);
    }

    QueryProtocolOptions copy(ConsistencyLevel newConsistencyLevel) {
      return new QueryProtocolOptions(
          requestType,
          newConsistencyLevel,
          positionalValues,
          namedValues,
          skipMetadata,
          pageSize,
          pagingState,
          serialConsistency,
          defaultTimestamp,
          nowInSeconds);
    }

    void encode(ByteBuf dest, ProtocolVersion version) {
      switch (version) {
        case V1:
          // only EXECUTE messages have variables in V1, and their list must be written
          // even if it is empty; and they are never named
          if (requestType == Message.Request.Type.EXECUTE)
            CBUtil.writeValueList(positionalValues, dest);
          CBUtil.writeConsistencyLevel(consistency, dest);
          break;
        case V2:
        case V3:
        case V4:
        case V5:
          CBUtil.writeConsistencyLevel(consistency, dest);
          QueryFlag.serialize(flags, dest, version);
          if (flags.contains(QueryFlag.VALUES)) {
            if (flags.contains(QueryFlag.VALUE_NAMES)) {
              assert version.compareTo(ProtocolVersion.V3) >= 0;
              CBUtil.writeNamedValueList(namedValues, dest);
            } else {
              CBUtil.writeValueList(positionalValues, dest);
            }
          }
          if (flags.contains(QueryFlag.PAGE_SIZE)) dest.writeInt(pageSize);
          if (flags.contains(QueryFlag.PAGING_STATE)) CBUtil.writeValue(pagingState, dest);
          if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
            CBUtil.writeConsistencyLevel(serialConsistency, dest);
          if (version.compareTo(ProtocolVersion.V3) >= 0
              && flags.contains(QueryFlag.DEFAULT_TIMESTAMP)) dest.writeLong(defaultTimestamp);
          if (version.compareTo(ProtocolVersion.V5) >= 0
              && flags.contains(QueryFlag.NOW_IN_SECONDS)) dest.writeInt(nowInSeconds);
          break;
        default:
          throw version.unsupported();
      }
    }

    int encodedSize(ProtocolVersion version) {
      switch (version) {
        case V1:
          // only EXECUTE messages have variables in V1, and their list must be written
          // even if it is empty; and they are never named
          return (requestType == Message.Request.Type.EXECUTE
                  ? CBUtil.sizeOfValueList(positionalValues)
                  : 0)
              + CBUtil.sizeOfConsistencyLevel(consistency);
        case V2:
        case V3:
        case V4:
        case V5:
          int size = 0;
          size += CBUtil.sizeOfConsistencyLevel(consistency);
          size += QueryFlag.serializedSize(version);
          if (flags.contains(QueryFlag.VALUES)) {
            if (flags.contains(QueryFlag.VALUE_NAMES)) {
              assert version.compareTo(ProtocolVersion.V3) >= 0;
              size += CBUtil.sizeOfNamedValueList(namedValues);
            } else {
              size += CBUtil.sizeOfValueList(positionalValues);
            }
          }
          if (flags.contains(QueryFlag.PAGE_SIZE)) size += 4;
          if (flags.contains(QueryFlag.PAGING_STATE)) size += CBUtil.sizeOfValue(pagingState);
          if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
            size += CBUtil.sizeOfConsistencyLevel(serialConsistency);
          if (version.compareTo(ProtocolVersion.V3) >= 0
              && flags.contains(QueryFlag.DEFAULT_TIMESTAMP)) size += 8;
          if (version.compareTo(ProtocolVersion.V5) >= 0
              && flags.contains(QueryFlag.NOW_IN_SECONDS)) size += 4;
          return size;
        default:
          throw version.unsupported();
      }
    }

    @Override
    public String toString() {
      return String.format(
          "[cl=%s, positionalVals=%s, namedVals=%s, skip=%b, psize=%d, state=%s, serialCl=%s]",
          consistency,
          Arrays.toString(positionalValues),
          namedValues,
          skipMetadata,
          pageSize,
          pagingState,
          serialConsistency);
    }
  }

  static class Batch extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Batch msg, ByteBuf dest, ProtocolVersion version) {
            int queries = msg.queryOrIdList.size();
            assert queries <= 0xFFFF;

            dest.writeByte(fromType(msg.type));
            dest.writeShort(queries);

            for (int i = 0; i < queries; i++) {
              Object q = msg.queryOrIdList.get(i);
              dest.writeByte((byte) (q instanceof String ? 0 : 1));
              if (q instanceof String) CBUtil.writeLongString((String) q, dest);
              else CBUtil.writeShortBytes(((MD5Digest) q).bytes, dest);

              CBUtil.writeValueList(msg.values[i], dest);
            }

            msg.options.encode(dest, version);
          }

          @Override
          public int encodedSize(Batch msg, ProtocolVersion version) {
            int size = 3; // type + nb queries
            for (int i = 0; i < msg.queryOrIdList.size(); i++) {
              Object q = msg.queryOrIdList.get(i);
              size +=
                  1
                      + (q instanceof String
                          ? CBUtil.sizeOfLongString((String) q)
                          : CBUtil.sizeOfShortBytes(((MD5Digest) q).bytes));

              size += CBUtil.sizeOfValueList(msg.values[i]);
            }
            size += msg.options.encodedSize(version);
            return size;
          }

          private byte fromType(BatchStatement.Type type) {
            switch (type) {
              case LOGGED:
                return 0;
              case UNLOGGED:
                return 1;
              case COUNTER:
                return 2;
              default:
                throw new AssertionError();
            }
          }
        };

    final BatchStatement.Type type;
    final List queryOrIdList;
    final ByteBuffer[][] values;
    final BatchProtocolOptions options;

    Batch(
        BatchStatement.Type type,
        List queryOrIdList,
        ByteBuffer[][] values,
        BatchProtocolOptions options,
        boolean tracingRequested) {
      super(Message.Request.Type.BATCH, tracingRequested);
      this.type = type;
      this.queryOrIdList = queryOrIdList;
      this.values = values;
      this.options = options;
    }

    @Override
    protected Request copyInternal() {
      return new Batch(type, queryOrIdList, values, options, isTracingRequested());
    }

    @Override
    protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
      return new Batch(
          type, queryOrIdList, values, options.copy(newConsistencyLevel), isTracingRequested());
    }

    @Override
    public String toString() {
      StringBuilder sb = new StringBuilder();
      sb.append("BATCH of [");
      for (int i = 0; i < queryOrIdList.size(); i++) {
        if (i > 0) sb.append(", ");
        sb.append(queryOrIdList.get(i)).append(" with ").append(values[i].length).append(" values");
      }
      sb.append("] with options ").append(options);
      return sb.toString();
    }
  }

  static class BatchProtocolOptions {
    private final EnumSet flags = EnumSet.noneOf(QueryFlag.class);
    final ConsistencyLevel consistency;
    final ConsistencyLevel serialConsistency;
    final long defaultTimestamp;
    final int nowInSeconds;

    BatchProtocolOptions(
        ConsistencyLevel consistency,
        ConsistencyLevel serialConsistency,
        long defaultTimestamp,
        int nowInSeconds) {
      this.consistency = consistency;
      this.serialConsistency = serialConsistency;
      this.defaultTimestamp = defaultTimestamp;
      this.nowInSeconds = nowInSeconds;

      if (serialConsistency != ConsistencyLevel.SERIAL) flags.add(QueryFlag.SERIAL_CONSISTENCY);
      if (defaultTimestamp != Long.MIN_VALUE) flags.add(QueryFlag.DEFAULT_TIMESTAMP);
      if (nowInSeconds != Integer.MIN_VALUE) flags.add(QueryFlag.NOW_IN_SECONDS);
    }

    BatchProtocolOptions copy(ConsistencyLevel newConsistencyLevel) {
      return new BatchProtocolOptions(
          newConsistencyLevel, serialConsistency, defaultTimestamp, nowInSeconds);
    }

    void encode(ByteBuf dest, ProtocolVersion version) {
      switch (version) {
        case V2:
          CBUtil.writeConsistencyLevel(consistency, dest);
          break;
        case V3:
        case V4:
        case V5:
          CBUtil.writeConsistencyLevel(consistency, dest);
          QueryFlag.serialize(flags, dest, version);
          if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
            CBUtil.writeConsistencyLevel(serialConsistency, dest);
          if (flags.contains(QueryFlag.DEFAULT_TIMESTAMP)) dest.writeLong(defaultTimestamp);
          if (version.compareTo(ProtocolVersion.V5) >= 0
              && flags.contains(QueryFlag.NOW_IN_SECONDS)) dest.writeInt(nowInSeconds);
          break;
        default:
          throw version.unsupported();
      }
    }

    int encodedSize(ProtocolVersion version) {
      switch (version) {
        case V2:
          return CBUtil.sizeOfConsistencyLevel(consistency);
        case V3:
        case V4:
        case V5:
          int size = 0;
          size += CBUtil.sizeOfConsistencyLevel(consistency);
          size += QueryFlag.serializedSize(version);
          if (flags.contains(QueryFlag.SERIAL_CONSISTENCY))
            size += CBUtil.sizeOfConsistencyLevel(serialConsistency);
          if (flags.contains(QueryFlag.DEFAULT_TIMESTAMP)) size += 8;
          if (version.compareTo(ProtocolVersion.V5) >= 0
              && flags.contains(QueryFlag.NOW_IN_SECONDS)) size += 4;
          return size;
        default:
          throw version.unsupported();
      }
    }

    @Override
    public String toString() {
      return String.format(
          "[cl=%s, serialCl=%s, defaultTs=%d]", consistency, serialConsistency, defaultTimestamp);
    }
  }

  static class Prepare extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {

          @Override
          public void encode(Prepare msg, ByteBuf dest, ProtocolVersion version) {
            CBUtil.writeLongString(msg.query, dest);

            if (version.compareTo(ProtocolVersion.V5) >= 0) {
              // Write empty flags for now, to communicate that no keyspace is being set.
              dest.writeInt(0);
            }
          }

          @Override
          public int encodedSize(Prepare msg, ProtocolVersion version) {
            int size = CBUtil.sizeOfLongString(msg.query);

            if (version.compareTo(ProtocolVersion.V5) >= 0) {
              size += 4; // flags
            }
            return size;
          }
        };

    private final String query;

    Prepare(String query) {
      super(Message.Request.Type.PREPARE);
      this.query = query;
    }

    @Override
    protected Request copyInternal() {
      return new Prepare(query);
    }

    @Override
    public String toString() {
      return "PREPARE " + query;
    }
  }

  static class Register extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {
          @Override
          public void encode(Register msg, ByteBuf dest, ProtocolVersion version) {
            dest.writeShort(msg.eventTypes.size());
            for (ProtocolEvent.Type type : msg.eventTypes) CBUtil.writeEnumValue(type, dest);
          }

          @Override
          public int encodedSize(Register msg, ProtocolVersion version) {
            int size = 2;
            for (ProtocolEvent.Type type : msg.eventTypes) size += CBUtil.sizeOfEnumValue(type);
            return size;
          }
        };

    private final List eventTypes;

    Register(List eventTypes) {
      super(Message.Request.Type.REGISTER);
      this.eventTypes = eventTypes;
    }

    @Override
    protected Request copyInternal() {
      return new Register(eventTypes);
    }

    @Override
    public String toString() {
      return "REGISTER " + eventTypes;
    }
  }

  static class AuthResponse extends Message.Request {

    static final Message.Coder coder =
        new Message.Coder() {

          @Override
          public void encode(AuthResponse response, ByteBuf dest, ProtocolVersion version) {
            CBUtil.writeValue(response.token, dest);
          }

          @Override
          public int encodedSize(AuthResponse response, ProtocolVersion version) {
            return CBUtil.sizeOfValue(response.token);
          }
        };

    private final byte[] token;

    AuthResponse(byte[] token) {
      super(Message.Request.Type.AUTH_RESPONSE);
      this.token = token;
    }

    @Override
    protected Request copyInternal() {
      return new AuthResponse(token);
    }
  }
}