All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.dongliu.prettypb.rpc.client.RpcClientChannel Maven / Gradle / Ivy

There is a newer version: 0.3.5
Show newest version
package net.dongliu.prettypb.rpc.client;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibWrapper;
import net.dongliu.prettypb.rpc.listener.TcpConnectionEventListener;
import net.dongliu.prettypb.rpc.protocol.*;
import net.dongliu.prettypb.rpc.info.PeerInfo;
import net.dongliu.prettypb.rpc.utils.Handlers;
import net.dongliu.prettypb.runtime.ExtensionRegistry;
import net.dongliu.prettypb.runtime.ProtobufSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;

/**
 * An RpcClientChannel is constructed once when a connection is established
 * with a server ( on both client and server ). The same RpcClientChannel
 * instance is used for the duration of the connection.
 *
 * If a client reconnects, a new RpcClientChannel instance is constructed.
 *
 * @author dongliu
 */
public class RpcClientChannel implements AutoCloseable {

    private final Map pendingRequestMap = new ConcurrentHashMap<>();

    // below fields are updated when connectOnce been called
    /**
     * client peer info
     */
    private volatile PeerInfo clientPeer;

    /**
     * server peer info. available after connect
     */
    private volatile PeerInfo serverPeer;

    /**
     * determined by if client want to use compress and if server support compress
     */
    private volatile boolean useCompress;

    /**
     * the socket channel this rpc channel hold
     */
    private Channel channel;

    private final String host;
    private final int port;
    private final Bootstrap bootstrap;
    private final int correlationId;
    private final boolean clientUseCompress;
    private final int connectTimeout;
    private final List listeners;

    private static Logger logger = LoggerFactory.getLogger(RpcClientChannel.class);
    private final ExtensionRegistry extensionRegistry;

    public RpcClientChannel(String host, int port, Bootstrap bootstrap, int correlationId,
                            boolean clientUseCompress, int connectTimeout,
                            List listeners,
                            ExtensionRegistry extensionRegistry) {
        this.host = host;
        this.port = port;
        this.bootstrap = bootstrap;
        this.correlationId = correlationId;
        this.clientUseCompress = clientUseCompress;
        this.connectTimeout = connectTimeout;
        this.listeners = listeners;
        this.extensionRegistry = extensionRegistry;
    }

    /**
     * connect and setup pipeline.
     *
     * @throws java.io.IOException
     */
    public void connect() throws IOException, TimeoutException, InterruptedException {
        // Make a new connection.
        InetSocketAddress remoteAddress = new InetSocketAddress(host, port);
        ChannelFuture connectFuture = bootstrap.connect(remoteAddress).awaitUninterruptibly();

        if (!connectFuture.isSuccess()) {
            throw new IOException("Failed to connect to " + remoteAddress, connectFuture.cause());
        }

        channel = connectFuture.channel();
        InetSocketAddress connectedAddress = (InetSocketAddress) channel.localAddress();
        clientPeer = new PeerInfo(connectedAddress.getHostName(), connectedAddress.getPort());
        sendConnectMsg(channel, clientPeer, correlationId, clientUseCompress);

        // connect response
        ClientConnectResponseHandler connectResponseHandler = (ClientConnectResponseHandler)
                channel.pipeline().get(Handlers.CLIENT_CONNECT);
        if (connectResponseHandler == null) {
            throw new IllegalStateException("No connectResponse handler in channel pipeline.");
        }

        ConnectResponse connectResponse = connectResponseHandler
                .blockGetConnectResponse(connectTimeout);
        checkConnectResponse(channel, correlationId, connectResponse);

        if (connectResponse.hasServerPID()) {
            this.serverPeer = new PeerInfo(remoteAddress.getHostName(), remoteAddress.getPort(),
                    connectResponse.getServerPID());
        } else {
            this.serverPeer = new PeerInfo(remoteAddress.getHostName(), remoteAddress.getPort());
        }

        this.useCompress = connectResponse.isCompress();
        logger.info("connect to {}:{}, use compress: {}", host, port, useCompress);

        RpcClientHandler rpcClientHandler = completePipeline(this.useCompress, channel.pipeline());
        rpcClientHandler.notifyOpened();

        // start timeout checker
        startCheckTimeout();
    }


    private void checkConnectResponse(Channel channel, int correlationId,
                                      ConnectResponse connectResponse)
            throws IOException, TimeoutException {
        if (connectResponse == null) {
            channel.close().awaitUninterruptibly();
            throw new TimeoutException("connect to rpc server error, response is null");
        }
        if (connectResponse.hasErrorCode()) {
            channel.close().awaitUninterruptibly();
            throw new IOException("connect response error: " + connectResponse.getErrorCode());
        }
        if (!connectResponse.hasCorrelationId()) {
            channel.close().awaitUninterruptibly();
            throw new IOException("connect response missing correlationId.");
        }
        if (connectResponse.getCorrelationId() != correlationId) {
            channel.close().awaitUninterruptibly();
            throw new IOException(
                    "DuplexTcpServer CONNECT_RESPONSE correlationId mismatch. TcpClient sent "
                            + correlationId + " received "
                            + connectResponse.getCorrelationId() + " from TcpServer.");
        }
    }

    private ConnectRequest sendConnectMsg(Channel channel, PeerInfo clientInfo,
                                          int correlationId, boolean useCompress) {
        ConnectRequest connectRequest = new ConnectRequest();
        connectRequest.setClientHostName(clientInfo.getHostName());
        connectRequest.setClientPort(clientInfo.getPort());
        connectRequest.setClientPID(clientInfo.getPid());
        connectRequest.setCorrelationId(correlationId);
        connectRequest.setCompress(useCompress);

        WirePayload payload = new WirePayload();
        payload.setConnectRequest(connectRequest);
        channel.writeAndFlush(payload);
        return connectRequest;
    }


    /**
     * After RPC handshake has taken place, remove the RPC handshake
     * {@link ClientConnectResponseHandler} and add a {@link RpcClientHandler}
     * and {@link net.dongliu.prettypb.rpc.server.RpcServerHandler} to complete the Netty client side Pipeline.
     *
     * @return
     */
    private RpcClientHandler completePipeline(boolean compress, ChannelPipeline p) {

        if (compress) {
            p.addBefore(Handlers.FRAME_DECODER, Handlers.COMPRESSOR,
                    ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP));
            p.addAfter(Handlers.COMPRESSOR, Handlers.DECOMPRESSOR,
                    ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
        }

        TcpConnectionEventListener informer = new TcpConnectionEventListener() {
            @Override
            public void connectionClosed(RpcClientChannel rpcClientChannel) {
                for (TcpConnectionEventListener listener : listeners) {
                    listener.connectionClosed(rpcClientChannel);
                }
            }

            @Override
            public void connectionOpened(RpcClientChannel rpcClientChannel) {
                for (TcpConnectionEventListener listener : listeners) {
                    listener.connectionOpened(rpcClientChannel);
                }
            }
        };
        RpcClientHandler rpcClientHandler = new RpcClientHandler(informer, pendingRequestMap, this,
                extensionRegistry);
        p.replace(Handlers.CLIENT_CONNECT, Handlers.RPC_CLIENT, rpcClientHandler);

        return rpcClientHandler;
    }

    public void registerPendingRequest(int seqId, ClientCallTask state) {
        if (pendingRequestMap.containsKey(seqId)) {
            throw new IllegalArgumentException("State already registered");
        }
        pendingRequestMap.put(seqId, state);
    }

    public void writeRequest(RpcRequest rpcRequest) {
        WirePayload payload = new WirePayload();
        payload.setRpcRequest(rpcRequest);
        channel.writeAndFlush(payload);
    }

    public ChannelFuture sendOobMessage(Object message) {
        OobMessage msg = new OobMessage();
        msg.setMessageBytes(ProtobufSerializer.toBytes(msg, (Class) message.getClass()));
        WirePayload payload = new WirePayload();
        payload.setOobMessage(msg);
        return channel.writeAndFlush(payload);
    }

    /**
     * For use by ServerRpcController to send a server out-of-band message
     * back to the client.
     *
     * @param correlationId
     * @param oobMessage
     */
    public ChannelFuture sendOobResponse(int correlationId, Object oobMessage) {
        OobResponse msg = new OobResponse();
        msg.setCorrelationId(correlationId);
        msg.setMessageBytes(ProtobufSerializer.toBytes(oobMessage,
                (Class) oobMessage.getClass()));

        WirePayload payload = new WirePayload();
        payload.setOobResponse(msg);

        return channel.writeAndFlush(payload);
    }

    public PeerInfo getServerPeer() {
        return serverPeer;
    }

    @Override
    public void close() {
        this.channel.close().awaitUninterruptibly();
    }

    private void startCheckTimeout() {
        new Thread() {
            @Override
            public void run() {
                try {
                    Thread.sleep(100);
                } catch (InterruptedException ignore) {
                }
                checkTimeout();
            }
        };
    }

    private void checkTimeout() {
        for (ClientCallTask task : pendingRequestMap.values()) {
            if (task.isTimeout()) {
                task = pendingRequestMap.remove(task.getCorrelationId());
                if (task != null) {
                    task.setCanceled(true);
                    task.getCallback().onError(new TimeoutException("time out"));
                }
            }
        }
    }
}