All Downloads are FREE. Search and download functionalities are using the official Maven repository.

akka.remote.transport.netty.NettyTransport.scala Maven / Gradle / Ivy

The newest version!
/**
 * Copyright (C) 2009-2014 Typesafe Inc. 
 */
package akka.remote.transport.netty

import akka.actor.{ Address, ExtendedActorSystem }
import akka.dispatch.ThreadPoolConfig
import akka.event.Logging
import akka.remote.transport.AssociationHandle.HandleEventListener
import akka.remote.transport.Transport._
import akka.remote.transport.netty.NettyTransportSettings.{ Udp, Tcp, Mode }
import akka.remote.transport.{ AssociationHandle, Transport }
import akka.{ OnlyCauseStackTrace, ConfigurationException }
import com.typesafe.config.Config
import java.net.{ UnknownHostException, SocketAddress, InetAddress, InetSocketAddress, ConnectException }
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ ConcurrentHashMap, Executors, CancellationException }
import org.jboss.netty.bootstrap.{ ConnectionlessBootstrap, Bootstrap, ClientBootstrap, ServerBootstrap }
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
import org.jboss.netty.channel._
import org.jboss.netty.channel.group.{ DefaultChannelGroup, ChannelGroup, ChannelGroupFuture, ChannelGroupFutureListener }
import org.jboss.netty.channel.socket.nio.{ NioWorkerPool, NioDatagramChannelFactory, NioServerSocketChannelFactory, NioClientSocketChannelFactory }
import org.jboss.netty.handler.codec.frame.{ LengthFieldBasedFrameDecoder, LengthFieldPrepender }
import org.jboss.netty.handler.ssl.SslHandler
import scala.concurrent.duration.{ Duration, FiniteDuration, MILLISECONDS }
import scala.concurrent.{ ExecutionContext, Promise, Future, blocking }
import scala.util.{ Failure, Success, Try }
import scala.util.control.{ NoStackTrace, NonFatal }
import akka.util.Helpers.Requiring
import akka.util.Helpers
import akka.remote.RARP
import org.jboss.netty.util.HashedWheelTimer

object NettyTransportSettings {
  sealed trait Mode
  case object Tcp extends Mode { override def toString = "tcp" }
  case object Udp extends Mode { override def toString = "udp" }
}

object NettyFutureBridge {
  def apply(nettyFuture: ChannelFuture): Future[Channel] = {
    val p = Promise[Channel]()
    nettyFuture.addListener(new ChannelFutureListener {
      def operationComplete(future: ChannelFuture): Unit = p complete Try(
        if (future.isSuccess) future.getChannel
        else if (future.isCancelled) throw new CancellationException
        else throw future.getCause)
    })
    p.future
  }

  def apply(nettyFuture: ChannelGroupFuture): Future[ChannelGroup] = {
    import scala.collection.JavaConverters._
    val p = Promise[ChannelGroup]
    nettyFuture.addListener(new ChannelGroupFutureListener {
      def operationComplete(future: ChannelGroupFuture): Unit = p complete Try(
        if (future.isCompleteSuccess) future.getGroup
        else throw future.iterator.asScala.collectFirst {
          case f if f.isCancelled ⇒ new CancellationException
          case f if !f.isSuccess  ⇒ f.getCause
        } getOrElse new IllegalStateException("Error reported in ChannelGroupFuture, but no error found in individual futures."))
    })
    p.future
  }
}

@SerialVersionUID(1L)
class NettyTransportException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) with OnlyCauseStackTrace {
  def this(msg: String) = this(msg, null)
}

class NettyTransportSettings(config: Config) {

  import akka.util.Helpers.ConfigOps
  import config._

  val TransportMode: Mode = getString("transport-protocol") match {
    case "tcp"   ⇒ Tcp
    case "udp"   ⇒ Udp
    case unknown ⇒ throw new ConfigurationException(s"Unknown transport: [$unknown]")
  }

  val EnableSsl: Boolean = getBoolean("enable-ssl") requiring (!_ || TransportMode == Tcp, s"$TransportMode does not support SSL")

  val UseDispatcherForIo: Option[String] = getString("use-dispatcher-for-io") match {
    case "" | null  ⇒ None
    case dispatcher ⇒ Some(dispatcher)
  }

  private[this] def optionSize(s: String): Option[Int] = getBytes(s).toInt match {
    case 0          ⇒ None
    case x if x < 0 ⇒ throw new ConfigurationException(s"Setting '$s' must be 0 or positive (and fit in an Int)")
    case other      ⇒ Some(other)
  }

  val ConnectionTimeout: FiniteDuration = config.getMillisDuration("connection-timeout")

  val WriteBufferHighWaterMark: Option[Int] = optionSize("write-buffer-high-water-mark")

  val WriteBufferLowWaterMark: Option[Int] = optionSize("write-buffer-low-water-mark")

  val SendBufferSize: Option[Int] = optionSize("send-buffer-size")

  val ReceiveBufferSize: Option[Int] = optionSize("receive-buffer-size") requiring (s ⇒
    s.isDefined || TransportMode != Udp, "receive-buffer-size must be specified for UDP")

  val MaxFrameSize: Int = getBytes("maximum-frame-size").toInt requiring (
    _ >= 32000,
    s"Setting 'maximum-frame-size' must be at least 32000 bytes")

  val Backlog: Int = getInt("backlog")

  val TcpNodelay: Boolean = getBoolean("tcp-nodelay")

  val TcpKeepalive: Boolean = getBoolean("tcp-keepalive")

  val TcpReuseAddr: Boolean = getString("tcp-reuse-addr") match {
    case "off-for-windows" ⇒ !Helpers.isWindows
    case _                 ⇒ getBoolean("tcp-reuse-addr")
  }

  val Hostname: String = getString("hostname") match {
    case ""    ⇒ InetAddress.getLocalHost.getHostAddress
    case value ⇒ value
  }

  @deprecated("WARNING: This should only be used by professionals.", "2.0")
  val PortSelector: Int = getInt("port")

  val SslSettings: Option[SSLSettings] = if (EnableSsl) Some(new SSLSettings(config.getConfig("security"))) else None

  val ServerSocketWorkerPoolSize: Int = computeWPS(config.getConfig("server-socket-worker-pool"))

  val ClientSocketWorkerPoolSize: Int = computeWPS(config.getConfig("client-socket-worker-pool"))

  private def computeWPS(config: Config): Int =
    ThreadPoolConfig.scaledPoolSize(
      config.getInt("pool-size-min"),
      config.getDouble("pool-size-factor"),
      config.getInt("pool-size-max"))

}

/**
 * INTERNAL API
 */
private[netty] trait CommonHandlers extends NettyHelpers {
  protected val transport: NettyTransport

  final override def onOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent): Unit = transport.channelGroup.add(e.getChannel)

  protected def createHandle(channel: Channel, localAddress: Address, remoteAddress: Address): AssociationHandle

  protected def registerListener(channel: Channel,
                                 listener: HandleEventListener,
                                 msg: ChannelBuffer,
                                 remoteSocketAddress: InetSocketAddress): Unit

  final protected def init(channel: Channel, remoteSocketAddress: SocketAddress, remoteAddress: Address, msg: ChannelBuffer)(
    op: (AssociationHandle ⇒ Any)): Unit = {
    import transport._
    NettyTransport.addressFromSocketAddress(channel.getLocalAddress, schemeIdentifier, system.name, Some(settings.Hostname)) match {
      case Some(localAddress) ⇒
        val handle = createHandle(channel, localAddress, remoteAddress)
        handle.readHandlerPromise.future.onSuccess {
          case listener: HandleEventListener ⇒
            registerListener(channel, listener, msg, remoteSocketAddress.asInstanceOf[InetSocketAddress])
            channel.setReadable(true)
        }
        op(handle)

      case _ ⇒ NettyTransport.gracefulClose(channel)
    }
  }
}

/**
 * INTERNAL API
 */
private[netty] abstract class ServerHandler(protected final val transport: NettyTransport,
                                            private final val associationListenerFuture: Future[AssociationEventListener])
  extends NettyServerHelpers with CommonHandlers {

  import transport.executionContext

  final protected def initInbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
    channel.setReadable(false)
    associationListenerFuture.onSuccess {
      case listener: AssociationEventListener ⇒
        val remoteAddress = NettyTransport.addressFromSocketAddress(remoteSocketAddress, transport.schemeIdentifier,
          transport.system.name, hostName = None).getOrElse(
            throw new NettyTransportException(s"Unknown inbound remote address type [${remoteSocketAddress.getClass.getName}]"))
        init(channel, remoteSocketAddress, remoteAddress, msg) { listener notify InboundAssociation(_) }
    }
  }

}

/**
 * INTERNAL API
 */
private[netty] abstract class ClientHandler(protected final val transport: NettyTransport, remoteAddress: Address)
  extends NettyClientHelpers with CommonHandlers {
  final protected val statusPromise = Promise[AssociationHandle]()
  def statusFuture = statusPromise.future

  final protected def initOutbound(channel: Channel, remoteSocketAddress: SocketAddress, msg: ChannelBuffer): Unit = {
    init(channel, remoteSocketAddress, remoteAddress, msg)(statusPromise.success)
  }

}

/**
 * INTERNAL API
 */
private[transport] object NettyTransport {
  // 4 bytes will be used to represent the frame length. Used by netty LengthFieldPrepender downstream handler.
  val FrameLengthFieldLength = 4
  def gracefulClose(channel: Channel)(implicit ec: ExecutionContext): Unit = {
    def always(c: ChannelFuture) = NettyFutureBridge(c) recover { case _ ⇒ c.getChannel }
    for {
      _ ← always { channel.write(ChannelBuffers.buffer(0)) } // Force flush by waiting on a final dummy write
      _ ← always { channel.disconnect() }
    } channel.close()
  }

  val uniqueIdCounter = new AtomicInteger(0)

  def addressFromSocketAddress(addr: SocketAddress, schemeIdentifier: String, systemName: String,
                               hostName: Option[String]): Option[Address] = addr match {
    case sa: InetSocketAddress ⇒ Some(Address(schemeIdentifier, systemName,
      hostName.getOrElse(sa.getAddress.getHostAddress), sa.getPort)) // perhaps use getHostString in jdk 1.7
    case _ ⇒ None
  }
}

// FIXME: Split into separate UDP and TCP classes
class NettyTransport(val settings: NettyTransportSettings, val system: ExtendedActorSystem) extends Transport {

  def this(system: ExtendedActorSystem, conf: Config) = this(new NettyTransportSettings(conf), system)

  import NettyTransport._
  import settings._

  implicit val executionContext: ExecutionContext =
    settings.UseDispatcherForIo.orElse(RARP(system).provider.remoteSettings.Dispatcher match {
      case ""             ⇒ None
      case dispatcherName ⇒ Some(dispatcherName)
    }).map(system.dispatchers.lookup).getOrElse(system.dispatcher)

  override val schemeIdentifier: String = (if (EnableSsl) "ssl." else "") + TransportMode
  override def maximumPayloadBytes: Int = settings.MaxFrameSize

  private final val isDatagram = TransportMode == Udp

  @volatile private var localAddress: Address = _
  @volatile private var serverChannel: Channel = _

  private val log = Logging(system, this.getClass)

  /**
   * INTERNAL API
   */
  private[netty] final val udpConnectionTable = new ConcurrentHashMap[SocketAddress, HandleEventListener]()

  private def createExecutorService() =
    UseDispatcherForIo.map(system.dispatchers.lookup) getOrElse Executors.newCachedThreadPool(system.threadFactory)

  /*
   * Be aware, that the close() method of DefaultChannelGroup is racy, because it uses an iterator over a ConcurrentHashMap.
   * In the old remoting this was handled by using a custom subclass, guarding the close() method with a write-lock.
   * The usage of this class is safe in the new remoting, as close() is called after unbind() is finished, and no
   * outbound connections are initiated in the shutdown phase.
   */
  val channelGroup = new DefaultChannelGroup("akka-netty-transport-driver-channelgroup-" +
    uniqueIdCounter.getAndIncrement)

  private val clientChannelFactory: ChannelFactory = TransportMode match {
    case Tcp ⇒
      val boss, worker = createExecutorService()
      // We need to create a HashedWheelTimer here since Netty creates one with a thread that
      // doesn't respect the akka.daemonic setting
      new NioClientSocketChannelFactory(boss, 1, new NioWorkerPool(worker, ClientSocketWorkerPoolSize),
        new HashedWheelTimer(system.threadFactory))
    case Udp ⇒
      // This does not create a HashedWheelTimer internally
      new NioDatagramChannelFactory(createExecutorService(), ClientSocketWorkerPoolSize)
  }

  private val serverChannelFactory: ChannelFactory = TransportMode match {
    case Tcp ⇒
      val boss, worker = createExecutorService()
      // This does not create a HashedWheelTimer internally
      new NioServerSocketChannelFactory(boss, worker, ServerSocketWorkerPoolSize)
    case Udp ⇒
      // This does not create a HashedWheelTimer internally
      new NioDatagramChannelFactory(createExecutorService(), ServerSocketWorkerPoolSize)
  }

  private def newPipeline: DefaultChannelPipeline = {
    val pipeline = new DefaultChannelPipeline

    if (!isDatagram) {
      pipeline.addLast("FrameDecoder", new LengthFieldBasedFrameDecoder(
        maximumPayloadBytes,
        0,
        FrameLengthFieldLength,
        0,
        FrameLengthFieldLength, // Strip the header
        true))
      pipeline.addLast("FrameEncoder", new LengthFieldPrepender(FrameLengthFieldLength))
    }

    pipeline
  }

  private val associationListenerPromise: Promise[AssociationEventListener] = Promise()

  private def sslHandler(isClient: Boolean): SslHandler = {
    val handler = NettySSLSupport(settings.SslSettings.get, log, isClient)
    handler.setCloseOnSSLException(true)
    handler
  }

  private val serverPipelineFactory: ChannelPipelineFactory = new ChannelPipelineFactory {
    override def getPipeline: ChannelPipeline = {
      val pipeline = newPipeline
      if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = false))
      val handler = if (isDatagram) new UdpServerHandler(NettyTransport.this, associationListenerPromise.future)
      else new TcpServerHandler(NettyTransport.this, associationListenerPromise.future)
      pipeline.addLast("ServerHandler", handler)
      pipeline
    }
  }

  private def clientPipelineFactory(remoteAddress: Address): ChannelPipelineFactory =
    new ChannelPipelineFactory {
      override def getPipeline: ChannelPipeline = {
        val pipeline = newPipeline
        if (EnableSsl) pipeline.addFirst("SslHandler", sslHandler(isClient = true))
        val handler = if (isDatagram) new UdpClientHandler(NettyTransport.this, remoteAddress)
        else new TcpClientHandler(NettyTransport.this, remoteAddress)
        pipeline.addLast("clienthandler", handler)
        pipeline
      }
    }

  private def setupBootstrap[B <: Bootstrap](bootstrap: B, pipelineFactory: ChannelPipelineFactory): B = {
    bootstrap.setPipelineFactory(pipelineFactory)
    bootstrap.setOption("backlog", settings.Backlog)
    bootstrap.setOption("tcpNoDelay", settings.TcpNodelay)
    bootstrap.setOption("child.keepAlive", settings.TcpKeepalive)
    bootstrap.setOption("reuseAddress", settings.TcpReuseAddr)
    if (isDatagram) bootstrap.setOption("receiveBufferSizePredictorFactory", new FixedReceiveBufferSizePredictorFactory(ReceiveBufferSize.get))
    settings.ReceiveBufferSize.foreach(sz ⇒ bootstrap.setOption("receiveBufferSize", sz))
    settings.SendBufferSize.foreach(sz ⇒ bootstrap.setOption("sendBufferSize", sz))
    settings.WriteBufferHighWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferHighWaterMark", sz))
    settings.WriteBufferLowWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferLowWaterMark", sz))
    bootstrap
  }

  private val inboundBootstrap: Bootstrap = settings.TransportMode match {
    case Tcp ⇒ setupBootstrap(new ServerBootstrap(serverChannelFactory), serverPipelineFactory)
    case Udp ⇒ setupBootstrap(new ConnectionlessBootstrap(serverChannelFactory), serverPipelineFactory)
  }

  private def outboundBootstrap(remoteAddress: Address): ClientBootstrap = {
    val bootstrap = setupBootstrap(new ClientBootstrap(clientChannelFactory), clientPipelineFactory(remoteAddress))
    bootstrap.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
    bootstrap.setOption("tcpNoDelay", settings.TcpNodelay)
    bootstrap.setOption("keepAlive", settings.TcpKeepalive)
    settings.ReceiveBufferSize.foreach(sz ⇒ bootstrap.setOption("receiveBufferSize", sz))
    settings.SendBufferSize.foreach(sz ⇒ bootstrap.setOption("sendBufferSize", sz))
    settings.WriteBufferHighWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferHighWaterMark", sz))
    settings.WriteBufferLowWaterMark.foreach(sz ⇒ bootstrap.setOption("writeBufferLowWaterMark", sz))
    bootstrap
  }

  override def isResponsibleFor(address: Address): Boolean = true //TODO: Add configurable subnet filtering

  // TODO: This should be factored out to an async (or thread-isolated) name lookup service #2960
  def addressToSocketAddress(addr: Address): Future[InetSocketAddress] = addr match {
    case Address(_, _, Some(host), Some(port)) ⇒ Future { blocking { new InetSocketAddress(InetAddress.getByName(host), port) } }
    case _                                     ⇒ Future.failed(new IllegalArgumentException(s"Address [$addr] does not contain host or port information."))
  }

  override def listen: Future[(Address, Promise[AssociationEventListener])] = {
    for {
      address ← addressToSocketAddress(Address("", "", settings.Hostname, settings.PortSelector))
    } yield {
      try {
        val newServerChannel = inboundBootstrap match {
          case b: ServerBootstrap         ⇒ b.bind(address)
          case b: ConnectionlessBootstrap ⇒ b.bind(address)
        }

        // Block reads until a handler actor is registered
        newServerChannel.setReadable(false)
        channelGroup.add(newServerChannel)

        serverChannel = newServerChannel

        addressFromSocketAddress(newServerChannel.getLocalAddress, schemeIdentifier, system.name, Some(settings.Hostname)) match {
          case Some(address) ⇒
            localAddress = address
            associationListenerPromise.future.onSuccess { case listener ⇒ newServerChannel.setReadable(true) }
            (address, associationListenerPromise)
          case None ⇒ throw new NettyTransportException(s"Unknown local address type [${newServerChannel.getLocalAddress.getClass.getName}]")
        }
      } catch {
        case NonFatal(e) ⇒ {
          log.error("failed to bind to {}, shutting down Netty transport", address)
          try { shutdown() } catch { case NonFatal(e) ⇒ } // ingore possible exception during shutdown
          throw e
        }
      }
    }
  }

  override def associate(remoteAddress: Address): Future[AssociationHandle] = {
    if (!serverChannel.isBound) Future.failed(new NettyTransportException("Transport is not bound"))
    else {
      val bootstrap: ClientBootstrap = outboundBootstrap(remoteAddress)

      (for {
        socketAddress ← addressToSocketAddress(remoteAddress)
        readyChannel ← NettyFutureBridge(bootstrap.connect(socketAddress)) map {
          channel ⇒
            if (EnableSsl)
              blocking {
                channel.getPipeline.get(classOf[SslHandler]).handshake().awaitUninterruptibly()
              }
            if (!isDatagram) channel.setReadable(false)
            channel
        }
        handle ← if (isDatagram)
          Future {
            readyChannel.getRemoteAddress match {
              case addr: InetSocketAddress ⇒
                val handle = new UdpAssociationHandle(localAddress, remoteAddress, readyChannel, NettyTransport.this)
                handle.readHandlerPromise.future.onSuccess {
                  case listener ⇒ udpConnectionTable.put(addr, listener)
                }
                handle
              case unknown ⇒ throw new NettyTransportException(s"Unknown outbound remote address type [${unknown.getClass.getName}]")
            }
          }
        else
          readyChannel.getPipeline.get(classOf[ClientHandler]).statusFuture
      } yield handle) recover {
        case c: CancellationException ⇒ throw new NettyTransportException("Connection was cancelled") with NoStackTrace
        case u @ (_: UnknownHostException | _: SecurityException | _: ConnectException) ⇒ throw new InvalidAssociationException(u.getMessage, u.getCause)
        case NonFatal(t) ⇒ throw new NettyTransportException(t.getMessage, t.getCause) with NoStackTrace
      }
    }
  }

  override def shutdown(): Future[Boolean] = {
    def always(c: ChannelGroupFuture) = NettyFutureBridge(c).map(_ ⇒ true) recover { case _ ⇒ false }
    for {
      // Force flush by trying to write an empty buffer and wait for success
      unbindStatus ← always(channelGroup.unbind())
      lastWriteStatus ← always(channelGroup.write(ChannelBuffers.buffer(0)))
      disconnectStatus ← always(channelGroup.disconnect())
      closeStatus ← always(channelGroup.close())
    } yield {
      // Release the selectors, but don't try to kill the dispatcher
      if (UseDispatcherForIo.isDefined) {
        clientChannelFactory.shutdown()
        serverChannelFactory.shutdown()
      } else {
        clientChannelFactory.releaseExternalResources()
        serverChannelFactory.releaseExternalResources()
      }
      lastWriteStatus && unbindStatus && disconnectStatus && closeStatus
    }

  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy