alluxio.grpc.GrpcSerializationUtils Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of alluxio-core-common Show documentation
Show all versions of alluxio-core-common Show documentation
Common utilities shared in Alluxio core modules
/*
* The Alluxio Open Foundation licenses this work under the Apache License, version 2.0
* (the "License"). You may not use this work except in compliance with the License, which is
* available at www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied, as more fully set forth in the License.
*
* See the NOTICE file distributed with this work for information regarding copyright ownership.
*/
package alluxio.grpc;
import io.grpc.CallOptions;
import io.grpc.MethodDescriptor;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.internal.CompositeReadableBuffer;
import io.grpc.internal.ReadableBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
/**
* Utilities for gRPC message serialization.
*/
public class GrpcSerializationUtils {
public static final CallOptions.Key OVERRIDDEN_METHOD_DESCRIPTOR =
CallOptions.Key.create("overridden method descriptor");
private static final Logger LOG = LoggerFactory.getLogger(GrpcSerializationUtils.class);
private static final int TAG_TYPE_BITS = 3;
private static final String BUFFER_INPUT_STREAM_CLASS_NAME =
"io.grpc.internal.ReadableBuffers$BufferInputStream";
private static final String BUFFER_FIELD_NAME = "buffer";
private static final String BUFFERS_FIELD_NAME = "buffers";
private static final String NETTY_WRITABLE_BUFFER_CLASS_NAME =
"io.grpc.netty.NettyWritableBuffer";
private static final String NETTY_READABLE_BUFFER_CLASS_NAME =
"io.grpc.netty.NettyReadableBuffer";
private static final String BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME =
"io.grpc.internal.MessageFramer$BufferChainOutputStream";
private static final String BUFFER_LIST_FIELD_NAME = "bufferList";
private static final String CURRENT_FIELD_NAME = "current";
private static Constructor> sNettyWritableBufferConstructor;
private static Field sBufferList;
private static Field sCompositeBuffers = null;
private static Field sCurrent;
private static Field sReadableBufferField = null;
private static Field sReadableByteBuf = null;
private static boolean sZeroCopySendSupported = true;
private static boolean sZeroCopyReceiveSupported = true;
static {
try {
sReadableBufferField = getPrivateField(BUFFER_INPUT_STREAM_CLASS_NAME, BUFFER_FIELD_NAME);
} catch (Exception e) {
LOG.warn("Cannot get gRPC input stream buffer, zero copy send will be disabled.", e);
sZeroCopySendSupported = false;
}
try {
sNettyWritableBufferConstructor =
getPrivateConstructor(NETTY_WRITABLE_BUFFER_CLASS_NAME, ByteBuf.class);
sBufferList = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, BUFFER_LIST_FIELD_NAME);
sCurrent = getPrivateField(BUFFER_CHAIN_OUTPUT_STREAM_CLASS_NAME, CURRENT_FIELD_NAME);
sCompositeBuffers =
getPrivateField(CompositeReadableBuffer.class.getName(), BUFFERS_FIELD_NAME);
sReadableByteBuf = getPrivateField(NETTY_READABLE_BUFFER_CLASS_NAME, BUFFER_FIELD_NAME);
} catch (Exception e) {
LOG.warn("Cannot get gRPC output stream buffer, zero copy receive will be disabled.", e);
sZeroCopyReceiveSupported = false;
}
}
private static Field getPrivateField(String className, String fieldName)
throws NoSuchFieldException, ClassNotFoundException {
Class> declaringClass = Class.forName(className);
Field field = declaringClass.getDeclaredField(fieldName);
field.setAccessible(true);
return field;
}
private static Constructor> getPrivateConstructor(String className, Class> ...parameterTypes)
throws ClassNotFoundException, NoSuchMethodException {
Class> declaringClass = Class.forName(className);
Constructor> constructor = declaringClass.getDeclaredConstructor(parameterTypes);
constructor.setAccessible(true);
return constructor;
}
/**
* Makes a gRPC tag for a field.
*
* @param fieldNumber field number
* @param wireType wire type of the field
* @return the gRPC tag
*/
public static int makeTag(final int fieldNumber, final int wireType) {
// This is a public version of WireFormat.makeTag.
return (fieldNumber << TAG_TYPE_BITS) | wireType;
}
/**
* Gets a buffer directly from a gRPC input stream.
*
* @param stream the input stream
* @return the raw data buffer
*/
public static ReadableBuffer getBufferFromStream(InputStream stream) {
if (!sZeroCopyReceiveSupported
|| !stream.getClass().equals(sReadableBufferField.getDeclaringClass())) {
return null;
}
try {
return (ReadableBuffer) sReadableBufferField.get(stream);
} catch (Exception e) {
LOG.warn("Failed to get data buffer from stream.", e);
return null;
}
}
/**
* Gets a Netty buffer directly from a gRPC ReadableBuffer.
*
* @param buffer the input buffer
* @return the raw ByteBuf, or null if the ByteBuf cannot be extracted
*/
public static ByteBuf getByteBufFromReadableBuffer(ReadableBuffer buffer) {
if (!sZeroCopyReceiveSupported) {
return null;
}
try {
if (buffer instanceof CompositeReadableBuffer) {
Queue buffers = (Queue) sCompositeBuffers.get(buffer);
if (buffers.size() == 1) {
return getByteBufFromReadableBuffer(buffers.peek());
} else {
CompositeByteBuf buf = PooledByteBufAllocator.DEFAULT.compositeBuffer();
for (ReadableBuffer readableBuffer : buffers) {
ByteBuf subBuffer = getByteBufFromReadableBuffer(readableBuffer);
if (subBuffer == null) {
return null;
}
buf.addComponent(true, subBuffer);
}
return buf;
}
} else if (buffer.getClass().equals(sReadableByteBuf.getDeclaringClass())) {
return (ByteBuf) sReadableByteBuf.get(buffer);
}
} catch (Exception e) {
LOG.warn("Failed to get data buffer from stream: {}.", e.toString());
return null;
}
return null;
}
/**
* Add the given buffers directly to the gRPC output stream.
*
* @param buffers the buffers to be added
* @param stream the output stream
* @return whether the buffers are added successfully
*/
public static boolean addBuffersToStream(ByteBuf[] buffers, OutputStream stream) {
if (!sZeroCopySendSupported || !stream.getClass().equals(sBufferList.getDeclaringClass())) {
return false;
}
try {
if (sCurrent.get(stream) != null) {
return false;
}
for (ByteBuf buffer : buffers) {
Object nettyBuffer = sNettyWritableBufferConstructor.newInstance(buffer);
List list = (List) sBufferList.get(stream);
list.add(nettyBuffer);
buffer.retain();
sCurrent.set(stream, nettyBuffer);
}
return true;
} catch (Exception e) {
LOG.warn("Failed to add data buffer to stream: {}.", e.toString());
return false;
}
}
/**
* Creates a service definition that uses custom marshallers.
*
* @param service the service to intercept
* @param marshallers a map that specifies which marshaller to use for each method
* @return the new service definition
*/
public static ServerServiceDefinition overrideMethods(
final ServerServiceDefinition service,
final Map marshallers) {
List> newMethods = new ArrayList>();
List> newDescriptors = new ArrayList>();
// intercepts the descriptors
for (final ServerMethodDefinition, ?> definition : service.getMethods()) {
ServerMethodDefinition, ?> newMethod = interceptMethod(definition, marshallers);
newDescriptors.add(newMethod.getMethodDescriptor());
newMethods.add(newMethod);
}
// builds the new service descriptor
final ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition
.builder(new ServiceDescriptor(service.getServiceDescriptor().getName(), newDescriptors));
// creates the new service definition
for (ServerMethodDefinition, ?> definition : newMethods) {
serviceBuilder.addMethod(definition);
}
return serviceBuilder.build();
}
private static ServerMethodDefinition interceptMethod(
final ServerMethodDefinition definition,
final Map newMethods) {
MethodDescriptor descriptor = definition.getMethodDescriptor();
MethodDescriptor newMethod = newMethods.get(descriptor);
if (newMethod != null) {
return ServerMethodDefinition.create(newMethod, definition.getServerCallHandler());
}
return definition;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy