All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.ipc.NettyTransceiver Maven / Gradle / Ivy

/*
 * Tencent is pleased to support the open source community by making Angel available.
 *
 * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in 
 * compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/Apache-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 *
 */


package com.tencent.angel.ipc;

import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import com.tencent.angel.conf.AngelConf;
import com.tencent.angel.exception.RemoteException;
import com.tencent.angel.exception.StandbyException;
import com.tencent.angel.ipc.NettyTransportCodec.NettyDataPack;
import com.tencent.angel.ipc.NettyTransportCodec.NettyFrameDecoder;
import com.tencent.angel.ipc.NettyTransportCodec.NettyFrameEncoder;
import com.tencent.angel.protobuf.generated.RPCProtos;
import com.tencent.angel.protobuf.generated.RPCProtos.*;
import com.tencent.angel.protobuf.generated.RPCProtos.RpcResponseHeader.Status;
import com.tencent.angel.utils.ByteBufferInputStream;
import com.tencent.angel.utils.ByteBufferOutputStream;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.ReadTimeoutHandler;
import org.apache.hadoop.conf.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * A Netty-based {@link Transceiver} implementation.
 */
public class NettyTransceiver extends Transceiver {
  private static final Logger LOG = LoggerFactory.getLogger(NettyTransceiver.class.getName());

  private final AtomicInteger serialGenerator = new AtomicInteger(0);
  private final Map>> requests =
    new ConcurrentHashMap>>();

  private final int connectTimeoutMillis;
  private final Bootstrap bootstrap;
  private final InetSocketAddress remoteAddr;

  private volatile ChannelFuture channelFuture;
  private volatile boolean stopping;
  private volatile Channel channel; // Synchronized on stateLock
  private final Object channelFutureLock = new Object();

  private volatile int refCount = 1;
  private Configuration conf;

  public NettyTransceiver(Configuration conf, InetSocketAddress addr, EventLoopGroup workerGroup,
    PooledByteBufAllocator pooledAllocator, Class socketChannelClass,
    int connectTimeoutMillis) throws IOException {
    this.conf = conf;
    this.connectTimeoutMillis = connectTimeoutMillis;

    bootstrap = new Bootstrap();
    bootstrap.group(workerGroup).channel(socketChannelClass)
      // Disable Nagle's Algorithm since we don't want packets to wait
      .option(ChannelOption.TCP_NODELAY, true).option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis)
      .option(ChannelOption.ALLOCATOR, pooledAllocator);

    // Configure the event pipeline factory.
    bootstrap.handler(new ChannelInitializer() {
      @Override protected void initChannel(SocketChannel ch) throws Exception {
        ch.pipeline().addLast("encoder", NettyFrameEncoder.INSTANCE)
          .addLast("frameDecoder", NettyUtils.createFrameDecoder())
          .addLast("decoder", NettyFrameDecoder.INSTANCE).addLast("readTimeout",
          new ReadTimeoutHandler(NettyTransceiver.this.conf
            .getInt(AngelConf.CONNECTION_READ_TIMEOUT_SEC,
              AngelConf.DEFAULT_CONNECTION_READ_TIMEOUT_SEC)))
          .addLast("handler", new MLClientMLHandler());
      }
    });
    remoteAddr = addr;
    // Make a new connection.
    try {
      getChannel();
    } catch (IOException e) {
      LOG.debug("connect error, e: " + e);
      throw e;
    }
  }

  /**
   * Tests whether the given channel is ready for writing.
   *
   * @return true if the channel is open and ready; false otherwise.
   */
  private static boolean isChannelReady(Channel channel) {
    return (channel != null) && channel.isOpen() && channel.isRegistered() && channel.isActive();
  }

  /**
   * Gets the Netty channel. If the channel is not connected, first attempts to connect. NOTE: The
   * stateLock read lock *must* be acquired before calling this method.
   *
   * @return the Netty channel
   * @throws java.io.IOException if an error occurs connecting the channel.
   */
  private synchronized Channel getChannel() throws IOException {
    if (!isChannelReady(channel)) {
      synchronized (channelFutureLock) {
        if (!stopping) {
          if (LOG.isDebugEnabled()) {
            LOG.debug("Connecting to " + remoteAddr);
          }
          channelFuture = bootstrap.connect(remoteAddr);
        }
      }
      if (channelFuture != null) {
        try {
          channelFuture.await(connectTimeoutMillis);
          LOG.debug("waiting connect timeout! connectTimeoutMillis: " + connectTimeoutMillis);
        } catch (InterruptedException e) {
          stopping = false;
          throw new IOException("Request has been interrupted.", e);
        }

        synchronized (channelFutureLock) {
          if (!channelFuture.isSuccess()) {
            channelFuture.cancel(true);
            throw new IOException("Error connecting to " + remoteAddr, channelFuture.cause());
          }
          channel = channelFuture.channel();
          if (LOG.isDebugEnabled()) {
            LOG.debug("new channel is {} ", channel);
          }
          channelFuture = null;
        }
      }
    }
    return channel;
  }

  /**
   * Closes the connection to the remote peer if connected.
   *
   * @param awaitCompletion       if true, will block until the close has completed.
   * @param cancelPendingRequests if true, will drain the requests map and send an IOException to
   *                              all Callbacks.
   * @param cause                 if non-null and cancelPendingRequests is true, this Throwable will be passed to
   *                              all Callbacks.
   */
  private synchronized void disconnect(Channel channel, boolean awaitCompletion,
    boolean cancelPendingRequests, Throwable cause) {
    if (LOG.isDebugEnabled()) {
      LOG.debug("disconnecting channel: " + channel);
    }
    Channel channelToClose = null;
    Map>> requestsToCancel = null;

    ChannelFuture channelFutureToCancel = null;
    synchronized (channelFutureLock) {
      if (stopping && channelFuture != null) {
        channelFutureToCancel = channelFuture;
        channelFuture = null;
      }
    }
    if (channelFutureToCancel != null) {
      channelFutureToCancel.cancel(true);
    }
    if (channel != null) {
      if (cause != null) {
        LOG.debug("Disconnect {} due to {}", channel,
          cause.getClass().getName() + cause.getMessage());
      } else {
        if (LOG.isDebugEnabled()) {
          LOG.debug("Disconnect {}", this.channel);
        }
      }
      channelToClose = channel;
      this.channel = null;

      if (cancelPendingRequests) {
        // Remove all pending requests (will be canceled after relinquishing
        // write lock).
        requestsToCancel = new ConcurrentHashMap>>(requests);
        requests.clear();
      }
    }

    // Cancel any pending requests by sending errors to the callbacks:
    if ((requestsToCancel != null) && !requestsToCancel.isEmpty()) {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Removing " + requestsToCancel.size() + " pending request(s).");
      }
      for (Callback> request : requestsToCancel.values()) {
        request.handleError(
          cause != null ? cause : new IOException(getClass().getSimpleName() + " closed"));
      }
    }

    // Close the channel:
    if (channelToClose != null) {
      ChannelFuture closeFuture = channelToClose.close();
      if (awaitCompletion && (closeFuture != null)) {
        closeFuture.awaitUninterruptibly(connectTimeoutMillis);
      }
    }
  }

  /**
   * Netty channels are thread-safe, so there is no need to acquire locks. This method is a no-op.
   */
  @Override public void lockChannel() {

  }

  /**
   * Netty channels are thread-safe, so there is no need to acquire locks. This method is a no-op.
   */
  @Override public void unlockChannel() {

  }

  public synchronized void close() {
    if (stopping) {
      return;
    }

    if (LOG.isDebugEnabled()) {
      LOG.debug("Closing the netty transceiver...");
    }
    try {
      // Close the connection:
      stopping = true;
      disconnect(this.channel, true, true, null);
    } finally {
      // Shut down all thread pools to exit.
      if (LOG.isDebugEnabled()) {
        LOG.debug("release channelFactory resource for " + remoteAddr);
      }
    }
  }

  @Override public String getRemoteName() throws IOException {
    return NettyUtils.getRemoteAddress(getChannel());
  }

  /**
   * Make a call, passing param, to the IPC server running at address
   * which is servicing the protocol protocol, with the ticket
   * credentials, returning the value. Throws exceptions if there are network problems or if the
   * remote code threw an exception.
   */
  public Message call(RpcRequestBody requestBody, Class protocol,
    int rpcTimeout, Callback callback) throws Exception {
    ConnectionHeader.Builder builder = ConnectionHeader.newBuilder();
    builder.setProtocol(protocol == null ? "" : protocol.getName());
    ConnectionHeader connectionHeader = builder.build();

    RpcRequestHeader.Builder headerBuilder = RPCProtos.RpcRequestHeader.newBuilder();

    RpcRequestHeader rpcHeader = headerBuilder.build();

    ByteBufferOutputStream bbo = new ByteBufferOutputStream();
    connectionHeader.writeDelimitedTo(bbo);
    rpcHeader.writeDelimitedTo(bbo);
    requestBody.writeDelimitedTo(bbo);
    CallFuture future = new CallFuture(callback);
    if (LOG.isDebugEnabled()) {
      LOG.debug("send message, " + requestBody.getMethodName() + " , channel: " + channel);
    }

    transceive(bbo.getBufferList(),
      new TransceiverCallback(requestBody, protocol, future));

    if (callback == null) {
      try {
        return future.get(
          conf.getLong(AngelConf.ANGEL_READ_TIMEOUT_SEC, AngelConf.DEFAULT_ANGEL_READ_TIMEOUT_SEC),
          TimeUnit.SECONDS);
      } catch (java.util.concurrent.TimeoutException e) {
        if (LOG.isDebugEnabled()) {
          LOG.debug(
            "timeout for: send message, " + requestBody.getMethodName() + " , channel: " + channel);
        }
        disconnect(this.channel, true, true, e);
        throw e;
      }
    }
    return null;
  }

  /**
   * Override as non-synchronized method because the method is thread safe.
   */
  @Override public List transceive(List request) throws IOException {
    try {
      CallFuture> transceiverFuture = new CallFuture>();
      transceive(request, transceiverFuture);
      return transceiverFuture.get();
    } catch (InterruptedException e) {
      LOG.info("failed to get the response", e);
      throw new IOException(e);
    } catch (ExecutionException e) {
      LOG.warn("failed to get the response", e);
      throw new IOException(e);
    }
  }

  @Override public void transceive(List request, Callback> callback) {
    int serial = serialGenerator.incrementAndGet();
    try {
      NettyDataPack dataPack = new NettyDataPack(serial, request);
      requests.put(serial, callback);
      if (LOG.isDebugEnabled()) {
        LOG.debug("send message, serial: " + serial + ", channel: " + channel);
      }
      // LOG.info("serial " + serial + "start time = " + System.currentTimeMillis());
      NettyDataPack.writeDataPack(getChannel(), dataPack);
    } catch (IOException e) {
      if (LOG.isDebugEnabled()) {
        LOG.debug("write Data error, serial: " + serial + ", channel: " + channel + " due to:", e);
      }
      requests.remove(serial);
      callback.handleError(e);
    }
  }

  @Override public void writeBuffers(List buffers) throws IOException {
    NettyDataPack dataPack = new NettyDataPack(serialGenerator.incrementAndGet(), buffers);
    NettyDataPack.writeDataPack(getChannel(), dataPack);
  }

  @Override public List readBuffers() throws IOException {
    throw new UnsupportedOperationException();
  }

  class TransceiverCallback implements Callback> {
    private final RpcRequestBody requestBody;
    private final Class protocol;
    private final Callback callback;

    /**
     * Creates a TransceiverCallback.
     *
     * @param callback the callback to set.
     */
    public TransceiverCallback(RpcRequestBody requestBody,
      Class protocol, Callback callback) {
      this.requestBody = requestBody;
      this.protocol = protocol;
      this.callback = callback;
    }

    @Override @SuppressWarnings("unchecked")
    public void handleResult(List responseBytes) {
      ByteBufferInputStream in = new ByteBufferInputStream(responseBytes);
      try {
        // See NettyServer.prepareResponse for where we write out the response.
        // It writes the call.id (int), a boolean signifying any error (and if
        // so the exception name/trace), and the response bytes

        // Read the call id.
        RpcResponseHeader response = RpcResponseHeader.parseDelimitedFrom(in);
        if (response == null) {
          LOG.error("response is null");
          // When the stream is closed, protobuf doesn't raise an EOFException,
          // instead, it returns a null message object.
          throw new EOFException();
        }

        Status status = response.getStatus();
        if (status == Status.SUCCESS) {
          Message rpcResponseType;
          try {
            rpcResponseType = ProtobufRpcEngine.Invoker.getReturnProtoType(
              ProtobufRpcEngine.Server.getMethod(protocol, requestBody.getMethodName()));
          } catch (Exception e) {
            throw new RuntimeException(e); // local exception
          }
          Builder builder = rpcResponseType.newBuilderForType();
          builder.mergeDelimitedFrom(in);
          Message value = builder.build();

          if (callback != null) {
            LOG.debug("to execute callback, method: " + requestBody.getMethodName());
            callback.handleResult((T) value);
          } else {
            if (LOG.isDebugEnabled()) {
              LOG.debug("callback is null, method: " + requestBody.getMethodName());
            }
          }
        } else if (status == Status.ERROR || status == Status.FATAL) {
          RpcException exceptionResponse = RpcException.parseDelimitedFrom(in);
          String exceptionName = exceptionResponse.getExceptionName();
          String exceptionReason = exceptionResponse.getStackTrace();
          if (exceptionName != null && exceptionName.equals(StandbyException.class.getName())) {
            handleError(new StandbyException(exceptionReason));
          } else {
            RemoteException remoteException = new RemoteException(exceptionName, exceptionReason);
            handleError(remoteException.unwrapRemoteException());
          }
        } else {
          handleError(new IOException("response status is " + status));
        }
      } catch (Exception e) {
        LOG.error("Error handling transceiver callback: " + e, e);
        handleError(e);
      }
    }

    @Override public void handleError(Throwable error) {
      callback.handleError(error);
    }
  }


  /**
   * ML client handler for the Netty transport
   */
  class MLClientMLHandler extends ChannelInboundHandlerAdapter {

    @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Remote peer " + remoteAddr + " closed channel: " + ctx.channel());
      }
      disconnect(ctx.channel(), false, true, null);
      super.channelInactive(ctx);
    }

    @Override public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
      if (!(request instanceof NettyDataPack)) {
        ctx.fireChannelRead(request);
        return;
      }
      NettyDataPack dataPack = (NettyDataPack) request;
      if (LOG.isDebugEnabled()) {
        LOG.debug(
          "messageReceived, serail: " + dataPack.getSerial() + ", channel: " + ctx.channel());
      }

      // LOG.info("method " + dataPack.getSerial() + " received ts = " +
      // System.currentTimeMillis());

      Callback> callback = requests.get(dataPack.getSerial());
      if (callback == null) {
        LOG.error(
          "Missing previous call info, serail: " + dataPack.getSerial() + ", channel: " + ctx
            .channel());
        throw new RuntimeException("Missing previous call info");
      }
      try {
        callback.handleResult(dataPack.getDatas());
      } finally {
        requests.remove(dataPack.getSerial());
      }
    }

    @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
      throws Exception {
      if (LOG.isDebugEnabled()) {
        LOG.debug("Netty Transceiver error." + "channel: " + ctx.channel(), cause);
      }
      disconnect(ctx.channel(), false, true, cause);
    }
  }

  /**
   * Increment this client's reference count
   */
  synchronized void incCount() {
    refCount++;
  }

  /**
   * Decrement this client's reference count
   */
  synchronized void decCount() {
    refCount--;
  }

  /**
   * Return if this client has no reference
   *
   * @return true if this client has no reference; false otherwise
   */
  synchronized boolean isZeroReference() {
    return refCount == 0;
  }

  /**
   * @return the remoteAddr
   */
  public InetSocketAddress getRemoteAddr() {
    return remoteAddr;
  }

  @Override public Configuration getConf() {
    return this.conf;
  }

  @Override public void setConf(Configuration conf) {
    this.conf = conf;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy