All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hazelcast.org.apache.calcite.avatica.remote.ProtobufTranslationImpl Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hazelcast.org.apache.calcite.avatica.remote;

import com.hazelcast.org.apache.calcite.avatica.proto.Common.WireMessage;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.CatalogsRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.CloseConnectionRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.CloseStatementRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.ColumnsRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.CommitRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.ConnectionSyncRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.CreateStatementRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.DatabasePropertyRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.ExecuteBatchRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.ExecuteRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.FetchRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.OpenConnectionRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.PrepareAndExecuteBatchRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.PrepareAndExecuteRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.PrepareRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.RollbackRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.SchemasRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.SyncResultsRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.TableTypesRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.TablesRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Requests.TypeInfoRequest;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.CloseConnectionResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.CloseStatementResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.CommitResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.ConnectionSyncResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.CreateStatementResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.DatabasePropertyResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.ErrorResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.ExecuteBatchResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.ExecuteResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.FetchResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.OpenConnectionResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.PrepareResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.ResultSetResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.RollbackResponse;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.RpcMetadata;
import com.hazelcast.org.apache.calcite.avatica.proto.Responses.SyncResultsResponse;
import com.hazelcast.org.apache.calcite.avatica.remote.Service.Request;
import com.hazelcast.org.apache.calcite.avatica.remote.Service.Response;
import com.hazelcast.org.apache.calcite.avatica.remote.Service.RpcMetadataResponse;
import com.hazelcast.org.apache.calcite.avatica.util.UnsynchronizedBuffer;

import com.hazelcast.com.google.protobuf.ByteString;
import com.hazelcast.com.google.protobuf.CodedInputStream;
import com.hazelcast.com.google.protobuf.InvalidProtocolBufferException;
import com.hazelcast.com.google.protobuf.Message;
import com.hazelcast.com.google.protobuf.Parser;
import com.hazelcast.com.google.protobuf.TextFormat;
import com.hazelcast.com.google.protobuf.UnsafeByteOperations;

import com.hazelcast.org.slf4j.Logger;
import com.hazelcast.org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static java.nio.charset.StandardCharsets.UTF_8;

/**
 * Implementation of {@link ProtobufTranslationImpl} that translates
 * protobuf requests to POJO requests.
 */
public class ProtobufTranslationImpl implements ProtobufTranslation {
  private static final Logger LOG = LoggerFactory.getLogger(ProtobufTranslationImpl.class);

  /**
   * Encapsulate the logic of transforming a protobuf Request message into the Avatica POJO request.
   */
  static class RequestTranslator {

    private final Parser parser;
    private final Service.Request impl;

    RequestTranslator(Parser parser, Service.Request impl) {
      this.parser = parser;
      this.impl = impl;
    }

    public Service.Request transform(ByteString serializedMessage) throws
        InvalidProtocolBufferException {
      // This should already be an aliased CodedInputStream from the WireMessage parsing.
      Message msg = parser.parseFrom(serializedMessage.newCodedInput());
      if (LOG.isTraceEnabled()) {
        LOG.trace("Deserialized request '{}'", TextFormat.shortDebugString(msg));
      }
      return impl.deserialize(msg);
    }
  }

  /**
   * Encapsulate the logic of transforming a protobuf Response message into the Avatica POJO
   * Response.
   */
  static class ResponseTranslator {

    private final Parser parser;
    private final Service.Response impl;

    ResponseTranslator(Parser parser, Service.Response impl) {
      this.parser = parser;
      this.impl = impl;
    }

    public Service.Response transform(ByteString serializedMessage) throws
        InvalidProtocolBufferException {
      Message msg = parser.parseFrom(serializedMessage);
      if (LOG.isTraceEnabled()) {
        LOG.trace("Deserialized response '{}'", TextFormat.shortDebugString(msg));
      }
      return impl.deserialize(msg);
    }
  }

  // Extremely ugly mapping of PB class name into a means to convert it to the POJO
  private static final Map REQUEST_PARSERS;
  private static final Map RESPONSE_PARSERS;
  private static final Map, ByteString> MESSAGE_CLASSES;

  static {
    Map reqParsers = new ConcurrentHashMap<>();
    reqParsers.put(CatalogsRequest.class.getName(),
        new RequestTranslator(CatalogsRequest.parser(), new Service.CatalogsRequest()));
    reqParsers.put(OpenConnectionRequest.class.getName(),
        new RequestTranslator(OpenConnectionRequest.parser(), new Service.OpenConnectionRequest()));
    reqParsers.put(CloseConnectionRequest.class.getName(),
        new RequestTranslator(CloseConnectionRequest.parser(),
          new Service.CloseConnectionRequest()));
    reqParsers.put(CloseStatementRequest.class.getName(),
        new RequestTranslator(CloseStatementRequest.parser(), new Service.CloseStatementRequest()));
    reqParsers.put(ColumnsRequest.class.getName(),
        new RequestTranslator(ColumnsRequest.parser(), new Service.ColumnsRequest()));
    reqParsers.put(ConnectionSyncRequest.class.getName(),
        new RequestTranslator(ConnectionSyncRequest.parser(), new Service.ConnectionSyncRequest()));
    reqParsers.put(CreateStatementRequest.class.getName(),
        new RequestTranslator(CreateStatementRequest.parser(),
          new Service.CreateStatementRequest()));
    reqParsers.put(DatabasePropertyRequest.class.getName(),
        new RequestTranslator(DatabasePropertyRequest.parser(),
            new Service.DatabasePropertyRequest()));
    reqParsers.put(FetchRequest.class.getName(),
        new RequestTranslator(FetchRequest.parser(), new Service.FetchRequest()));
    reqParsers.put(PrepareAndExecuteRequest.class.getName(),
        new RequestTranslator(PrepareAndExecuteRequest.parser(),
            new Service.PrepareAndExecuteRequest()));
    reqParsers.put(PrepareRequest.class.getName(),
        new RequestTranslator(PrepareRequest.parser(), new Service.PrepareRequest()));
    reqParsers.put(SchemasRequest.class.getName(),
        new RequestTranslator(SchemasRequest.parser(), new Service.SchemasRequest()));
    reqParsers.put(TablesRequest.class.getName(),
        new RequestTranslator(TablesRequest.parser(), new Service.TablesRequest()));
    reqParsers.put(TableTypesRequest.class.getName(),
        new RequestTranslator(TableTypesRequest.parser(), new Service.TableTypesRequest()));
    reqParsers.put(TypeInfoRequest.class.getName(),
        new RequestTranslator(TypeInfoRequest.parser(), new Service.TypeInfoRequest()));
    reqParsers.put(ExecuteRequest.class.getName(),
        new RequestTranslator(ExecuteRequest.parser(), new Service.ExecuteRequest()));
    reqParsers.put(SyncResultsRequest.class.getName(),
        new RequestTranslator(SyncResultsRequest.parser(), new Service.SyncResultsRequest()));
    reqParsers.put(CommitRequest.class.getName(),
        new RequestTranslator(CommitRequest.parser(), new Service.CommitRequest()));
    reqParsers.put(RollbackRequest.class.getName(),
        new RequestTranslator(RollbackRequest.parser(), new Service.RollbackRequest()));
    reqParsers.put(PrepareAndExecuteBatchRequest.class.getName(),
        new RequestTranslator(PrepareAndExecuteBatchRequest.parser(),
            new Service.PrepareAndExecuteBatchRequest()));
    reqParsers.put(ExecuteBatchRequest.class.getName(),
        new RequestTranslator(ExecuteBatchRequest.parser(),
            new Service.ExecuteBatchRequest()));

    REQUEST_PARSERS = Collections.unmodifiableMap(reqParsers);

    Map respParsers = new ConcurrentHashMap<>();
    respParsers.put(OpenConnectionResponse.class.getName(),
        new ResponseTranslator(OpenConnectionResponse.parser(),
            new Service.OpenConnectionResponse()));
    respParsers.put(CloseConnectionResponse.class.getName(),
        new ResponseTranslator(CloseConnectionResponse.parser(),
            new Service.CloseConnectionResponse()));
    respParsers.put(CloseStatementResponse.class.getName(),
        new ResponseTranslator(CloseStatementResponse.parser(),
            new Service.CloseStatementResponse()));
    respParsers.put(ConnectionSyncResponse.class.getName(),
        new ResponseTranslator(ConnectionSyncResponse.parser(),
            new Service.ConnectionSyncResponse()));
    respParsers.put(CreateStatementResponse.class.getName(),
        new ResponseTranslator(CreateStatementResponse.parser(),
            new Service.CreateStatementResponse()));
    respParsers.put(DatabasePropertyResponse.class.getName(),
        new ResponseTranslator(DatabasePropertyResponse.parser(),
            new Service.DatabasePropertyResponse()));
    respParsers.put(ExecuteResponse.class.getName(),
        new ResponseTranslator(ExecuteResponse.parser(), new Service.ExecuteResponse()));
    respParsers.put(FetchResponse.class.getName(),
        new ResponseTranslator(FetchResponse.parser(), new Service.FetchResponse()));
    respParsers.put(PrepareResponse.class.getName(),
        new ResponseTranslator(PrepareResponse.parser(), new Service.PrepareResponse()));
    respParsers.put(ResultSetResponse.class.getName(),
        new ResponseTranslator(ResultSetResponse.parser(), new Service.ResultSetResponse()));
    respParsers.put(ErrorResponse.class.getName(),
        new ResponseTranslator(ErrorResponse.parser(), new Service.ErrorResponse()));
    respParsers.put(SyncResultsResponse.class.getName(),
        new ResponseTranslator(SyncResultsResponse.parser(), new Service.SyncResultsResponse()));
    respParsers.put(RpcMetadata.class.getName(),
        new ResponseTranslator(RpcMetadata.parser(), new RpcMetadataResponse()));
    respParsers.put(CommitResponse.class.getName(),
        new ResponseTranslator(CommitResponse.parser(), new Service.CommitResponse()));
    respParsers.put(RollbackResponse.class.getName(),
        new ResponseTranslator(RollbackResponse.parser(), new Service.RollbackResponse()));
    respParsers.put(ExecuteBatchResponse.class.getName(),
        new ResponseTranslator(ExecuteBatchResponse.parser(), new Service.ExecuteBatchResponse()));

    RESPONSE_PARSERS = Collections.unmodifiableMap(respParsers);

    Map, ByteString> messageClassNames = new ConcurrentHashMap<>();
    for (Class msgClz : getAllMessageClasses()) {
      messageClassNames.put(msgClz, wrapClassName(msgClz));
    }
    MESSAGE_CLASSES = Collections.unmodifiableMap(messageClassNames);
  }

  private static List> getAllMessageClasses() {
    List> messageClasses = new ArrayList<>();
    messageClasses.add(CatalogsRequest.class);
    messageClasses.add(CloseConnectionRequest.class);
    messageClasses.add(CloseStatementRequest.class);
    messageClasses.add(ColumnsRequest.class);
    messageClasses.add(CommitRequest.class);
    messageClasses.add(ConnectionSyncRequest.class);
    messageClasses.add(CreateStatementRequest.class);
    messageClasses.add(DatabasePropertyRequest.class);
    messageClasses.add(ExecuteRequest.class);
    messageClasses.add(FetchRequest.class);
    messageClasses.add(OpenConnectionRequest.class);
    messageClasses.add(PrepareAndExecuteRequest.class);
    messageClasses.add(PrepareRequest.class);
    messageClasses.add(RollbackRequest.class);
    messageClasses.add(SchemasRequest.class);
    messageClasses.add(SyncResultsRequest.class);
    messageClasses.add(TableTypesRequest.class);
    messageClasses.add(TablesRequest.class);
    messageClasses.add(TypeInfoRequest.class);
    messageClasses.add(PrepareAndExecuteBatchRequest.class);
    messageClasses.add(ExecuteBatchRequest.class);

    messageClasses.add(CloseConnectionResponse.class);
    messageClasses.add(CloseStatementResponse.class);
    messageClasses.add(CommitResponse.class);
    messageClasses.add(ConnectionSyncResponse.class);
    messageClasses.add(CreateStatementResponse.class);
    messageClasses.add(DatabasePropertyResponse.class);
    messageClasses.add(ErrorResponse.class);
    messageClasses.add(ExecuteResponse.class);
    messageClasses.add(FetchResponse.class);
    messageClasses.add(OpenConnectionResponse.class);
    messageClasses.add(PrepareResponse.class);
    messageClasses.add(ResultSetResponse.class);
    messageClasses.add(RollbackResponse.class);
    messageClasses.add(RpcMetadata.class);
    messageClasses.add(SyncResultsResponse.class);
    messageClasses.add(ExecuteBatchResponse.class);

    return messageClasses;
  }

  private static ByteString wrapClassName(Class clz) {
    return UnsafeByteOperations.unsafeWrap(clz.getName().getBytes(UTF_8));
  }

  private final ThreadLocal threadLocalBuffer =
      new ThreadLocal() {
        @Override protected UnsynchronizedBuffer initialValue() {
          return new UnsynchronizedBuffer();
        }
      };

  /**
   * Fetches the concrete message's Parser implementation.
   *
   * @param className The protocol buffer class name
   * @return The Parser for the class
   * @throws IllegalArgumentException If the argument is null or if a Parser for the given
   *     class name is not found.
   */
  public static RequestTranslator getParserForRequest(String className) {
    if (null == className || className.isEmpty()) {
      throw new IllegalArgumentException("Cannot fetch parser for Request with "
          + (null == className ? "null" : "missing") + " class name");
    }

    RequestTranslator translator = REQUEST_PARSERS.get(className);
    if (null == translator) {
      throw new IllegalArgumentException("Cannot find request parser for " + className);
    }

    return translator;
  }

  /**
   * Fetches the concrete message's Parser implementation.
   *
   * @param className The protocol buffer class name
   * @return The Parser for the class
   * @throws IllegalArgumentException If the argument is null or if a Parser for the given
   *     class name is not found.
   */
  public static ResponseTranslator getParserForResponse(String className) {
    if (null == className || className.isEmpty()) {
      throw new IllegalArgumentException("Cannot fetch parser for Response with "
          + (null == className ? "null" : "missing") + " class name");
    }

    ResponseTranslator translator = RESPONSE_PARSERS.get(className);
    if (null == translator) {
      throw new IllegalArgumentException("Cannot find response parser for " + className);
    }

    return translator;
  }

  @Override public byte[] serializeResponse(Response response) throws IOException {
    // Avoid BAOS for its synchronized write methods, we don't need that concurrency control
    UnsynchronizedBuffer out = threadLocalBuffer.get();
    try {
      Message responseMsg = response.serialize();
      // Serialization of the response may be large
      if (LOG.isTraceEnabled()) {
        LOG.trace("Serializing response '{}'", TextFormat.shortDebugString(responseMsg));
      }
      serializeMessage(out, responseMsg);
      return out.toArray();
    } finally {
      out.reset();
    }
  }

  @Override public byte[] serializeRequest(Request request) throws IOException {
    // Avoid BAOS for its synchronized write methods, we don't need that concurrency control
    UnsynchronizedBuffer out = threadLocalBuffer.get();
    try {
      Message requestMsg = request.serialize();
      // Serialization of the request may be large
      if (LOG.isTraceEnabled()) {
        LOG.trace("Serializing request '{}'", TextFormat.shortDebugString(requestMsg));
      }
      serializeMessage(out, requestMsg);
      return out.toArray();
    } finally {
      out.reset();
    }
  }

  void serializeMessage(OutputStream out, Message msg) throws IOException {
    // Serialize the protobuf message
    UnsynchronizedBuffer buffer = threadLocalBuffer.get();
    ByteString serializedMsg;
    try {
      msg.writeTo(buffer);
      // Make a bytestring from it
      serializedMsg = UnsafeByteOperations.unsafeWrap(buffer.toArray());
    } finally {
      buffer.reset();
    }

    // Wrap the serialized message in a WireMessage
    WireMessage wireMsg = WireMessage.newBuilder().setNameBytes(getClassNameBytes(msg.getClass()))
        .setWrappedMessage(serializedMsg).build();

    // Write the WireMessage to the provided OutputStream
    wireMsg.writeTo(out);
  }

  ByteString getClassNameBytes(Class clz) {
    ByteString byteString = MESSAGE_CLASSES.get(clz);
    if (null == byteString) {
      throw new IllegalArgumentException("Missing ByteString for " + clz.getName());
    }
    return byteString;
  }

  @Override public Request parseRequest(byte[] bytes) throws IOException {
    ByteString byteString = UnsafeByteOperations.unsafeWrap(bytes);
    CodedInputStream inputStream = byteString.newCodedInput();
    // Enable aliasing to avoid an extra copy to get at the serialized Request inside of the
    // WireMessage.
    inputStream.enableAliasing(true);
    WireMessage wireMsg = WireMessage.parseFrom(inputStream);

    String serializedMessageClassName = wireMsg.getName();

    try {
      RequestTranslator translator = getParserForRequest(serializedMessageClassName);

      // The ByteString should be logical offsets into the original byte array
      return translator.transform(wireMsg.getWrappedMessage());
    } catch (RuntimeException e) {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Failed to parse request message '{}'", TextFormat.shortDebugString(wireMsg));
      }
      throw e;
    }
  }

  @Override public Response parseResponse(byte[] bytes) throws IOException {
    ByteString byteString = UnsafeByteOperations.unsafeWrap(bytes);
    CodedInputStream inputStream = byteString.newCodedInput();
    // Enable aliasing to avoid an extra copy to get at the serialized Response inside of the
    // WireMessage.
    inputStream.enableAliasing(true);
    WireMessage wireMsg = WireMessage.parseFrom(inputStream);

    String serializedMessageClassName = wireMsg.getName();
    try {
      ResponseTranslator translator = getParserForResponse(serializedMessageClassName);

      return translator.transform(wireMsg.getWrappedMessage());
    } catch (RuntimeException e) {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Failed to parse response message '{}'", TextFormat.shortDebugString(wireMsg));
      }
      throw e;
    }
  }
}

// End ProtobufTranslationImpl.java




© 2015 - 2024 Weber Informatics LLC | Privacy Policy