All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.hekate.network.netty.NettyServerClient Maven / Gradle / Ivy

/*
 * Copyright 2022 The Hekate Project
 *
 * The Hekate Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package io.hekate.network.netty;

import io.hekate.codec.Codec;
import io.hekate.codec.CodecException;
import io.hekate.core.HekateIoException;
import io.hekate.core.internal.util.Utils;
import io.hekate.network.NetworkEndpoint;
import io.hekate.network.NetworkEndpointClosedException;
import io.hekate.network.NetworkFuture;
import io.hekate.network.NetworkSendCallback;
import io.hekate.network.NetworkServerHandler;
import io.hekate.network.NetworkTimeoutException;
import io.hekate.network.internal.NettyChannelSupport;
import io.hekate.network.netty.NetworkProtocol.HandshakeAccept;
import io.hekate.network.netty.NetworkProtocol.HandshakeReject;
import io.hekate.network.netty.NetworkProtocol.HandshakeRequest;
import io.hekate.network.netty.NetworkProtocol.Heartbeat;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.traffic.ChannelTrafficShapingHandler;
import io.netty.handler.traffic.TrafficCounter;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.ThrowableUtil;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class NettyServerClient extends ChannelInboundHandlerAdapter implements NetworkEndpoint, NettyChannelSupport {
    private static final NetworkEndpointClosedException WRITE_CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
        new NetworkEndpointClosedException("Connection closed."), NettyServerClient.class, "send()"
    );

    private final Map handlers;

    private final InetSocketAddress remoteAddress;

    private final InetSocketAddress localAddress;

    private final EventLoopGroup eventLoopGroup;

    private final NettyWriteQueue writeQueue = new NettyWriteQueue();

    private final boolean ssl;

    private final int hbInterval;

    private final int hbLossThreshold;

    private final boolean hbDisabled;

    private final GenericFutureListener> hbFlushListener;

    private Logger log = LoggerFactory.getLogger(NettyServerClient.class);

    private boolean debug = log.isDebugEnabled();

    private boolean trace = log.isTraceEnabled();

    private boolean hbFlushed = true;

    private int ignoreReadTimeouts;

    private NetworkServerHandler serverHandler;

    private NettyServerHandler handlerReg;

    private NettyMetricsSink metrics;

    private String protocol;

    private Codec codec;

    private EventLoop eventLoop;

    private boolean connectNotified;

    private volatile Object userContext;

    private volatile ChannelHandlerContext handlerCtx;

    public NettyServerClient(
        InetSocketAddress remoteAddress,
        InetSocketAddress localAddress,
        boolean ssl,
        int hbInterval,
        int hbLossThreshold,
        boolean hbDisabled,
        Map handlers,
        EventLoopGroup eventLoopGroup
    ) {
        this.remoteAddress = remoteAddress;
        this.localAddress = localAddress;
        this.ssl = ssl;
        this.hbInterval = hbInterval;
        this.hbLossThreshold = hbLossThreshold;
        this.hbDisabled = hbDisabled;
        this.handlers = handlers;
        this.eventLoopGroup = eventLoopGroup;

        hbFlushListener = future -> hbFlushed = true;
    }

    @Override
    public String protocol() {
        return protocol;
    }

    @Override
    public InetSocketAddress remoteAddress() {
        return remoteAddress;
    }

    @Override
    public InetSocketAddress localAddress() {
        return localAddress;
    }

    @Override
    public boolean isSecure() {
        return ssl;
    }

    @Override
    public Object getContext() {
        return userContext;
    }

    @Override
    public void setContext(Object ctx) {
        this.userContext = ctx;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        Channel channel = ctx.channel();

        if (trace) {
            log.trace("Got connection [from={}]", channel.remoteAddress());
        }

        this.handlerCtx = ctx;

        mayBeCreateIdleStateHandler().ifPresent(handler ->
            ctx.pipeline().addFirst(IdleStateHandler.class.getName(), handler)
        );

        if (metrics != null) {
            metrics.onConnect();
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        if (handlerReg != null) {
            handlerReg.remove(this);
        }

        if (serverHandler != null && connectNotified) {
            try {
                serverHandler.onDisconnect(this);
            } finally {
                serverHandler = null;
            }
        }

        if (metrics != null) {
            metrics.onDisconnect();
        }

        this.handlerCtx = null;

        if (debug) {
            log.debug("Closed connection [from={}]", address());
        }
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (isHandshakeDone()) {
            if (msg instanceof Heartbeat) {
                if (trace) {
                    log.trace("Received network heartbeat from client [from={}]", address());
                }
            } else {
                NettyMessage netMsg = (NettyMessage)msg;

                netMsg.prepare(log);

                if (trace) {
                    log.trace("Message buffer prepared [from={}, message={}]", address(), netMsg);
                }

                if (metrics != null) {
                    metrics.onMessageReceived();
                }

                try {
                    serverHandler.onMessage(netMsg, this);
                } finally {
                    netMsg.release();
                }
            }
        } else {
            if (trace) {
                log.trace("Received network handshake request [from={}, message={}]", address(), msg);
            }

            HandshakeRequest handshake = (HandshakeRequest)msg;

            String protocol;
            NettyServerHandler handlerReg;

            if (handshake == null) {
                protocol = null;
                handlerReg = null;
            } else {
                this.protocol = protocol = handshake.protocol();

                handlerReg = handlers.get(protocol);
            }

            if (handlerReg == null) {
                if (debug) {
                    log.debug("Closing connection with unsupported protocol [from={}, protocol={}]", address(), protocol);
                }

                HandshakeReject reject = new HandshakeReject("Unsupported protocol [protocol=" + protocol + ']');

                ctx.writeAndFlush(reject).addListener(ChannelFutureListener.CLOSE);
            } else {
                // Map connection to a thread.
                EventLoop eventLoop = mapToThread(handshake.threadAffinity(), handlerReg);

                // Check if we need to re-bind this channel to another thread.
                if (eventLoop.inEventLoop()) {
                    // No need to rebind.
                    init(ctx.channel(), handshake, handlerReg);
                } else {
                    if (trace) {
                        log.trace("Registering channel to a custom NIO thread [from={}, protocol={}]", address(), protocol);
                    }

                    // Unregister and then re-register IdleStateHandler in order to prevent RejectedExecutionException if same
                    // instance is used on different threads.
                    ctx.pipeline().remove(IdleStateHandler.class.getName());

                    Channel channel = ctx.channel();

                    channel.deregister().addListener(deregister -> {
                        if (deregister.isSuccess()) {
                            if (!eventLoop.isShutdown() && channel.isOpen()) {
                                eventLoop.register(channel).addListener(register -> {
                                    if (register.isSuccess() && channel.isOpen()) {
                                        if (trace) {
                                            log.trace("Registered channel to a custom NIO thread [from={}, protocol={}]",
                                                address(), protocol);
                                        }

                                        mayBeCreateIdleStateHandler().ifPresent(handler ->
                                            ctx.pipeline().addFirst(IdleStateHandler.class.getName(), handler)
                                        );

                                        init(channel, handshake, handlerReg);
                                    }
                                });
                            }
                        }
                    });
                }
            }
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable error) throws Exception {
        Throwable realError = NettyErrorUtils.unwrap(error);

        boolean disconnect = true;

        if (realError instanceof CodecException) {
            disconnect = false;
        } else if (realError instanceof HekateIoException) {
            if (debug) {
                log.debug("Closing inbound network connection due to I/O error "
                    + "[protocol={}, from={}, reason={}]", protocol, address(), realError.toString());
            }
        } else if (log.isErrorEnabled()) {
            log.error("Inbound network connection failure [protocol={}, from={}]", protocol, address(), realError);
        }

        if (disconnect) {
            if (serverHandler != null) {
                serverHandler.onFailure(this, realError);
            }

            ctx.close();
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof AutoReadChangeEvent) {
            if (evt == AutoReadChangeEvent.PAUSE) {
                // Completely ignore read timeouts.
                ignoreReadTimeouts = -1;
            } else {
                // Ignore next timeout.
                ignoreReadTimeouts = 1;
            }
        } else if (evt instanceof IdleStateEvent) {
            IdleStateEvent idle = (IdleStateEvent)evt;

            if (idle.state() == IdleState.WRITER_IDLE) {
                if (hbFlushed && isHandshakeDone()) {
                    // Make sure that we don't push multiple heartbeats to the network buffer simultaneously.
                    // Need to perform this check since remote peer can hang and stop reading
                    // while this channel will still be trying to put more and more heartbeats on its send buffer.
                    hbFlushed = false;

                    ctx.writeAndFlush(Heartbeat.INSTANCE).addListener(hbFlushListener);
                }
            } else {
                // Reader idle.
                // Ignore if auto-reading was disabled since in such case we will not read any heartbeats.
                if (ignoreReadTimeouts != -1 && ctx.channel().config().isAutoRead()) {
                    // Check if timeout should be ignored.
                    if (ignoreReadTimeouts > 0) {
                        // Decrement the counter of ignored timeouts.
                        ignoreReadTimeouts--;
                    } else {
                        throw new NetworkTimeoutException("Connection timed out [protocol=" + protocol() + ", address=" + address() + ']');
                    }
                }
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }

    @Override
    public void send(Object msg) {
        doSend(msg, null);
    }

    @Override
    public void send(Object msg, NetworkSendCallback onSend) {
        doSend(msg, onSend);
    }

    @Override
    public void pauseReceiving(Consumer> callback) {
        pauseReceiver(true, callback);
    }

    @Override
    public void resumeReceiving(Consumer> callback) {
        pauseReceiver(false, callback);
    }

    @Override
    public boolean isReceiving() {
        ChannelHandlerContext localCtx = this.handlerCtx;

        return localCtx != null && localCtx.channel().config().isAutoRead();
    }

    @Override
    public NetworkFuture disconnect() {
        NetworkFuture future = new NetworkFuture<>();

        ChannelHandlerContext localCtx = this.handlerCtx;

        if (localCtx == null) {
            future.complete(this);
        } else {
            this.handlerCtx = null;

            ChannelFuture closeFuture = localCtx.close();

            closeFuture.addListener(completed -> future.complete(this));
        }

        return future;
    }

    @Override
    public Optional nettyChannel() {
        ChannelHandlerContext ctx = this.handlerCtx;

        return ctx != null ? Optional.of(ctx.channel()) : Optional.empty();
    }

    private void pauseReceiver(boolean pause, Consumer> callback) {
        ChannelHandlerContext localCtx = this.handlerCtx;

        if (localCtx != null) {
            if (pause) {
                if (debug) {
                    log.debug("Pausing inbound receiver [from={}, protocol={}]", address(), protocol);
                }
            } else {
                if (debug) {
                    log.debug("Resuming Pausing inbound receiver [from={}, protocol={}]", address(), protocol);
                }
            }

            Channel channel = localCtx.channel();
            EventLoop eventLoop = channel.eventLoop();

            if (eventLoop.inEventLoop()) {
                channel.config().setAutoRead(!pause);

                notifyOnReceivePause(pause, callback, channel);
            } else {
                eventLoop.execute(() -> {
                    channel.config().setAutoRead(!pause);

                    notifyOnReceivePause(pause, callback, channel);
                });
            }
        } else if (callback != null) {
            callback.accept(this);
        }
    }

    private void notifyOnReceivePause(boolean pause, Consumer> callback, Channel channel) {
        channel.pipeline().fireUserEventTriggered(pause ? AutoReadChangeEvent.PAUSE : AutoReadChangeEvent.RESUME);

        if (callback != null) {
            try {
                callback.accept(this);
            } catch (Throwable e) {
                log.error("Got an unexpected error while notifying callback on network inbound receive status change "
                    + "[pause={}, from={}, protocol={}]", pause, address(), protocol, e);
            }
        }
    }

    private Optional mayBeCreateIdleStateHandler() {
        if (hbInterval > 0 && hbLossThreshold > 0) {
            int readerIdle = hbInterval * hbLossThreshold;
            int writerIdle = hbDisabled ? 0 : hbInterval;

            if (trace) {
                log.trace("Registering heartbeat handler [from={}, reader-idle={}, writer-idle={}]", address(), readerIdle, writerIdle);
            }

            return Optional.of(new IdleStateHandler(readerIdle, writerIdle, 0, TimeUnit.MILLISECONDS));
        } else {
            return Optional.empty();
        }
    }

    private void init(Channel channel, HandshakeRequest request, NettyServerHandler handlerReg) {
        NettyServerHandlerConfig cfg = handlerReg.config();

        if (cfg.getLoggerCategory() != null) {
            log = LoggerFactory.getLogger(cfg.getLoggerCategory());

            debug = log.isDebugEnabled();
            trace = log.isTraceEnabled();
        }

        this.eventLoop = channel.eventLoop();
        this.serverHandler = cfg.getHandler();
        this.handlerReg = handlerReg;
        this.metrics = handlerReg.metrics();
        this.codec = request.codec();

        // Register this client.
        handlerReg.add(this);

        // Configure metrics.
        if (metrics != null) {
            channel.pipeline().addFirst(new ChannelTrafficShapingHandler(0, 0, NettyInternalConst.TRAFFIC_SHAPING_INTERVAL) {
                @Override
                protected void doAccounting(TrafficCounter counter) {
                    metrics.onBytesReceived(counter.lastReadBytes());
                    metrics.onBytesSent(counter.lastWrittenBytes());
                }
            });
        }

        if (debug) {
            log.debug("Accepted connection [from={}, protocol={}]", address(), cfg.getProtocol());
        }

        // Accept handshake.
        HandshakeAccept accept = new HandshakeAccept(hbInterval, hbLossThreshold, hbDisabled);

        channel.writeAndFlush(accept).addListener(future -> {
                if (channel.isOpen()) {
                    connectNotified = true;

                    // Notify on connect.
                    serverHandler.onConnect(request.payload(), this);
                }
            }
        );
    }

    private void doSend(Object msg, NetworkSendCallback onSend) {
        ChannelHandlerContext localCtx = this.handlerCtx;

        if (localCtx == null) {
            // Notify on failure.
            if (onSend != null) {
                NettyUtils.runAtAllCost(eventLoop, () ->
                    notifyOnError(msg, onSend, WRITE_CLOSED_CHANNEL_EXCEPTION)
                );
            }
        } else {
            // Write message to the channel.
            write(msg, onSend, localCtx);
        }
    }

    private void write(Object msg, NetworkSendCallback onSend, ChannelHandlerContext localCtx) {
        if (!validateMessageType(msg, onSend)) {
            return;
        }

        if (debug) {
            log.debug("Sending to client [to={}, message={}]", address(), msg);
        }

        if (metrics != null) {
            metrics.onMessageEnqueue();
        }

        Channel channel = localCtx.channel();

        // Prepare deferred message.
        DeferredMessage deferredMsg;

        boolean failed = false;

        // Maybe pre-encode message.
        if (codec.isStateful()) {
            deferredMsg = new DeferredMessage(msg, channel);
        } else {
            if (trace) {
                log.trace("Pre-encoding message [to={}, message={}]", address(), msg);
            }

            try {
                ByteBuf buf = NetworkProtocolCodec.preEncode(msg, codec, localCtx.alloc());

                deferredMsg = new DeferredEncodedMessage(buf, msg, channel);
            } catch (CodecException e) {
                deferredMsg = fail(msg, channel, e);

                failed = true;
            }
        }

        deferredMsg.addListener(result -> {
            if (metrics != null) {
                metrics.onMessageDequeue();
            }

            if (result.isSuccess()) {
                // Successful operation.
                if (debug) {
                    log.debug("Done sending to client [to={}, message={}]", address(), msg);
                }

                if (metrics != null) {
                    metrics.onMessageSent();
                }
            } else {
                // Failed operation.
                Throwable cause = result.cause();

                if (debug) {
                    log.debug("Failed to send to client [to={}, message={}]", address(), msg, cause);
                }

                // Notify channel pipeline on error (ignore if channel is already closed).
                if (channel.isOpen()) {
                    channel.pipeline().fireExceptionCaught(cause);
                }
            }

            // Notify user callback.
            if (onSend != null) {
                try {
                    Throwable mayBeError = NettyErrorUtils.unwrap(result.cause());

                    onSend.onComplete(msg, mayBeError);
                } catch (Throwable t) {
                    if (log.isErrorEnabled()) {
                        log.error("Failed to notify network message callback [message={}]", msg, t);
                    }
                }
            }
        });

        // Enqueue write operation.
        if (!failed) {
            writeQueue.enqueue(deferredMsg, localCtx.executor());
        }
    }

    private boolean validateMessageType(Object msg, NetworkSendCallback onSend) {
        if (codec.baseType().isInstance(msg)) {
            return true;
        } else {
            String expected = codec.baseType().getName();
            String real = msg.getClass().getName();

            CodecException err = new CodecException("Unsupported message type [expected=" + expected + ", real=" + real + ']');

            if (onSend == null) {
                if (log.isErrorEnabled()) {
                    log.error("Message sending failed.", err);
                }
            } else {
                notifyOnError(msg, onSend, err);
            }

            return false;
        }
    }

    private EventLoop mapToThread(int affinity, NettyServerHandler handler) {
        EventLoopGroup group;

        // Check if a dedicated thread pool is defined for this protocol.
        if (handler.config().getEventLoop() == null) {
            // Use core thread pool.
            group = eventLoopGroup;
        } else {
            // Use dedicated thread pool.
            group = handler.config().getEventLoop();
        }

        List eventLoops = new ArrayList<>();

        // Assumes that the same group always returns its event loops in the same order.
        for (Iterator it = group.iterator(); it.hasNext(); ) {
            eventLoops.add((EventLoop)it.next());
        }

        return eventLoops.get(Utils.mod(affinity, eventLoops.size()));
    }

    private DeferredMessage fail(Object msg, Channel channel, Throwable error) {
        DeferredMessage promise = new DeferredMessage(msg, channel);

        promise.setFailure(error);

        return promise;
    }

    private void notifyOnError(Object msg, NetworkSendCallback onSend, Throwable error) {
        try {
            onSend.onComplete(msg, error);
        } catch (Throwable e) {
            log.error("Failed to notify callback on network operation failure "
                + "[protocol={}, from={}, message={}]", protocol, address(), msg, e);
        }
    }

    private boolean isHandshakeDone() {
        return serverHandler != null;
    }

    private InetSocketAddress address() {
        return remoteAddress;
    }

    @Override
    public String toString() {
        return getClass().getSimpleName()
            + "[from=" + address()
            + ", protocol=" + protocol
            + ']';
    }
}