All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hive.spark.client.rpc.KryoMessageCodec Maven / Gradle / Ivy

There is a newer version: 2.3.10
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hive.spark.client.rpc;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;

import org.objenesis.strategy.StdInstantiatorStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.ByteBufferInputStream;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.google.common.base.Preconditions;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;

/**
 * Codec that serializes / deserializes objects using Kryo. Objects are encoded with a 4-byte
 * header with the length of the serialized data.
 */
class KryoMessageCodec extends ByteToMessageCodec {

  private static final Logger LOG = LoggerFactory.getLogger(KryoMessageCodec.class);

  // Kryo docs say 0-8 are taken. Strange things happen if you don't set an ID when registering
  // classes.
  private static final int REG_ID_BASE = 16;

  private final int maxMessageSize;
  private final List> messages;
  private final ThreadLocal kryos = new ThreadLocal() {
    @Override
    protected Kryo initialValue() {
      Kryo kryo = new Kryo();
      int count = 0;
      for (Class klass : messages) {
        kryo.register(klass, REG_ID_BASE + count);
        count++;
      }
      kryo.setInstantiatorStrategy(new Kryo.DefaultInstantiatorStrategy(new StdInstantiatorStrategy()));
      return kryo;
    }
  };

  private volatile EncryptionHandler encryptionHandler;

  public KryoMessageCodec(int maxMessageSize, Class... messages) {
    this.maxMessageSize = maxMessageSize;
    this.messages = Arrays.asList(messages);
    this.encryptionHandler = null;
  }

  @Override
  protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out)
      throws Exception {
    if (in.readableBytes() < 4) {
      return;
    }

    in.markReaderIndex();
    int msgSize = in.readInt();
    checkSize(msgSize);

    if (in.readableBytes() < msgSize) {
      // Incomplete message in buffer.
      in.resetReaderIndex();
      return;
    }

    try {
      ByteBuffer nioBuffer = maybeDecrypt(in.nioBuffer(in.readerIndex(), msgSize));
      Input kryoIn = new Input(new ByteBufferInputStream(nioBuffer));

      Object msg = kryos.get().readClassAndObject(kryoIn);
      LOG.debug("Decoded message of type {} ({} bytes)",
          msg != null ? msg.getClass().getName() : msg, msgSize);
      out.add(msg);
    } finally {
      in.skipBytes(msgSize);
    }
  }

  @Override
  protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf buf)
      throws Exception {
    ByteArrayOutputStream bytes = new ByteArrayOutputStream();
    Output kryoOut = new Output(bytes);
    kryos.get().writeClassAndObject(kryoOut, msg);
    kryoOut.flush();

    byte[] msgData = maybeEncrypt(bytes.toByteArray());
    LOG.debug("Encoded message of type {} ({} bytes)", msg.getClass().getName(), msgData.length);
    checkSize(msgData.length);

    buf.ensureWritable(msgData.length + 4);
    buf.writeInt(msgData.length);
    buf.writeBytes(msgData);
  }

  @Override
  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
    if (encryptionHandler != null) {
      encryptionHandler.dispose();
    }
    super.channelInactive(ctx);
  }

  private void checkSize(int msgSize) {
    Preconditions.checkArgument(msgSize > 0, "Message size (%s bytes) must be positive.", msgSize);
    Preconditions.checkArgument(maxMessageSize <= 0 || msgSize <= maxMessageSize,
        "Message (%s bytes) exceeds maximum allowed size (%s bytes).", msgSize, maxMessageSize);
  }

  private byte[] maybeEncrypt(byte[] data) throws Exception {
    return (encryptionHandler != null) ? encryptionHandler.wrap(data, 0, data.length) : data;
  }

  private ByteBuffer maybeDecrypt(ByteBuffer data) throws Exception {
    if (encryptionHandler != null) {
      byte[] encrypted;
      int len = data.limit() - data.position();
      int offset;
      if (data.hasArray()) {
        encrypted = data.array();
        offset = data.position() + data.arrayOffset();
        data.position(data.limit());
      } else {
        encrypted = new byte[len];
        offset = 0;
        data.get(encrypted);
      }
      return ByteBuffer.wrap(encryptionHandler.unwrap(encrypted, offset, len));
    } else {
      return data;
    }
  }

  void setEncryptionHandler(EncryptionHandler handler) {
    this.encryptionHandler = handler;
  }

  interface EncryptionHandler {

    byte[] wrap(byte[] data, int offset, int len) throws IOException;

    byte[] unwrap(byte[] data, int offset, int len) throws IOException;

    void dispose() throws IOException;

  }

}