All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.dongliu.prettypb.rpc.server.RpcServerHandler Maven / Gradle / Ivy

There is a newer version: 0.3.5
Show newest version
/**
 *   Copyright 2010-2014 Peter Klauser
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package net.dongliu.prettypb.rpc.server;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import net.dongliu.prettypb.rpc.common.MethodInfo;
import net.dongliu.prettypb.rpc.common.ServiceInfo;
import net.dongliu.prettypb.rpc.common.TaskCallBack;
import net.dongliu.prettypb.rpc.common.TaskSet;
import net.dongliu.prettypb.rpc.exception.ServiceException;
import net.dongliu.prettypb.rpc.protocol.RpcCancel;
import net.dongliu.prettypb.rpc.protocol.RpcRequest;
import net.dongliu.prettypb.rpc.protocol.WirePayload;
import net.dongliu.prettypb.runtime.ExtensionRegistry;
import net.dongliu.prettypb.runtime.ProtoBufDecoder;
import net.dongliu.prettypb.runtime.include.RpcCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * A pipeline handler which handles incoming RpcRequest and RpcCancel payloads towards
 *
 * @author Peter Klauser
 */
public class RpcServerHandler extends MessageToMessageDecoder {

    private static Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);
    private final RpcServerChannel rpcServerChannel;
    private final RpcServiceRegistry rpcServiceRegistry;
    private final ThreadPoolExecutor rpcServiceExecutor;

    private final RpcServerChannelRegistry rpcServerChannelRegistry;
    /**
     * extension registry for rpc messages
     */
    private final ExtensionRegistry extensionRegistry;

    private final TaskSet taskSet = new TaskSet<>();

    public RpcServerHandler(RpcServerChannel rpcServerChannel,
                            RpcServiceRegistry rpcServiceRegistry,
                            ThreadPoolExecutor rpcServiceExecutor,
                            RpcServerChannelRegistry rpcServerChannelRegistry,
                            ExtensionRegistry extensionRegistry) {
        this.rpcServerChannelRegistry = rpcServerChannelRegistry;
        this.rpcServerChannel = rpcServerChannel;
        this.rpcServiceRegistry = rpcServiceRegistry;
        this.rpcServiceExecutor = rpcServiceExecutor;
        this.extensionRegistry = extensionRegistry;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, WirePayload msg,
                          List out) throws Exception {
        if (msg.hasRpcRequest()) {
            onRequest(msg.getRpcRequest());
        } else if (msg.hasRpcCancel()) {
            cancel(msg.getRpcCancel());
        } else {
            // serverMessage, unsolicitedMessage, rpcResponse, rpcError were consumed further down by RpcClientHandler.
            // everything else is passed through to potentially later channel handlers which are modified by using code.
            out.add(msg);
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        rpcServerChannelRegistry.removeRpcServerChannel(rpcServerChannel);
        onClose();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        logger.warn("Exception caught during RPC operation.", cause);
        ctx.close();
        onClose();
    }

    /**
     * @return the rpcClientRegistry
     */
    public RpcServerChannelRegistry getRpcServerChannelRegistry() {
        return rpcServerChannelRegistry;
    }

    /**
     * process rpc request
     *
     * @param rpcRequest
     */
    public void onRequest(RpcRequest rpcRequest) {
        long start = System.currentTimeMillis();
        int correlationId = rpcRequest.getCorrelationId();

        ServiceInfo serviceInfo;
        if (rpcRequest.hasServicePackage()) {
            serviceInfo = rpcServiceRegistry.simpleResolveService(rpcRequest.getServiceIdentifier());
        } else {
            // proto-rpc-pro has a issue do not consider the package of proto file. we add it here
            // and keep compatible with the origin proto-rpc-pro
            serviceInfo = rpcServiceRegistry.resolveService(rpcRequest.getServicePackage(),
                    rpcRequest.getMethodIdentifier());
        }

        if (serviceInfo == null) {
            // service not found
            String errorMessage = "Unknown Service: " + rpcRequest.getServiceIdentifier();
            logger.error("service not found: {}", rpcRequest.getServiceIdentifier());
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }

        MethodInfo methodInfo = serviceInfo.getMethodInfo(rpcRequest.getMethodIdentifier());

        if (methodInfo == null) {
            String errorMessage = "Unknown Method: " + rpcRequest.getMethodIdentifier();
            logger.error("rpc method not found: {}", rpcRequest.getMethodIdentifier());
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }
        Object request;
        byte[] requestData = rpcRequest.getRequestBytes();
        try {
            request = ProtoBufDecoder.fromBytes(methodInfo.getRequestType(), requestData,
                    extensionRegistry);

        } catch (RuntimeException e) {
            String errorMessage = "Invalid request protobuf:" + e.getMessage();
            logger.error("invalid request protobuf", e);
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }

        ServerCallTask task = new ServerCallTask(methodInfo, request, start,
                rpcRequest.getTimeoutMs(), correlationId);
        // note this onRequest
        if (!taskSet.add(task)) {
            throw new ServiceException("add task failed");
        }
        // submit to execute
        submit(task);
    }

    /**
     * submit a task
     *
     * @param task
     */
    private void submit(final ServerCallTask task) {
        try {
            rpcServiceExecutor.submit(this.new TaskRunnable(task));
        } catch (RejectedExecutionException e) {
            taskSet.consume(task.id(), new TaskCallBack() {
                @Override
                public void onTask(ServerCallTask task) {
                    rpcServerChannel.sendRpcError(task.id(), "too many request to service");
                }
            });
        }
    }

    /**
     * On cancelFromExecutor from the client, the RpcServer does not expect to receive a
     * callback anymore from the RpcServerCallExecutor.
     *
     * @param rpcCancel
     */
    public void cancel(RpcCancel rpcCancel) {
        taskSet.consume(rpcCancel.getCorrelationId(), new TaskCallBack() {
            @Override
            public void onTask(ServerCallTask task) {
                cancelFromExecutor(task);
            }
        });
    }

    private void cancelFromExecutor(ServerCallTask task) {
        // cancel task just by set canceled to true
        task.cancel();
    }


    /**
     * Cancel any pending server calls due to closure of the RpcClient.
     */
    public void onClose() {
        logger.debug("rpc channel closed.");
        taskSet.close();
    }

    private class TaskRunnable implements Runnable {
        private final ServerCallTask task;

        TaskRunnable(ServerCallTask task) {
            this.task = task;
        }

        @Override
        public void run() {
            if (task.canceled()) {
                //should already be removed
                taskSet.consume(task.id());
                return;
            }
            if (task.timeout()) {
                sendError("Rpc call time out");
            }

            final MethodInfo methodInfo = task.getMethodInfo();
            final ServiceInfo serviceInfo = methodInfo.getServiceInfo();
            if (serviceInfo.isAsync()) {
                RpcCallback rpcCallback = new RpcCallback() {
                    @Override
                    public void onFinished(Object value) {
                        sendResponse(value, methodInfo.getResponseType());
                    }

                    @Override
                    public void onError(Exception e) {
                        sendError(e.getMessage());
                    }
                };
                try {
                    methodInfo.getImplMethod().invoke(serviceInfo.getImpl(), task.getRequest(),
                            rpcCallback);
                } catch (IllegalAccessException | InvocationTargetException e) {
                    sendError(e.getMessage());
                }
            } else {
                Object result;
                try {
                    result = methodInfo.getImplMethod().invoke(serviceInfo.getImpl(),
                            task.getRequest());
                } catch (IllegalAccessException | InvocationTargetException e) {
                    sendError(e.getMessage());
                    return;
                }
                sendResponse(result, methodInfo.getResponseType());
            }
        }

        private void sendError(final String error) {
            taskSet.consume(task.id(), new TaskCallBack() {
                @Override
                public void onTask(ServerCallTask task) throws Exception {
                    rpcServerChannel.sendRpcError(task.id(), error);
                }
            });
        }

        private void sendResponse(final Object result, final Class responseType) {
            taskSet.consume(task.id(), new TaskCallBack() {
                @Override
                public void onTask(ServerCallTask task) throws Exception {
                    rpcServerChannel.sendRpcResponse(task.id(), result, responseType);
                }
            });
        }
    }
}