All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.transport.InboundHandler Maven / Gradle / Ivy

There is a newer version: 8.13.2
Show newest version
/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.elasticsearch.transport;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.metrics.MeanMetric;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class InboundHandler {

    private static final Logger logger = LogManager.getLogger(InboundHandler.class);

    private final MeanMetric readBytesMetric = new MeanMetric();
    private final ThreadPool threadPool;
    private final OutboundHandler outboundHandler;
    private final CircuitBreakerService circuitBreakerService;
    private final InboundMessage.Reader reader;
    private final TransportHandshaker handshaker;
    private final TransportKeepAlive keepAlive;

    private final Transport.ResponseHandlers responseHandlers = new Transport.ResponseHandlers();
    private volatile Map> requestHandlers = Collections.emptyMap();
    private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

    InboundHandler(ThreadPool threadPool, OutboundHandler outboundHandler, InboundMessage.Reader reader,
                   CircuitBreakerService circuitBreakerService, TransportHandshaker handshaker, TransportKeepAlive keepAlive) {
        this.threadPool = threadPool;
        this.outboundHandler = outboundHandler;
        this.circuitBreakerService = circuitBreakerService;
        this.reader = reader;
        this.handshaker = handshaker;
        this.keepAlive = keepAlive;
    }

    synchronized  void registerRequestHandler(RequestHandlerRegistry reg) {
        if (requestHandlers.containsKey(reg.getAction())) {
            throw new IllegalArgumentException("transport handlers for action " + reg.getAction() + " is already registered");
        }
        requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap();
    }

    @SuppressWarnings("unchecked")
    final  RequestHandlerRegistry getRequestHandler(String action) {
        return (RequestHandlerRegistry) requestHandlers.get(action);
    }

    final Transport.ResponseHandlers getResponseHandlers() {
        return responseHandlers;
    }

    MeanMetric getReadBytes() {
        return readBytesMetric;
    }

    void setMessageListener(TransportMessageListener listener) {
        if (messageListener == TransportMessageListener.NOOP_LISTENER) {
            messageListener = listener;
        } else {
            throw new IllegalStateException("Cannot set message listener twice");
        }
    }

    void inboundMessage(TcpChannel channel, BytesReference message) throws Exception {
        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
        TransportLogger.logInboundMessage(channel, message);
        readBytesMetric.inc(message.length() + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE);
        // Message length of 0 is a ping
        if (message.length() != 0) {
            messageReceived(message, channel);
        } else {
            keepAlive.receiveKeepAlive(channel);
        }
    }

    private void messageReceived(BytesReference reference, TcpChannel channel) throws IOException {
        InetSocketAddress remoteAddress = channel.getRemoteAddress();

        ThreadContext threadContext = threadPool.getThreadContext();
        try (ThreadContext.StoredContext existing = threadContext.stashContext();
             InboundMessage message = reader.deserialize(reference)) {
            // Place the context with the headers from the message
            message.getStoredContext().restore();
            threadContext.putTransient("_remote_address", remoteAddress);
            if (message.isRequest()) {
                handleRequest(channel, (InboundMessage.Request) message, reference.length());
            } else {
                final TransportResponseHandler handler;
                long requestId = message.getRequestId();
                if (message.isHandshake()) {
                    handler = handshaker.removeHandlerForHandshake(requestId);
                } else {
                    TransportResponseHandler theHandler =
                        responseHandlers.onResponseReceived(requestId, messageListener);
                    if (theHandler == null && message.isError()) {
                        handler = handshaker.removeHandlerForHandshake(requestId);
                    } else {
                        handler = theHandler;
                    }
                }
                // ignore if its null, the service logs it
                if (handler != null) {
                    if (message.isError()) {
                        handlerResponseError(message.getStreamInput(), handler);
                    } else {
                        handleResponse(remoteAddress, message.getStreamInput(), handler);
                    }
                    // Check the entire message has been read
                    final int nextByte = message.getStreamInput().read();
                    // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
                    if (nextByte != -1) {
                        throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler ["
                            + handler + "], error [" + message.isError() + "]; resetting");
                    }
                }
            }
        }
    }

    private  void handleRequest(TcpChannel channel, InboundMessage.Request message, int messageLengthBytes) {
        final Set features = message.getFeatures();
        final String action = message.getActionName();
        final long requestId = message.getRequestId();
        final StreamInput stream = message.getStreamInput();
        final Version version = message.getVersion();
        TransportChannel transportChannel = null;
        try {
            messageListener.onRequestReceived(requestId, action);
            if (message.isHandshake()) {
                handshaker.handleHandshake(version, features, channel, requestId, stream);
            } else {
                final RequestHandlerRegistry reg = getRequestHandler(action);
                if (reg == null) {
                    throw new ActionNotFoundTransportException(action);
                }
                CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
                if (reg.canTripCircuitBreaker()) {
                    breaker.addEstimateBytesAndMaybeBreak(messageLengthBytes, "");
                } else {
                    breaker.addWithoutBreaking(messageLengthBytes);
                }
                transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
                    circuitBreakerService, messageLengthBytes, message.isCompress());
                final T request = reg.newRequest(stream);
                request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
                final int nextByte = stream.read();
                // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker
                if (nextByte != -1) {
                    throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action
                        + "], available [" + stream.available() + "]; resetting");
                }
                threadPool.executor(reg.getExecutor()).execute(new RequestHandler<>(reg, request, transportChannel));
            }
        } catch (Exception e) {
            // the circuit breaker tripped
            if (transportChannel == null) {
                transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
                    circuitBreakerService, 0, message.isCompress());
            }
            try {
                transportChannel.sendResponse(e);
            } catch (IOException inner) {
                inner.addSuppressed(e);
                logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner);
            }
        }
    }

    private  void handleResponse(InetSocketAddress remoteAddress, final StreamInput stream,
                                                              final TransportResponseHandler handler) {
        final T response;
        try {
            response = handler.read(stream);
            response.remoteAddress(new TransportAddress(remoteAddress));
        } catch (Exception e) {
            handleException(handler, new TransportSerializationException(
                "Failed to deserialize response from handler [" + handler.getClass().getName() + "]", e));
            return;
        }
        threadPool.executor(handler.executor()).execute(new AbstractRunnable() {
            @Override
            public void onFailure(Exception e) {
                handleException(handler, new ResponseHandlerFailureTransportException(e));
            }

            @Override
            protected void doRun() {
                handler.handleResponse(response);
            }
        });
    }

    private void handlerResponseError(StreamInput stream, final TransportResponseHandler handler) {
        Exception error;
        try {
            error = stream.readException();
        } catch (Exception e) {
            error = new TransportSerializationException("Failed to deserialize exception response from stream", e);
        }
        handleException(handler, error);
    }

    private void handleException(final TransportResponseHandler handler, Throwable error) {
        if (!(error instanceof RemoteTransportException)) {
            error = new RemoteTransportException(error.getMessage(), error);
        }
        final RemoteTransportException rtx = (RemoteTransportException) error;
        threadPool.executor(handler.executor()).execute(() -> {
            try {
                handler.handleException(rtx);
            } catch (Exception e) {
                logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e);
            }
        });
    }

    private static class RequestHandler extends AbstractRunnable {
        private final RequestHandlerRegistry reg;
        private final T request;
        private final TransportChannel transportChannel;

        RequestHandler(RequestHandlerRegistry reg, T request, TransportChannel transportChannel) {
            this.reg = reg;
            this.request = request;
            this.transportChannel = transportChannel;
        }

        @Override
        protected void doRun() throws Exception {
            reg.processMessageReceived(request, transportChannel);
        }

        @Override
        public boolean isForceExecution() {
            return reg.isForceExecution();
        }

        @Override
        public void onFailure(Exception e) {
            try {
                transportChannel.sendResponse(e);
            } catch (Exception inner) {
                inner.addSuppressed(e);
                logger.warn(() -> new ParameterizedMessage(
                    "Failed to send error message back to client for action [{}]", reg.getAction()), inner);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy