All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.transport.InboundHandler Maven / Gradle / Ivy

There is a newer version: 8.13.4
Show newest version
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0 and the Server Side Public License, v 1; you may not use this file except
 * in compliance with, at your election, the Elastic License 2.0 or the Server
 * Side Public License, v 1.
 */

package org.elasticsearch.transport;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.ByteBufferStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;

/**
 * Handles inbound messages by first deserializing a {@link TransportMessage} from an {@link InboundMessage} and then passing
 * it to the appropriate handler.
 */
public class InboundHandler {

    private static final Logger logger = LogManager.getLogger(InboundHandler.class);

    private final ThreadPool threadPool;
    private final OutboundHandler outboundHandler;
    private final NamedWriteableRegistry namedWriteableRegistry;
    private final TransportHandshaker handshaker;
    private final TransportKeepAlive keepAlive;
    private final Transport.ResponseHandlers responseHandlers;
    private final Transport.RequestHandlers requestHandlers;

    private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

    private volatile long slowLogThresholdMs = Long.MAX_VALUE;

    InboundHandler(
        ThreadPool threadPool,
        OutboundHandler outboundHandler,
        NamedWriteableRegistry namedWriteableRegistry,
        TransportHandshaker handshaker,
        TransportKeepAlive keepAlive,
        Transport.RequestHandlers requestHandlers,
        Transport.ResponseHandlers responseHandlers
    ) {
        this.threadPool = threadPool;
        this.outboundHandler = outboundHandler;
        this.namedWriteableRegistry = namedWriteableRegistry;
        this.handshaker = handshaker;
        this.keepAlive = keepAlive;
        this.requestHandlers = requestHandlers;
        this.responseHandlers = responseHandlers;
    }

    void setMessageListener(TransportMessageListener listener) {
        if (messageListener == TransportMessageListener.NOOP_LISTENER) {
            messageListener = listener;
        } else {
            throw new IllegalStateException("Cannot set message listener twice");
        }
    }

    void setSlowLogThreshold(TimeValue slowLogThreshold) {
        this.slowLogThresholdMs = slowLogThreshold.getMillis();
    }

    void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
        final long startTime = threadPool.relativeTimeInMillis();
        channel.getChannelStats().markAccessed(startTime);
        TransportLogger.logInboundMessage(channel, message);

        if (message.isPing()) {
            keepAlive.receiveKeepAlive(channel);
        } else {
            messageReceived(channel, message, startTime);
        }
    }

    // Empty stream constant to avoid instantiating a new stream for empty messages.
    private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES));

    private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
        final InetSocketAddress remoteAddress = channel.getRemoteAddress();
        final Header header = message.getHeader();
        assert header.needsToReadVariableHeader() == false;

        ThreadContext threadContext = threadPool.getThreadContext();
        try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
            // Place the context with the headers from the message
            threadContext.setHeaders(header.getHeaders());
            threadContext.putTransient("_remote_address", remoteAddress);
            if (header.isRequest()) {
                handleRequest(channel, header, message);
            } else {
                // Responses do not support short circuiting currently
                assert message.isShortCircuit() == false;
                final TransportResponseHandler handler;
                long requestId = header.getRequestId();
                if (header.isHandshake()) {
                    handler = handshaker.removeHandlerForHandshake(requestId);
                } else {
                    TransportResponseHandler theHandler = responseHandlers.onResponseReceived(
                        requestId,
                        messageListener
                    );
                    if (theHandler == null && header.isError()) {
                        handler = handshaker.removeHandlerForHandshake(requestId);
                    } else {
                        handler = theHandler;
                    }
                }
                // ignore if its null, the service logs it
                if (handler != null) {
                    final StreamInput streamInput;
                    if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) {
                        streamInput = namedWriteableStream(message.openOrGetStreamInput());
                        assertRemoteVersion(streamInput, header.getVersion());
                        if (header.isError()) {
                            handlerResponseError(streamInput, handler);
                        } else {
                            handleResponse(remoteAddress, streamInput, handler);
                        }
                        // Check the entire message has been read
                        final int nextByte = streamInput.read();
                        // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
                        if (nextByte != -1) {
                            throw new IllegalStateException(
                                "Message not fully read (response) for requestId ["
                                    + requestId
                                    + "], handler ["
                                    + handler
                                    + "], error ["
                                    + header.isError()
                                    + "]; resetting"
                            );
                        }
                    } else {
                        assert header.isError() == false;
                        handleResponse(remoteAddress, EMPTY_STREAM_INPUT, handler);
                    }
                }
            }
        } finally {
            final long took = threadPool.relativeTimeInMillis() - startTime;
            final long logThreshold = slowLogThresholdMs;
            if (logThreshold > 0 && took > logThreshold) {
                logger.warn(
                    "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]",
                    message,
                    took,
                    logThreshold
                );
            }
        }
    }

    private  void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException {
        final String action = header.getActionName();
        final long requestId = header.getRequestId();
        final Version version = header.getVersion();
        if (header.isHandshake()) {
            messageListener.onRequestReceived(requestId, action);
            // Cannot short circuit handshakes
            assert message.isShortCircuit() == false;
            final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
            assertRemoteVersion(stream, header.getVersion());
            final TransportChannel transportChannel = new TcpTransportChannel(
                outboundHandler,
                channel,
                action,
                requestId,
                version,
                header.getFeatures(),
                header.getCompressionScheme(),
                header.isHandshake(),
                message.takeBreakerReleaseControl()
            );
            try {
                handshaker.handleHandshake(transportChannel, requestId, stream);
            } catch (Exception e) {
                if (Version.CURRENT.isCompatible(header.getVersion())) {
                    sendErrorResponse(action, transportChannel, e);
                } else {
                    logger.warn(
                        new ParameterizedMessage(
                            "could not send error response to handshake received on [{}] using wire format version [{}], closing channel",
                            channel,
                            header.getVersion()
                        ),
                        e
                    );
                    channel.close();
                }
            }
        } else {
            final TransportChannel transportChannel = new TcpTransportChannel(
                outboundHandler,
                channel,
                action,
                requestId,
                version,
                header.getFeatures(),
                header.getCompressionScheme(),
                header.isHandshake(),
                message.takeBreakerReleaseControl()
            );
            try {
                messageListener.onRequestReceived(requestId, action);
                if (message.isShortCircuit()) {
                    sendErrorResponse(action, transportChannel, message.getException());
                } else {
                    final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
                    assertRemoteVersion(stream, header.getVersion());
                    final RequestHandlerRegistry reg = requestHandlers.getHandler(action);
                    assert reg != null;
                    final T request = reg.newRequest(stream);
                    try {
                        request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                        // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
                        final int nextByte = stream.read();
                        // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
                        if (nextByte != -1) {
                            throw new IllegalStateException(
                                "Message not fully read (request) for requestId ["
                                    + requestId
                                    + "], action ["
                                    + action
                                    + "], available ["
                                    + stream.available()
                                    + "]; resetting"
                            );
                        }
                        final String executor = reg.getExecutor();
                        if (ThreadPool.Names.SAME.equals(executor)) {
                            try {
                                reg.processMessageReceived(request, transportChannel);
                            } catch (Exception e) {
                                sendErrorResponse(reg.getAction(), transportChannel, e);
                            }
                        } else {
                            boolean success = false;
                            request.incRef();
                            try {
                                threadPool.executor(executor).execute(new AbstractRunnable() {
                                    @Override
                                    protected void doRun() throws Exception {
                                        reg.processMessageReceived(request, transportChannel);
                                    }

                                    @Override
                                    public boolean isForceExecution() {
                                        return reg.isForceExecution();
                                    }

                                    @Override
                                    public void onFailure(Exception e) {
                                        sendErrorResponse(reg.getAction(), transportChannel, e);
                                    }

                                    @Override
                                    public void onAfter() {
                                        request.decRef();
                                    }
                                });
                                success = true;
                            } finally {
                                if (success == false) {
                                    request.decRef();
                                }
                            }
                        }
                    } finally {
                        request.decRef();
                    }
                }
            } catch (Exception e) {
                sendErrorResponse(action, transportChannel, e);
            }
        }
    }

    private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
        try {
            transportChannel.sendResponse(e);
        } catch (Exception inner) {
            inner.addSuppressed(e);
            logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner);
        }
    }

    private  void handleResponse(
        InetSocketAddress remoteAddress,
        final StreamInput stream,
        final TransportResponseHandler handler
    ) {
        final T response;
        try {
            response = handler.read(stream);
            response.remoteAddress(new TransportAddress(remoteAddress));
        } catch (Exception e) {
            final Exception serializationException = new TransportSerializationException(
                "Failed to deserialize response from handler [" + handler + "]",
                e
            );
            logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException);
            handleException(handler, serializationException);
            return;
        }
        final String executor = handler.executor();
        if (ThreadPool.Names.SAME.equals(executor)) {
            doHandleResponse(handler, response);
        } else {
            boolean success = false;
            try {
                threadPool.executor(executor).execute(() -> doHandleResponse(handler, response));
                success = true;
            } finally {
                if (success == false) {
                    response.decRef();
                }
            }
        }
    }

    private  void doHandleResponse(TransportResponseHandler handler, T response) {
        try {
            handler.handleResponse(response);
        } catch (Exception e) {
            handleException(handler, new ResponseHandlerFailureTransportException(e));
        } finally {
            response.decRef();
        }
    }

    private void handlerResponseError(StreamInput stream, final TransportResponseHandler handler) {
        Exception error;
        try {
            error = stream.readException();
        } catch (Exception e) {
            error = new TransportSerializationException(
                "Failed to deserialize exception response from stream for handler [" + handler + "]",
                e
            );
        }
        handleException(handler, error);
    }

    private void handleException(final TransportResponseHandler handler, Throwable error) {
        if ((error instanceof RemoteTransportException) == false) {
            error = new RemoteTransportException(error.getMessage(), error);
        }
        final RemoteTransportException rtx = (RemoteTransportException) error;
        threadPool.executor(handler.executor()).execute(() -> {
            try {
                handler.handleException(rtx);
            } catch (Exception e) {
                logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e);
            }
        });
    }

    private StreamInput namedWriteableStream(StreamInput delegate) {
        return new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry);
    }

    static void assertRemoteVersion(StreamInput in, Version version) {
        assert version.equals(in.getVersion()) : "Stream version [" + in.getVersion() + "] does not match version [" + version + "]";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy