All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.dongliu.prettypb.rpc.server.RpcServerHandler Maven / Gradle / Ivy

There is a newer version: 0.3.5
Show newest version
/**
 *   Copyright 2010-2014 Peter Klauser
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package net.dongliu.prettypb.rpc.server;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import net.dongliu.prettypb.rpc.protocol.RpcCancel;
import net.dongliu.prettypb.rpc.protocol.RpcRequest;
import net.dongliu.prettypb.rpc.protocol.WirePayload;
import net.dongliu.prettypb.rpc.info.MethodInfo;
import net.dongliu.prettypb.rpc.info.ServiceInfo;
import net.dongliu.prettypb.runtime.ExtensionRegistry;
import net.dongliu.prettypb.runtime.ProtobufDeSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * A pipeline handler which handles incoming RpcRequest and RpcCancel payloads towards
 *
 * @author Peter Klauser
 */
public class RpcServerHandler extends MessageToMessageDecoder {

    private static Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);
    private final RpcServerChannel rpcServerChannel;
    private final RpcServiceRegistry rpcServiceRegistry;
    private final ThreadPoolExecutor rpcServiceExecutor;

    private final RpcServerChannelRegistry rpcServerChannelRegistry;
    /**
     * extension registry for rpc messages
     */
    private final ExtensionRegistry extensionRegistry;

    private final Map serverCallTaskMap =
            new ConcurrentHashMap<>();

    public RpcServerHandler(RpcServerChannel rpcServerChannel,
                            RpcServiceRegistry rpcServiceRegistry,
                            ThreadPoolExecutor rpcServiceExecutor,
                            RpcServerChannelRegistry rpcServerChannelRegistry,
                            ExtensionRegistry extensionRegistry) {
        this.rpcServerChannelRegistry = rpcServerChannelRegistry;
        this.rpcServerChannel = rpcServerChannel;
        this.rpcServiceRegistry = rpcServiceRegistry;
        this.rpcServiceExecutor = rpcServiceExecutor;
        this.extensionRegistry = extensionRegistry;
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, WirePayload msg,
                          List out) throws Exception {
        if (msg.hasRpcRequest()) {
            onRequest(msg.getRpcRequest());
        } else if (msg.hasRpcCancel()) {
            cancel(msg.getRpcCancel());
        } else {
            // serverMessage, unsolicitedMessage, rpcResponse, rpcError were consumed further down by RpcClientHandler.
            // everything else is passed through to potentially later channel handlers which are modified by using code.
            out.add(msg);
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
        rpcServerChannelRegistry.removeRpcServerChannel(rpcServerChannel);
        handleClosure();
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        logger.warn("Exception caught during RPC operation.", cause);
        ctx.close();
        handleClosure();
    }

    /**
     * @return the rpcClientRegistry
     */
    public RpcServerChannelRegistry getRpcServerChannelRegistry() {
        return rpcServerChannelRegistry;
    }

    /**
     * process rpc request
     *
     * @param rpcRequest
     */
    public void onRequest(RpcRequest rpcRequest) {
        long start = System.currentTimeMillis();
        int correlationId = rpcRequest.getCorrelationId();

        if (serverCallTaskMap.containsKey(correlationId)) {
            throw new IllegalStateException("correlationId " + correlationId
                    + " already registered as PendingServerCall.");
        }

        ServiceInfo serviceInfo;
        if (rpcRequest.hasServicePackage()) {
            serviceInfo = rpcServiceRegistry.simpleResolveService(rpcRequest.getServiceIdentifier());
        } else {
            // proto-rpc-pro has a issue do not consider the package of proto file. we add it here
            // and keep compatible with the origin proto-rpc-pro
            serviceInfo = rpcServiceRegistry.resolveService(rpcRequest.getServicePackage(),
                    rpcRequest.getMethodIdentifier());
        }

        if (serviceInfo == null) {
            // service not found
            String errorMessage = "Unknown Service: " + rpcRequest.getServiceIdentifier();
            logger.error("service not found: {}", rpcRequest.getServiceIdentifier());
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }

        MethodInfo methodInfo = serviceInfo.getMethodInfo(rpcRequest.getMethodIdentifier());

        if (methodInfo == null) {
            String errorMessage = "Unknown Method: " + rpcRequest.getMethodIdentifier();
            logger.error("rpc method not found: {}", rpcRequest.getMethodIdentifier());
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }
        Object request;
        byte[] requestData = rpcRequest.getRequestBytes();
        try {
            request = ProtobufDeSerializer.fromBytes(methodInfo.getRequestType(), requestData,
                    extensionRegistry);

        } catch (RuntimeException e) {
            String errorMessage = "Invalid request protobuf:" + e.getMessage();
            logger.error("invalid request protobuf", e);
            rpcServerChannel.sendRpcError(correlationId, errorMessage);
            return;
        }

        ServerCallTask task = new ServerCallTask(serviceInfo, methodInfo, request, start,
                rpcRequest.getTimeoutMs(), correlationId, rpcServerChannel, serverCallTaskMap);
        // note this onRequest
        serverCallTaskMap.put(correlationId, task);
        // submit to execute
        submit(task);
    }

    /**
     * submit a task
     *
     * @param task
     */
    private void submit(final ServerCallTask task) {
        try {
            rpcServiceExecutor.submit(task);
        } catch (RejectedExecutionException e) {
            rpcServerChannel.sendRpcError(task.getCorrelationId(),
                    "too many request to service: " + e.getMessage());
            serverCallTaskMap.remove(task.getCorrelationId());
        }
    }

    /**
     * On cancelFromExecutor from the client, the RpcServer does not expect to receive a
     * callback anymore from the RpcServerCallExecutor.
     *
     * @param rpcCancel
     */
    public void cancel(RpcCancel rpcCancel) {
        int correlationId = rpcCancel.getCorrelationId();

        if (rpcServiceExecutor == null) {
            return;
        }
        ServerCallTask task = serverCallTaskMap.remove(correlationId);
        if (task != null) {
            cancelFromExecutor(task);
        }
    }

    private void cancelFromExecutor(ServerCallTask task) {
        // cancel task just by set canceled to true
        task.setCancel(true);
    }


    /**
     * Cancel any pending server calls due to closure of the RpcClient.
     */
    public void handleClosure() {
        logger.debug("rpc channel closed.");
        do {
            for (Integer correlationId : serverCallTaskMap.keySet()) {
                ServerCallTask task = serverCallTaskMap.remove(correlationId);
                if (task != null) {
                    cancelFromExecutor(task);
                    logger.debug("request({}) cancel on close.", correlationId);
                }
            }
        } while (serverCallTaskMap.size() > 0);
    }
}