io.grpc.internal.MessageDeframer Maven / Gradle / Ivy
/*
* Copyright 2014 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Codec;
import io.grpc.Decompressor;
import io.grpc.Status;
import java.io.Closeable;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Locale;
import java.util.zip.DataFormatException;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
/**
* Deframer for GRPC frames.
*
* This class is not thread-safe. Unless otherwise stated, all calls to public methods should be
* made in the deframing thread.
*/
@NotThreadSafe
public class MessageDeframer implements Closeable, Deframer {
private static final int HEADER_LENGTH = 5;
private static final int COMPRESSED_FLAG_MASK = 1;
private static final int RESERVED_MASK = 0xFE;
private static final int MAX_BUFFER_SIZE = 1024 * 1024 * 2;
/**
* A listener of deframing events. These methods will be invoked from the deframing thread.
*/
public interface Listener {
/**
* Called when the given number of bytes has been read from the input source of the deframer.
* This is typically used to indicate to the underlying transport that more data can be
* accepted.
*
* @param numBytes the number of bytes read from the deframer's input source.
*/
void bytesRead(int numBytes);
/**
* Called to deliver the next complete message.
*
* @param producer single message producer wrapping the message.
*/
void messagesAvailable(StreamListener.MessageProducer producer);
/**
* Called when the deframer closes.
*
* @param hasPartialMessage whether the deframer contained an incomplete message at closing.
*/
void deframerClosed(boolean hasPartialMessage);
/**
* Called when a {@link #deframe(ReadableBuffer)} operation failed.
*
* @param cause the actual failure
*/
void deframeFailed(Throwable cause);
}
private enum State {
HEADER, BODY
}
private Listener listener;
private int maxInboundMessageSize;
private final StatsTraceContext statsTraceCtx;
private final TransportTracer transportTracer;
private Decompressor decompressor;
private GzipInflatingBuffer fullStreamDecompressor;
private byte[] inflatedBuffer;
private int inflatedIndex;
private State state = State.HEADER;
private int requiredLength = HEADER_LENGTH;
private boolean compressedFlag;
private CompositeReadableBuffer nextFrame;
private CompositeReadableBuffer unprocessed = new CompositeReadableBuffer();
private long pendingDeliveries;
private boolean inDelivery = false;
private int currentMessageSeqNo = -1;
private int inboundBodyWireSize;
private boolean closeWhenComplete = false;
private volatile boolean stopDelivery = false;
/**
* Create a deframer.
*
* @param listener listener for deframer events.
* @param decompressor the compression used if a compressed frame is encountered, with
* {@code NONE} meaning unsupported
* @param maxMessageSize the maximum allowed size for received messages.
*/
public MessageDeframer(
Listener listener,
Decompressor decompressor,
int maxMessageSize,
StatsTraceContext statsTraceCtx,
TransportTracer transportTracer) {
this.listener = checkNotNull(listener, "sink");
this.decompressor = checkNotNull(decompressor, "decompressor");
this.maxInboundMessageSize = maxMessageSize;
this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
this.transportTracer = checkNotNull(transportTracer, "transportTracer");
}
void setListener(Listener listener) {
this.listener = listener;
}
@Override
public void setMaxInboundMessageSize(int messageSize) {
maxInboundMessageSize = messageSize;
}
@Override
public void setDecompressor(Decompressor decompressor) {
checkState(fullStreamDecompressor == null, "Already set full stream decompressor");
this.decompressor = checkNotNull(decompressor, "Can't pass an empty decompressor");
}
@Override
public void setFullStreamDecompressor(GzipInflatingBuffer fullStreamDecompressor) {
checkState(decompressor == Codec.Identity.NONE, "per-message decompressor already set");
checkState(this.fullStreamDecompressor == null, "full stream decompressor already set");
this.fullStreamDecompressor =
checkNotNull(fullStreamDecompressor, "Can't pass a null full stream decompressor");
unprocessed = null;
}
@Override
public void request(int numMessages) {
checkArgument(numMessages > 0, "numMessages must be > 0");
if (isClosed()) {
return;
}
pendingDeliveries += numMessages;
deliver();
}
@Override
public void deframe(ReadableBuffer data) {
checkNotNull(data, "data");
boolean needToCloseData = true;
try {
if (!isClosedOrScheduledToClose()) {
if (fullStreamDecompressor != null) {
fullStreamDecompressor.addGzippedBytes(data);
} else {
unprocessed.addBuffer(data);
}
needToCloseData = false;
deliver();
}
} finally {
if (needToCloseData) {
data.close();
}
}
}
@Override
public void closeWhenComplete() {
if (isClosed()) {
return;
} else if (isStalled()) {
close();
} else {
closeWhenComplete = true;
}
}
/**
* Sets a flag to interrupt delivery of any currently queued messages. This may be invoked outside
* of the deframing thread, and must be followed by a call to {@link #close()} in the deframing
* thread. Without a subsequent call to {@link #close()}, the deframer may hang waiting for
* additional messages before noticing that the {@code stopDelivery} flag has been set.
*/
void stopDelivery() {
stopDelivery = true;
}
boolean hasPendingDeliveries() {
return pendingDeliveries != 0;
}
@Override
public void close() {
if (isClosed()) {
return;
}
boolean hasPartialMessage = nextFrame != null && nextFrame.readableBytes() > 0;
try {
if (fullStreamDecompressor != null) {
hasPartialMessage = hasPartialMessage || fullStreamDecompressor.hasPartialData();
fullStreamDecompressor.close();
}
if (unprocessed != null) {
unprocessed.close();
}
if (nextFrame != null) {
nextFrame.close();
}
} finally {
fullStreamDecompressor = null;
unprocessed = null;
nextFrame = null;
}
listener.deframerClosed(hasPartialMessage);
}
/**
* Indicates whether or not this deframer has been closed.
*/
public boolean isClosed() {
return unprocessed == null && fullStreamDecompressor == null;
}
/** Returns true if this deframer has already been closed or scheduled to close. */
private boolean isClosedOrScheduledToClose() {
return isClosed() || closeWhenComplete;
}
private boolean isStalled() {
if (fullStreamDecompressor != null) {
return fullStreamDecompressor.isStalled();
} else {
return unprocessed.readableBytes() == 0;
}
}
/**
* Reads and delivers as many messages to the listener as possible.
*/
private void deliver() {
// We can have reentrancy here when using a direct executor, triggered by calls to
// request more messages. This is safe as we simply loop until pendingDelivers = 0
if (inDelivery) {
return;
}
inDelivery = true;
try {
// Process the uncompressed bytes.
while (!stopDelivery && pendingDeliveries > 0 && readRequiredBytes()) {
switch (state) {
case HEADER:
processHeader();
break;
case BODY:
// Read the body and deliver the message.
processBody();
// Since we've delivered a message, decrement the number of pending
// deliveries remaining.
pendingDeliveries--;
break;
default:
throw new AssertionError("Invalid state: " + state);
}
}
if (stopDelivery) {
close();
return;
}
/*
* We are stalled when there are no more bytes to process. This allows delivering errors as
* soon as the buffered input has been consumed, independent of whether the application
* has requested another message. At this point in the function, either all frames have been
* delivered, or unprocessed is empty. If there is a partial message, it will be inside next
* frame and not in unprocessed. If there is extra data but no pending deliveries, it will
* be in unprocessed.
*/
if (closeWhenComplete && isStalled()) {
close();
}
} finally {
inDelivery = false;
}
}
/**
* Attempts to read the required bytes into nextFrame.
*
* @return {@code true} if all of the required bytes have been read.
*/
private boolean readRequiredBytes() {
int totalBytesRead = 0;
int deflatedBytesRead = 0;
try {
if (nextFrame == null) {
nextFrame = new CompositeReadableBuffer();
}
// Read until the buffer contains all the required bytes.
int missingBytes;
while ((missingBytes = requiredLength - nextFrame.readableBytes()) > 0) {
if (fullStreamDecompressor != null) {
try {
if (inflatedBuffer == null || inflatedIndex == inflatedBuffer.length) {
inflatedBuffer = new byte[Math.min(missingBytes, MAX_BUFFER_SIZE)];
inflatedIndex = 0;
}
int bytesToRead = Math.min(missingBytes, inflatedBuffer.length - inflatedIndex);
int n = fullStreamDecompressor.inflateBytes(inflatedBuffer, inflatedIndex, bytesToRead);
totalBytesRead += fullStreamDecompressor.getAndResetBytesConsumed();
deflatedBytesRead += fullStreamDecompressor.getAndResetDeflatedBytesConsumed();
if (n == 0) {
// No more inflated data is available.
return false;
}
nextFrame.addBuffer(ReadableBuffers.wrap(inflatedBuffer, inflatedIndex, n));
inflatedIndex += n;
} catch (IOException e) {
throw new RuntimeException(e);
} catch (DataFormatException e) {
throw new RuntimeException(e);
}
} else {
if (unprocessed.readableBytes() == 0) {
// No more data is available.
return false;
}
int toRead = Math.min(missingBytes, unprocessed.readableBytes());
totalBytesRead += toRead;
nextFrame.addBuffer(unprocessed.readBytes(toRead));
}
}
return true;
} finally {
if (totalBytesRead > 0) {
listener.bytesRead(totalBytesRead);
if (state == State.BODY) {
if (fullStreamDecompressor != null) {
// With compressed streams, totalBytesRead can include gzip header and trailer metadata
statsTraceCtx.inboundWireSize(deflatedBytesRead);
inboundBodyWireSize += deflatedBytesRead;
} else {
statsTraceCtx.inboundWireSize(totalBytesRead);
inboundBodyWireSize += totalBytesRead;
}
}
}
}
}
/**
* Processes the GRPC compression header which is composed of the compression flag and the outer
* frame length.
*/
private void processHeader() {
int type = nextFrame.readUnsignedByte();
if ((type & RESERVED_MASK) != 0) {
throw Status.INTERNAL.withDescription(
"gRPC frame header malformed: reserved bits not zero")
.asRuntimeException();
}
compressedFlag = (type & COMPRESSED_FLAG_MASK) != 0;
// Update the required length to include the length of the frame.
requiredLength = nextFrame.readInt();
if (requiredLength < 0 || requiredLength > maxInboundMessageSize) {
throw Status.RESOURCE_EXHAUSTED.withDescription(
String.format(Locale.US, "gRPC message exceeds maximum size %d: %d",
maxInboundMessageSize, requiredLength))
.asRuntimeException();
}
currentMessageSeqNo++;
statsTraceCtx.inboundMessage(currentMessageSeqNo);
transportTracer.reportMessageReceived();
// Continue reading the frame body.
state = State.BODY;
}
/**
* Processes the GRPC message body, which depending on frame header flags may be compressed.
*/
private void processBody() {
// There is no reliable way to get the uncompressed size per message when it's compressed,
// because the uncompressed bytes are provided through an InputStream whose total size is
// unknown until all bytes are read, and we don't know when it happens.
statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize,
(compressedFlag || fullStreamDecompressor != null) ? -1 : inboundBodyWireSize);
inboundBodyWireSize = 0;
InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody();
nextFrame.touch();
nextFrame = null;
listener.messagesAvailable(new SingleMessageProducer(stream));
// Done with this frame, begin processing the next header.
state = State.HEADER;
requiredLength = HEADER_LENGTH;
}
private InputStream getUncompressedBody() {
statsTraceCtx.inboundUncompressedSize(nextFrame.readableBytes());
return ReadableBuffers.openStream(nextFrame, true);
}
private InputStream getCompressedBody() {
if (decompressor == Codec.Identity.NONE) {
throw Status.INTERNAL.withDescription(
"Can't decode compressed gRPC message as compression not configured")
.asRuntimeException();
}
try {
// Enforce the maxMessageSize limit on the returned stream.
InputStream unlimitedStream =
decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
return new SizeEnforcingInputStream(
unlimitedStream, maxInboundMessageSize, statsTraceCtx);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
*/
@VisibleForTesting
static final class SizeEnforcingInputStream extends FilterInputStream {
private final int maxMessageSize;
private final StatsTraceContext statsTraceCtx;
private long maxCount;
private long count;
private long mark = -1;
SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx) {
super(in);
this.maxMessageSize = maxMessageSize;
this.statsTraceCtx = statsTraceCtx;
}
@Override
public int read() throws IOException {
int result = in.read();
if (result != -1) {
count++;
}
verifySize();
reportCount();
return result;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
int result = in.read(b, off, len);
if (result != -1) {
count += result;
}
verifySize();
reportCount();
return result;
}
@Override
public long skip(long n) throws IOException {
long result = in.skip(n);
count += result;
verifySize();
reportCount();
return result;
}
@Override
public synchronized void mark(int readlimit) {
in.mark(readlimit);
mark = count;
// it's okay to mark even if mark isn't supported, as reset won't work
}
@Override
public synchronized void reset() throws IOException {
if (!in.markSupported()) {
throw new IOException("Mark not supported");
}
if (mark == -1) {
throw new IOException("Mark not set");
}
in.reset();
count = mark;
}
private void reportCount() {
if (count > maxCount) {
statsTraceCtx.inboundUncompressedSize(count - maxCount);
maxCount = count;
}
}
private void verifySize() {
if (count > maxMessageSize) {
throw Status.RESOURCE_EXHAUSTED
.withDescription("Decompressed gRPC message exceeds maximum size " + maxMessageSize)
.asRuntimeException();
}
}
}
private static class SingleMessageProducer implements StreamListener.MessageProducer {
private InputStream message;
private SingleMessageProducer(InputStream message) {
this.message = message;
}
@Nullable
@Override
public InputStream next() {
InputStream messageToReturn = message;
message = null;
return messageToReturn;
}
}
}