All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.netty.protocol.nrpc.RpcServerChannelHandler Maven / Gradle / Ivy

package com.github.netty.protocol.nrpc;

import com.github.netty.annotation.NRpcMethod;
import com.github.netty.annotation.NRpcService;
import com.github.netty.core.AbstractChannelHandler;
import com.github.netty.core.util.*;
import com.github.netty.protocol.nrpc.codec.DataCodec;
import com.github.netty.protocol.nrpc.codec.DataCodecUtil;
import com.github.netty.protocol.nrpc.exception.RpcResponseException;
import com.github.netty.protocol.nrpc.exception.RpcTimeoutException;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;

import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.github.netty.protocol.nrpc.RpcPacket.*;
import static com.github.netty.protocol.nrpc.RpcPacket.ResponsePacket.*;
import static com.github.netty.protocol.nrpc.RpcServerAop.CONTEXT_LOCAL;
import static com.github.netty.protocol.nrpc.codec.DataCodec.Encode.BINARY;

/**
 * Server side processor
 *
 * @author wangzihao
 * 2018/9/16/016
 */
public class RpcServerChannelHandler extends AbstractChannelHandler {
    private static final LoggerX logger = LoggerFactoryX.getLogger(RpcServerChannelHandler.class);

    protected final ExpiryLRUMap rpcServerMethodDoneMap = new ExpiryLRUMap<>(512, Long.MAX_VALUE, Long.MAX_VALUE, null);
    protected final ExpiryLRUMap rpcChunkAckCallbackMap = new ExpiryLRUMap<>(512, Long.MAX_VALUE, Long.MAX_VALUE, null);
    private final Map serviceInstanceMap = new ConcurrentHashMap<>(8);
    private final List nettyRpcServerAopList = new CopyOnWriteArrayList<>();
    private final AtomicInteger chunkIdIncr = new AtomicInteger();
    /**
     * Data encoder decoder. (Serialization or Deserialization)
     */
    private DataCodec dataCodec;
    private ChannelHandlerContext context;
    private Supplier executorSupplier;
    private Executor executor;

    public RpcServerChannelHandler() {
        this(DataCodecUtil.newDataCodec());
    }

    public RpcServerChannelHandler(DataCodec dataCodec) {
        super(true);
        this.dataCodec = dataCodec;
        dataCodec.getEncodeRequestConsumerList().add(params -> {
            RpcContext rpcContext = CONTEXT_LOCAL.get();
            for (RpcServerAop aop : nettyRpcServerAopList) {
                aop.onDecodeRequestBefore(rpcContext, params);
            }
        });
        rpcServerMethodDoneMap.setOnExpiryConsumer(node -> {
            try {
                RpcRunnable runnable = node.getData();
                if (!runnable.done) {
                    if (runnable.timeoutNotifyFlag.compareAndSet(false, true)) {
                        runnable.executor.execute(runnable::onTimeout);
                    }
                    if (runnable.timeoutInterrupt) {
                        runnable.taskThread.interrupt();
                        runnable.interruptCount++;
                        rpcServerMethodDoneMap.put(node.getKey(), runnable, 100);
                    }
                }
            } catch (Exception e) {
                logger.warn("doneTimeout exception. server = {}, message = {}.", this, e.toString(), e);
            }
        });
        rpcChunkAckCallbackMap.setOnExpiryConsumer(node -> {
            try {
                ChunkAckCallback runnable = node.getData();
                if (!runnable.done) {
                    if (runnable.timeoutNotifyFlag.compareAndSet(false, true)) {
                        runnable.executor.execute(runnable::onTimeout);
                    }
                }
            } catch (Exception e) {
                logger.warn("doneTimeout exception. server = {}, message = {}.", this, e.toString(), e);
            }
        });
    }

    /**
     * Get the service name
     *
     * @param instanceClass instanceClass
     * @return requestMappingName
     */
    public static String getRequestMappingName(Class instanceClass) {
        String requestMappingName = null;
        NRpcService rpcInterfaceAnn = ReflectUtil.findAnnotation(instanceClass, NRpcService.class);
        if (rpcInterfaceAnn != null) {
            requestMappingName = rpcInterfaceAnn.value();
        }
        return requestMappingName;
    }

    /**
     * Generate a service name
     *
     * @param instanceClass instanceClass
     * @return requestMappingName
     */
    public static String generateRequestMappingName(Class instanceClass) {
        String requestMappingName;
        Class[] classes = ReflectUtil.getInterfaces(instanceClass);
        if (classes.length > 0) {
            requestMappingName = '/' + StringUtil.firstLowerCase(classes[0].getSimpleName());
        } else {
            requestMappingName = '/' + StringUtil.firstLowerCase(instanceClass.getSimpleName());
        }
        return requestMappingName;
    }

    public static RpcContext newRpcContext() {
        return new RpcContext<>();
    }

    static boolean buildAndWriteAndFlush(RequestPacket request, ResponseLastPacket lastResponse, RpcContext rpcContext, RpcServerChannelHandler channelHandler, RpcMethod rpcMethod, Object result, Throwable throwable, State state, ChunkAckCallback ackCallback, RpcRunnable rpcRunnable, int chunkIndex, RpcEmitter parentEmitter) {
        rpcContext.setResult(result);
        if (result instanceof Throwable) {
            result = result.toString();
        }

        ResponsePacket response;
        if (throwable != null) {
            rpcContext.setThrowable(throwable);
            response = lastResponse;
            response.setEncode(DataCodec.Encode.BINARY);
            response.setData(null);
            response.setStatus(SERVER_ERROR);
            response.setMessage(channelHandler.dataCodec.buildThrowableRpcMessage(throwable));
            logger.warn("invoke error = {}", throwable.toString(), throwable);
        } else if (result instanceof RpcEmitter) {
            RpcEmitter emitter = (RpcEmitter) result;
            emitter.usable(request, lastResponse, rpcContext, channelHandler, rpcMethod, rpcRunnable);
            return true;
        } else if (result instanceof CompletableFuture) {
            ((CompletableFuture) result).whenComplete((result1, throwable1) -> buildAndWriteAndFlush(request, lastResponse, rpcContext, channelHandler, rpcMethod, result1, throwable1, state, null, rpcRunnable, chunkIndex, parentEmitter));
            return true;
        } else {
            if (state == RpcContext.RpcState.WRITE_CHUNK) {
                int chunkId = channelHandler.newChunkId();
                for (RpcServerAop aop : channelHandler.getAopList()) {
                    try {
                        aop.onChunkAfter(rpcContext, result, chunkIndex, chunkId, parentEmitter);
                    } catch (Exception e) {
                        rpcMethod.getLog().warn(rpcMethod + " server.aop.onChunkAfter() exception = {}", e.toString(), e);
                    }
                }
                response = ResponsePacket.newChunkPacket(request.getRequestId(), chunkId);
                if (ackCallback != null) {
                    response.setAck(ACK_YES);
                    channelHandler.rpcChunkAckCallbackMap.put(chunkId, ackCallback, ackCallback.timeout);
                } else {
                    response.setAck(ACK_NO);
                }
            } else {
                if (rpcRunnable != null) {
                    rpcRunnable.done = true;
                }
                response = lastResponse;
            }
            if (result instanceof byte[]) {
                response.setEncode(DataCodec.Encode.BINARY);
                response.setData((byte[]) result);
            } else {
                response.setEncode(DataCodec.Encode.APP);
                if (state == RpcContext.RpcState.WRITE_CHUNK) {
                    response.setData(channelHandler.dataCodec.encodeChunkResponseData(result));
                } else {
                    response.setData(channelHandler.dataCodec.encodeResponseData(result, rpcMethod));
                }
            }
            response.setStatus(OK);
            response.setMessage("ok");
        }
        channelHandler.writeAndFlush(request.getAck(), response, rpcContext, state);
        return false;
    }

    public List getAopList() {
        return nettyRpcServerAopList;
    }

    public DataCodec getDataCodec() {
        return dataCodec;
    }

    public ChannelHandlerContext getContext() {
        return context;
    }

    public Supplier getExecutorSupplier() {
        return executorSupplier;
    }

    public void setExecutorSupplier(Supplier executorSupplier) {
        this.executorSupplier = executorSupplier;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        this.context = ctx;

        RpcContext rpcContext = newRpcContext();
        rpcContext.setRemoteAddress((InetSocketAddress) ctx.channel().remoteAddress());
        rpcContext.setLocalAddress((InetSocketAddress) ctx.channel().localAddress());
        CONTEXT_LOCAL.set(rpcContext);
        try {
            for (RpcServerAop aop : nettyRpcServerAopList) {
                aop.onConnectAfter(this);
            }
            if (executorSupplier != null) {
                this.executor = executorSupplier.get();
            }
        } finally {
            CONTEXT_LOCAL.remove();
            super.channelActive(ctx);
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        RpcContext rpcContext = newRpcContext();
        rpcContext.setRemoteAddress((InetSocketAddress) ctx.channel().remoteAddress());
        rpcContext.setLocalAddress((InetSocketAddress) ctx.channel().localAddress());
        CONTEXT_LOCAL.set(rpcContext);
        try {
            for (RpcServerAop aop : nettyRpcServerAopList) {
                aop.onDisconnectAfter(this);
            }
        } finally {
            CONTEXT_LOCAL.remove();
        }
        super.channelInactive(ctx);
    }

    @Override
    protected void onMessageReceived(ChannelHandlerContext ctx, RpcPacket packet) throws Exception {
        boolean async = false;
        RpcContext rpcContext = null;
        try {
            if (packet instanceof RequestPacket) {
                RequestPacket request = (RequestPacket) packet;
                rpcContext = newRpcContext();
                async = handleRequestPacket(rpcContext, request, ctx);
            } else if (packet instanceof ResponseChunkAckPacket) {
                ResponseChunkAckPacket response = (ResponseChunkAckPacket) packet;
                ChunkAckCallback callback = rpcChunkAckCallbackMap.remove(response.getAckChunkId());
                if (callback != null) {
                    callback.onAck(response);
                }
            }
        } finally {
            // recycle
            if (!async) {
                packet.recycle();
                if (rpcContext != null) {
                    rpcContext.recycle();
                }
            }
        }
    }

    private boolean handleRequestPacket(RpcContext rpcContext, RequestPacket request, ChannelHandlerContext ctx) {
        final Executor threadPool = this.executor;
        boolean async = false;
        try {
            rpcContext.setRemoteAddress((InetSocketAddress) ctx.channel().remoteAddress());
            rpcContext.setLocalAddress((InetSocketAddress) ctx.channel().localAddress());
            rpcContext.setRequest(request);
            rpcContext.setRpcBeginTimestamp(System.currentTimeMillis());

            // not found instance
            String serverInstanceKey = RpcServerInstance.getServerInstanceKey(request.getRequestMappingName(), request.getVersion());
            RpcServerInstance rpcInstance = serviceInstanceMap.get(serverInstanceKey);
            if (rpcInstance == null) {
                if (request.getAck() == ACK_YES) {
                    ResponseLastPacket response = ResponsePacket.newLastPacket();
                    rpcContext.setResponse(response);
                    boolean release = true;
                    try {
                        response.setRequestId(request.getRequestId());
                        response.setEncode(BINARY);
                        response.setStatus(ResponsePacket.NO_SUCH_SERVICE);
                        response.setMessage("not found service " + serverInstanceKey);
                        ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
                        release = false;
                    } finally {
                        if (release) {
                            RecyclableUtil.release(response);
                        }
                    }
                }
            } else {
                RpcMethod rpcMethod = rpcInstance.getRpcMethod(request.getMethodName());
                rpcContext.setRpcMethod(rpcMethod);
                ResponseLastPacket response = ResponsePacket.newLastPacket();
                rpcContext.setResponse(response);
                response.setRequestId(request.getRequestId());
                // not found method
                if (rpcMethod == null) {
                    response.setEncode(DataCodec.Encode.BINARY);
                    response.setStatus(NO_SUCH_METHOD);
                    response.setMessage("not found method [" + request.getMethodName() + "]");
                    response.setData(null);
                    writeAndFlush(request.getAck(), response, rpcContext, RpcContext.RpcState.WRITE_FINISH);
                } else if (threadPool != null) {
                    // invoke method by async and call event
                    int timeout = choseTimeout(rpcInstance.getTimeout(), rpcMethod.getTimeout(), request.getTimeout());
                    rpcContext.setTimeout(timeout);
                    RpcRunnable runnable = new RpcRunnable(threadPool, rpcMethod, timeout, response, request, dataCodec, this, rpcContext);
                    if (timeout > 0) {
                        rpcServerMethodDoneMap.put(runnable, runnable, timeout);
                    }
                    // execute by rpc thread pool
                    threadPool.execute(runnable);
                    async = true;
                } else {
                    // invoke method by sync
                    CONTEXT_LOCAL.set(rpcContext);
                    Object result = null;
                    Throwable throwable = null;
                    try {
                        result = rpcInstance.invoke(rpcMethod, request, rpcContext, this);
                    } catch (Throwable t) {
                        throwable = t;
                    }
                    async = buildAndWriteAndFlush(request, response, rpcContext, this, rpcMethod, result, throwable, RpcContext.RpcState.WRITE_FINISH, null, null, -1, null);
                }
            }
        } finally {
            // call event
            if (!async) {
                rpcContext.setRpcEndTimestamp(System.currentTimeMillis());
                CONTEXT_LOCAL.set(rpcContext);
                onResponseAfter(rpcContext);
            }
        }
        return async;
    }

    private int newChunkId() {
        int id = chunkIdIncr.getAndIncrement();
        if (id == Integer.MAX_VALUE) {
            chunkIdIncr.set(0);
        }
        return id;
    }

    public Executor getExecutor() {
        return executor;
    }

    /**
     * timeout is -1 then never timeout
     * timeout is 0 then use client timeout
     * timeout other then use server timeout
     *
     * @param serverServiceTimeout serverServiceTimeout
     * @param serverMethodTimeout  serverMethodTimeout
     * @param clientTimeout        clientTimeout
     * @return method timeout
     */
    public int choseTimeout(Integer serverServiceTimeout, Integer serverMethodTimeout, int clientTimeout) {
        if (serverMethodTimeout != null) {
            if (serverMethodTimeout == 0) {
                return clientTimeout;
            } else {
                return serverMethodTimeout;
            }
        } else if (serverServiceTimeout != null) {
            if (serverServiceTimeout == 0) {
                return clientTimeout;
            } else {
                return serverServiceTimeout;
            }
        } else {
            return clientTimeout;
        }
    }

    private void onResponseAfter(RpcContext rpcContext) {
        for (RpcServerAop aop : nettyRpcServerAopList) {
            aop.onResponseAfter(rpcContext);
        }
    }

    private void writeAndFlush(int ack, ResponsePacket response, RpcContext rpcContext, State rpcState) {
        boolean release = true;
        try {
            if (ack == ACK_YES) {
                context.writeAndFlush(response)
                        .addListener((ChannelFutureListener) future -> {
                            if (future.isSuccess()) {
                                onStateUpdate(rpcContext, rpcState);
                                if (rpcState == RpcContext.RpcState.WRITE_FINISH) {
                                    onStateUpdate(rpcContext, RpcContext.RpcState.END);
                                }
                            } else {
                                future.channel().close();
                            }
                        });
                release = false;
            } else {
                onStateUpdate(rpcContext, rpcState);
                if (rpcState == RpcContext.RpcState.WRITE_FINISH) {
                    onStateUpdate(rpcContext, RpcContext.RpcState.END);
                }
            }
        } finally {
            if (release) {
                RecyclableUtil.release(response);
            }
        }
    }

    public void onStateUpdate(RpcContext rpcContext, State toState) {
        State formState = rpcContext.getState();
        if (formState != null && formState.isComplete()) {
            return;
        }
        rpcContext.setState(toState);
        for (RpcServerAop aop : nettyRpcServerAopList) {
            aop.onStateUpdate(rpcContext, formState, toState);
        }
    }

    /**
     * Increase the RpcServerInstance
     *
     * @param requestMappingName requestMappingName
     * @param version            rpc version
     * @param rpcServerInstance  RpcServerInstance
     */
    public void addRpcServerInstance(String requestMappingName, String version, RpcServerInstance rpcServerInstance) {
        Object instance = rpcServerInstance.getInstance();
        if (requestMappingName == null || requestMappingName.isEmpty()) {
            requestMappingName = generateRequestMappingName(instance.getClass());
        }
        String serverInstanceKey = RpcServerInstance.getServerInstanceKey(requestMappingName, version);
        if (rpcServerInstance.getDataCodec() == null) {
            rpcServerInstance.setDataCodec(dataCodec);
        }
        RpcServerInstance oldServerInstance = serviceInstanceMap.put(serverInstanceKey, rpcServerInstance);
        if (oldServerInstance != null) {
            Object oldInstance = oldServerInstance.getInstance();
            logger.warn("override instance old={}, new={}",
                    oldInstance.getClass().getSimpleName() + "@" + Integer.toHexString(oldInstance.hashCode()),
                    instance.getClass().getSimpleName() + "@" + Integer.toHexString(instance.hashCode()));
        }
        logger.trace("addInstance({}, {}, {})",
                serverInstanceKey,
                instance.getClass().getSimpleName(),
                rpcServerInstance.getMethodToParameterNamesFunction().getClass().getSimpleName());
    }

    /**
     * Increase the instance
     *
     * @param instance The implementation class
     */
    public void addInstance(Object instance) {
        addInstance(instance, getRequestMappingName(instance.getClass()), true);
    }

    /**
     * Increase the instance
     *
     * @param instance             The implementation class
     * @param requestMappingName   requestMappingName
     * @param methodOverwriteCheck methodOverwriteCheck
     */
    public void addInstance(Object instance, String requestMappingName, boolean methodOverwriteCheck) {
        String version = RpcServerInstance.getVersion(instance.getClass(), "");
        addInstance(instance, requestMappingName, version, new ClassFileMethodToParameterNamesFunction(), new AnnotationMethodToMethodNameFunction(NRpcMethod.class), methodOverwriteCheck);
    }

    /**
     * Increase the instance
     *
     * @param instance                       The implementation class
     * @param requestMappingName             requestMappingName
     * @param version                        version
     * @param methodToParameterNamesFunction Method to a function with a parameter name
     * @param methodToNameFunction           Method of extracting remote call method name
     * @param methodOverwriteCheck           methodOverwriteCheck
     */
    public void addInstance(Object instance, String requestMappingName, String version, Function methodToParameterNamesFunction, Function methodToNameFunction, boolean methodOverwriteCheck) {
        Integer timeout = RpcServerInstance.getTimeout(instance.getClass());
        RpcServerInstance rpcServerInstance = new RpcServerInstance(instance, dataCodec, version, timeout, methodToParameterNamesFunction, methodToNameFunction, methodOverwriteCheck);
        addRpcServerInstance(requestMappingName, version, rpcServerInstance);
    }

    /**
     * Is there an instance
     *
     * @param instance instance
     * @return boolean existInstance
     */
    public boolean existInstance(Object instance) {
        if (serviceInstanceMap.isEmpty()) {
            return false;
        }
        Collection values = serviceInstanceMap.values();
        for (RpcServerInstance rpcServerInstance : values) {
            if (rpcServerInstance.getInstance() == instance) {
                return true;
            }
        }
        return false;
    }

    public Map getServiceInstanceMap() {
        return Collections.unmodifiableMap(serviceInstanceMap);
    }

    public static class ChunkAckCallback extends CompletableFuture {
        final AtomicBoolean timeoutNotifyFlag = new AtomicBoolean();
        final long startTimestamp = System.currentTimeMillis();
        boolean done = false;
        int timeout;
        Executor executor;
        Class type;
        RpcEmitter emitter;

        public void onTimeout() {
            if (done) {
                return;
            }
            long expiryTimestamp = System.currentTimeMillis();
            completeExceptionally(new RpcTimeoutException("RpcRequestTimeout : maxTimeout = [" + timeout +
                    "], timeout = [" + (expiryTimestamp - startTimestamp) + "], [" + toString() + "]", true,
                    startTimestamp, expiryTimestamp));
        }

        public void onAck(ResponseChunkAckPacket packet) {
            done = true;
            Integer status = packet.getStatus();
            if (status == null || status != OK) {
                completeExceptionally(new RpcResponseException(status, "Failure rpc response. status=" + status + ",message=" + packet.getMessage() + ",response=" + packet, true));
            } else {
                RpcServerInstance instance = (RpcServerInstance) emitter.rpcMethod.getInstance();
                Object data = instance.getDataCodec().decodeChunkResponseData(packet.getData(), emitter.rpcMethod);
                complete(cast(data));
            }
        }

        public ACK_TYPE cast(Object data) {
            return TypeUtil.cast(data, type);
        }
    }

    public static class RpcRunnable implements Runnable {
        final AtomicBoolean timeoutNotifyFlag = new AtomicBoolean();
        RpcMethod rpcMethod;
        RpcServerChannelHandler channelHandler;
        RequestPacket request;
        ResponseLastPacket response;
        DataCodec dataCodec;
        RpcContext rpcContext;
        int interruptCount = 0;
        Thread taskThread;
        boolean done = false;
        boolean timeoutInterrupt;
        int timeout;
        Executor executor;

        RpcRunnable(Executor executor, RpcMethod rpcMethod,
                    int timeout,
                    ResponseLastPacket response, RequestPacket request,
                    DataCodec dataCodec,
                    RpcServerChannelHandler channelHandler, RpcContext rpcContext) {
            this.executor = executor;
            this.rpcMethod = rpcMethod;
            this.timeout = timeout;
            this.response = response;
            this.timeoutInterrupt = rpcMethod.isTimeoutInterrupt();
            this.channelHandler = channelHandler;
            this.dataCodec = dataCodec;
            this.request = request;
            this.rpcContext = rpcContext;
        }

        public void onTimeout() {
            if (done) {
                return;
            }
            channelHandler.onStateUpdate(rpcContext, RpcContext.RpcState.TIMEOUT);
            for (RpcServerAop aop : channelHandler.nettyRpcServerAopList) {
                aop.onTimeout(rpcContext);
            }
        }

        @Override
        public int hashCode() {
            return super.hashCode();
        }

        @Override
        public boolean equals(Object obj) {
            return super.equals(obj);
        }

        @Override
        public void run() {
            taskThread = Thread.currentThread();
            CONTEXT_LOCAL.set(rpcContext);
            Object result = null;
            Throwable throwable = null;
            try {
                result = rpcMethod.getInstance().invoke(rpcMethod, request, rpcContext, channelHandler);
            } catch (Throwable t) {
                throwable = t;
            }
            done = true;
            buildAndWriteAndFlush(request, response, rpcContext, channelHandler, rpcMethod, result, throwable, RpcContext.RpcState.WRITE_FINISH, null, this, -1, null);
            rpcContext.setRpcEndTimestamp(System.currentTimeMillis());
            try {
                channelHandler.onResponseAfter(rpcContext);
            } finally {
                request.recycle();
                CONTEXT_LOCAL.remove();
            }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy