org.elasticsearch.transport.netty.MessageChannelHandler Maven / Gradle / Ivy
The newest version!
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport.netty;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.compress.Compressor;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.compress.NotCompressedException;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.*;
import org.elasticsearch.transport.support.TransportStatus;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.*;
import java.io.IOException;
import java.net.InetSocketAddress;
/**
* A handler (must be the last one!) that does size based frame decoding and forwards the actual message
* to the relevant action.
*/
public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
protected final ESLogger logger;
protected final ThreadPool threadPool;
protected final TransportServiceAdapter transportServiceAdapter;
protected final NettyTransport transport;
protected final String profileName;
public MessageChannelHandler(NettyTransport transport, ESLogger logger, String profileName) {
this.threadPool = transport.threadPool();
this.transportServiceAdapter = transport.transportServiceAdapter();
this.transport = transport;
this.logger = logger;
this.profileName = profileName;
}
@Override
public void writeComplete(ChannelHandlerContext ctx, WriteCompletionEvent e) throws Exception {
transportServiceAdapter.sent(e.getWrittenAmount());
super.writeComplete(ctx, e);
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
Transports.assertTransportThread();
Object m = e.getMessage();
if (!(m instanceof ChannelBuffer)) {
ctx.sendUpstream(e);
return;
}
ChannelBuffer buffer = (ChannelBuffer) m;
Marker marker = new Marker(buffer);
int size = marker.messageSizeWithRemainingHeaders();
transportServiceAdapter.received(marker.messageSizeWithAllHeaders());
// we have additional bytes to read, outside of the header
boolean hasMessageBytesToRead = marker.messageSize() != 0;
// netty always copies a buffer, either in NioWorker in its read handler, where it copies to a fresh
// buffer, or in the cumulation buffer, which is cleaned each time
StreamInput streamIn = ChannelBufferStreamInputFactory.create(buffer, size);
boolean success = false;
try {
long requestId = streamIn.readLong();
byte status = streamIn.readByte();
Version version = Version.fromId(streamIn.readInt());
if (TransportStatus.isCompress(status) && hasMessageBytesToRead && buffer.readable()) {
Compressor compressor;
try {
compressor = CompressorFactory.compressor(buffer);
} catch (NotCompressedException ex) {
int maxToRead = Math.min(buffer.readableBytes(), 10);
int offset = buffer.readerIndex();
StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [").append(maxToRead).append("] content bytes out of [").append(buffer.readableBytes()).append("] readable bytes with message size [").append(size).append("] ").append("] are [");
for (int i = 0; i < maxToRead; i++) {
sb.append(buffer.getByte(offset + i)).append(",");
}
sb.append("]");
throw new IllegalStateException(sb.toString());
}
streamIn = compressor.streamInput(streamIn);
}
if (version.onOrAfter(Version.CURRENT.minimumCompatibilityVersion()) == false || version.major != Version.CURRENT.major) {
throw new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" +Version.CURRENT.minimumCompatibilityVersion() + "]");
}
streamIn.setVersion(version);
if (TransportStatus.isRequest(status)) {
handleRequest(ctx.getChannel(), marker, streamIn, requestId, size, version);
} else {
TransportResponseHandler> handler = transportServiceAdapter.onResponseReceived(requestId);
// ignore if its null, the adapter logs it
if (handler != null) {
if (TransportStatus.isError(status)) {
handlerResponseError(streamIn, handler);
} else {
handleResponse(ctx.getChannel(), streamIn, handler);
}
marker.validateResponse(streamIn, requestId, handler, TransportStatus.isError(status));
}
}
success = true;
} finally {
try {
if (success) {
IOUtils.close(streamIn);
} else {
IOUtils.closeWhileHandlingException(streamIn);
}
} finally {
// Set the expected position of the buffer, no matter what happened
buffer.readerIndex(marker.expectedReaderIndex());
}
}
}
protected void handleResponse(Channel channel, StreamInput buffer, final TransportResponseHandler handler) {
buffer = new NamedWriteableAwareStreamInput(buffer, transport.namedWriteableRegistry);
final TransportResponse response = handler.newInstance();
response.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
response.remoteAddress();
try {
response.readFrom(buffer);
} catch (Throwable e) {
handleException(handler, new TransportSerializationException("Failed to deserialize response of type [" + response.getClass().getName() + "]", e));
return;
}
try {
if (ThreadPool.Names.SAME.equals(handler.executor())) {
//noinspection unchecked
handler.handleResponse(response);
} else {
threadPool.executor(handler.executor()).execute(new ResponseHandler(handler, response));
}
} catch (Throwable e) {
handleException(handler, new ResponseHandlerFailureTransportException(e));
}
}
private void handlerResponseError(StreamInput buffer, final TransportResponseHandler handler) {
Throwable error;
try {
error = buffer.readThrowable();
} catch (Throwable e) {
error = new TransportSerializationException("Failed to deserialize exception response from stream", e);
}
handleException(handler, error);
}
private void handleException(final TransportResponseHandler handler, Throwable error) {
if (!(error instanceof RemoteTransportException)) {
error = new RemoteTransportException(error.getMessage(), error);
}
final RemoteTransportException rtx = (RemoteTransportException) error;
if (ThreadPool.Names.SAME.equals(handler.executor())) {
try {
handler.handleException(rtx);
} catch (Throwable e) {
logger.error("failed to handle exception response [{}]", e, handler);
}
} else {
threadPool.executor(handler.executor()).execute(new Runnable() {
@Override
public void run() {
try {
handler.handleException(rtx);
} catch (Throwable e) {
logger.error("failed to handle exception response [{}]", e, handler);
}
}
});
}
}
protected String handleRequest(Channel channel, Marker marker, StreamInput buffer, long requestId, int messageLengthBytes,
Version version) throws IOException {
buffer = new NamedWriteableAwareStreamInput(buffer, transport.namedWriteableRegistry);
final String action = buffer.readString();
transportServiceAdapter.onRequestReceived(requestId, action);
NettyTransportChannel transportChannel = null;
try {
final RequestHandlerRegistry reg = transportServiceAdapter.getRequestHandler(action);
if (reg == null) {
throw new ActionNotFoundTransportException(action);
}
if (reg.canTripCircuitBreaker()) {
transport.inFlightRequestsBreaker().addEstimateBytesAndMaybeBreak(messageLengthBytes, "");
} else {
transport.inFlightRequestsBreaker().addWithoutBreaking(messageLengthBytes);
}
transportChannel = new NettyTransportChannel(transport, transportServiceAdapter, action, channel,
requestId, version, profileName, messageLengthBytes);
final TransportRequest request = reg.newRequest();
request.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
request.readFrom(buffer);
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
validateRequest(marker, buffer, requestId, request, action);
if (ThreadPool.Names.SAME.equals(reg.getExecutor())) {
//noinspection unchecked
reg.processMessageReceived(request, transportChannel);
} else {
threadPool.executor(reg.getExecutor()).execute(new RequestHandler(reg, request, transportChannel));
}
} catch (Throwable e) {
// the circuit breaker tripped
if (transportChannel == null) {
transportChannel = new NettyTransportChannel(transport, transportServiceAdapter, action, channel,
requestId, version, profileName, 0);
}
try {
transportChannel.sendResponse(e);
} catch (IOException e1) {
logger.warn("Failed to send error message back to client for action [" + action + "]", e);
logger.warn("Actual Exception", e1);
}
}
return action;
}
// This template method is needed to inject custom error checking logic in tests.
protected void validateRequest(Marker marker, StreamInput buffer, long requestId, TransportRequest request, String action) throws IOException {
marker.validateRequest(buffer, requestId, action);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
transport.exceptionCaught(ctx, e);
}
class ResponseHandler implements Runnable {
private final TransportResponseHandler handler;
private final TransportResponse response;
public ResponseHandler(TransportResponseHandler handler, TransportResponse response) {
this.handler = handler;
this.response = response;
}
@SuppressWarnings({"unchecked"})
@Override
public void run() {
try {
handler.handleResponse(response);
} catch (Throwable e) {
handleException(handler, new ResponseHandlerFailureTransportException(e));
}
}
}
class RequestHandler extends AbstractRunnable {
private final RequestHandlerRegistry reg;
private final TransportRequest request;
private final NettyTransportChannel transportChannel;
public RequestHandler(RequestHandlerRegistry reg, TransportRequest request, NettyTransportChannel transportChannel) {
this.reg = reg;
this.request = request;
this.transportChannel = transportChannel;
}
@SuppressWarnings({"unchecked"})
@Override
protected void doRun() throws Exception {
reg.processMessageReceived(request, transportChannel);
}
@Override
public boolean isForceExecution() {
return reg.isForceExecution();
}
@Override
public void onFailure(Throwable e) {
if (transport.lifecycleState() == Lifecycle.State.STARTED) {
// we can only send a response transport is started....
try {
transportChannel.sendResponse(e);
} catch (Throwable e1) {
logger.warn("Failed to send error message back to client for action [" + reg.getAction() + "]", e1);
logger.warn("Actual Exception", e);
}
}
}
}
/**
* Internal helper class to store characteristic offsets of a buffer during processing
*/
protected static final class Marker {
private final ChannelBuffer buffer;
private final int remainingMessageSize;
private final int expectedReaderIndex;
public Marker(ChannelBuffer buffer) {
this.buffer = buffer;
// when this constructor is called, we have read already two parts of the message header: the marker bytes and the message
// message length (see SizeHeaderFrameDecoder). Hence we have to rewind the index for MESSAGE_LENGTH_SIZE bytes to read the
// remaining message length again.
this.remainingMessageSize = buffer.getInt(buffer.readerIndex() - NettyHeader.MESSAGE_LENGTH_SIZE);
this.expectedReaderIndex = buffer.readerIndex() + remainingMessageSize;
}
/**
* @return the number of bytes that have yet to be read from the buffer
*/
public int messageSizeWithRemainingHeaders() {
return remainingMessageSize;
}
/**
* @return the number in bytes for the message including all headers (even the ones that have been read from the buffer already)
*/
public int messageSizeWithAllHeaders() {
return remainingMessageSize + NettyHeader.MARKER_BYTES_SIZE + NettyHeader.MESSAGE_LENGTH_SIZE;
}
/**
* @return the number of bytes for the message itself (excluding all headers).
*/
public int messageSize() {
return messageSizeWithAllHeaders() - NettyHeader.HEADER_SIZE;
}
/**
* @return the expected index of the buffer's reader after the message has been consumed entirely.
*/
public int expectedReaderIndex() {
return expectedReaderIndex;
}
/**
* Validates that a request has been fully read (not too few bytes but also not too many bytes).
*
* @param stream A stream that is associated with the buffer that is tracked by this marker.
* @param requestId The current request id.
* @param action The currently executed action.
* @throws IOException Iff the stream could not be read.
* @throws IllegalStateException Iff the request has not been fully read.
*/
public void validateRequest(StreamInput stream, long requestId, String action) throws IOException {
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
if (nextByte != -1) {
throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action
+ "], readerIndex [" + buffer.readerIndex() + "] vs expected [" + expectedReaderIndex + "]; resetting");
}
if (buffer.readerIndex() < expectedReaderIndex) {
throw new IllegalStateException("Message is fully read (request), yet there are "
+ (expectedReaderIndex - buffer.readerIndex()) + " remaining bytes; resetting");
}
if (buffer.readerIndex() > expectedReaderIndex) {
throw new IllegalStateException(
"Message read past expected size (request) for requestId [" + requestId + "], action [" + action
+ "], readerIndex [" + buffer.readerIndex() + "] vs expected [" + expectedReaderIndex + "]; resetting");
}
}
/**
* Validates that a response has been fully read (not too few bytes but also not too many bytes).
*
* @param stream A stream that is associated with the buffer that is tracked by this marker.
* @param requestId The corresponding request id for this response.
* @param handler The current response handler.
* @param error Whether validate an error response.
* @throws IOException Iff the stream could not be read.
* @throws IllegalStateException Iff the request has not been fully read.
*/
public void validateResponse(StreamInput stream, long requestId,
TransportResponseHandler> handler, boolean error) throws IOException {
// Check the entire message has been read
final int nextByte = stream.read();
// calling read() is useful to make sure the message is fully read, even if there is an EOS marker
if (nextByte != -1) {
throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler ["
+ handler + "], error [" + error + "]; resetting");
}
if (buffer.readerIndex() < expectedReaderIndex) {
throw new IllegalStateException("Message is fully read (response), yet there are "
+ (expectedReaderIndex - buffer.readerIndex()) + " remaining bytes; resetting");
}
if (buffer.readerIndex() > expectedReaderIndex) {
throw new IllegalStateException("Message read past expected size (response) for requestId [" + requestId
+ "], handler [" + handler + "], error [" + error + "]; resetting");
}
}
}
}