All Downloads are FREE. Search and download functionalities are using the official Maven repository.

akka.remote.transport.ThrottlerTransportAdapter.scala Maven / Gradle / Ivy

The newest version!
/**
 * Copyright (C) 2009-2014 Typesafe Inc. 
 */
package akka.remote.transport

import akka.actor._
import akka.pattern.{ PromiseActorRef, ask, pipe }
import akka.remote.transport.ActorTransportAdapter.AssociateUnderlying
import akka.remote.transport.AkkaPduCodec.Associate
import akka.remote.transport.AssociationHandle.{ DisassociateInfo, ActorHandleEventListener, Disassociated, InboundPayload, HandleEventListener }
import akka.remote.transport.ThrottlerManager.{ Listener, Handle, ListenerAndMode, Checkin }
import akka.remote.transport.ThrottlerTransportAdapter._
import akka.remote.transport.Transport._
import akka.util.{ Timeout, ByteString }
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.{ Future, Promise }
import scala.concurrent.duration._
import scala.math.min
import scala.util.{ Success, Failure }
import scala.util.control.NonFatal
import akka.dispatch.sysmsg.{ Unwatch, Watch }
import akka.dispatch.{ UnboundedMessageQueueSemantics, RequiresMessageQueue }
import akka.remote.RARP

class ThrottlerProvider extends TransportAdapterProvider {

  override def create(wrappedTransport: Transport, system: ExtendedActorSystem): Transport =
    new ThrottlerTransportAdapter(wrappedTransport, system)

}

object ThrottlerTransportAdapter {
  val SchemeIdentifier = "trttl"
  val UniqueId = new java.util.concurrent.atomic.AtomicInteger(0)

  sealed trait Direction {
    def includes(other: Direction): Boolean
  }

  object Direction {

    @SerialVersionUID(1L)
    case object Send extends Direction {
      override def includes(other: Direction): Boolean = other match {
        case Send ⇒ true
        case _    ⇒ false
      }

      /**
       * Java API: get the singleton instance
       */
      def getInstance = this
    }

    @SerialVersionUID(1L)
    case object Receive extends Direction {
      override def includes(other: Direction): Boolean = other match {
        case Receive ⇒ true
        case _       ⇒ false
      }

      /**
       * Java API: get the singleton instance
       */
      def getInstance = this
    }

    @SerialVersionUID(1L)
    case object Both extends Direction {
      override def includes(other: Direction): Boolean = true

      /**
       * Java API: get the singleton instance
       */
      def getInstance = this
    }
  }

  @SerialVersionUID(1L)
  case class SetThrottle(address: Address, direction: Direction, mode: ThrottleMode)

  @SerialVersionUID(1L)
  case object SetThrottleAck {
    /**
     * Java API: get the singleton instance
     */
    def getInstance = this
  }

  sealed trait ThrottleMode extends NoSerializationVerificationNeeded {
    def tryConsumeTokens(nanoTimeOfSend: Long, tokens: Int): (ThrottleMode, Boolean)
    def timeToAvailable(currentNanoTime: Long, tokens: Int): FiniteDuration
  }

  @SerialVersionUID(1L)
  case class TokenBucket(capacity: Int, tokensPerSecond: Double, nanoTimeOfLastSend: Long, availableTokens: Int)
    extends ThrottleMode {

    private def isAvailable(nanoTimeOfSend: Long, tokens: Int): Boolean =
      if ((tokens > capacity && availableTokens > 0)) {
        true // Allow messages larger than capacity through, it will be recorded as negative tokens
      } else min((availableTokens + tokensGenerated(nanoTimeOfSend)), capacity) >= tokens

    override def tryConsumeTokens(nanoTimeOfSend: Long, tokens: Int): (ThrottleMode, Boolean) = {
      if (isAvailable(nanoTimeOfSend, tokens))
        (this.copy(
          nanoTimeOfLastSend = nanoTimeOfSend,
          availableTokens = min(availableTokens - tokens + tokensGenerated(nanoTimeOfSend), capacity)), true)
      else (this, false)
    }

    override def timeToAvailable(currentNanoTime: Long, tokens: Int): FiniteDuration = {
      val needed = (if (tokens > capacity) 1 else tokens) - tokensGenerated(currentNanoTime)
      (needed / tokensPerSecond).seconds
    }

    private def tokensGenerated(nanoTimeOfSend: Long): Int =
      (TimeUnit.NANOSECONDS.toMillis(nanoTimeOfSend - nanoTimeOfLastSend) * tokensPerSecond / 1000.0).toInt
  }

  @SerialVersionUID(1L)
  case object Unthrottled extends ThrottleMode {
    override def tryConsumeTokens(nanoTimeOfSend: Long, tokens: Int): (ThrottleMode, Boolean) = (this, true)
    override def timeToAvailable(currentNanoTime: Long, tokens: Int): FiniteDuration = Duration.Zero

    /**
     * Java API: get the singleton instance
     */
    def getInstance = this

  }

  @SerialVersionUID(1L)
  case object Blackhole extends ThrottleMode {
    override def tryConsumeTokens(nanoTimeOfSend: Long, tokens: Int): (ThrottleMode, Boolean) = (this, false)
    override def timeToAvailable(currentNanoTime: Long, tokens: Int): FiniteDuration = Duration.Zero

    /**
     * Java API: get the singleton instance
     */
    def getInstance = this
  }

  /**
   * Management Command to force dissocation of an address.
   */
  @SerialVersionUID(1L)
  case class ForceDisassociate(address: Address)

  /**
   * Management Command to force dissocation of an address with an explicit error.
   */
  @SerialVersionUID(1L)
  case class ForceDisassociateExplicitly(address: Address, reason: DisassociateInfo)

  @SerialVersionUID(1L)
  case object ForceDisassociateAck {
    /**
     * Java API: get the singleton instance
     */
    def getInstance = this
  }
}

class ThrottlerTransportAdapter(_wrappedTransport: Transport, _system: ExtendedActorSystem) extends ActorTransportAdapter(_wrappedTransport, _system) {

  override protected def addedSchemeIdentifier = SchemeIdentifier
  override protected def maximumOverhead = 0
  protected def managerName: String = s"throttlermanager.${wrappedTransport.schemeIdentifier}${UniqueId.getAndIncrement}"
  protected def managerProps: Props = {
    val wt = wrappedTransport
    Props(classOf[ThrottlerManager], wt)
  }

  override def managementCommand(cmd: Any): Future[Boolean] = {
    import ActorTransportAdapter.AskTimeout
    cmd match {
      case s: SetThrottle                 ⇒ manager ? s map { case SetThrottleAck ⇒ true }
      case f: ForceDisassociate           ⇒ manager ? f map { case ForceDisassociateAck ⇒ true }
      case f: ForceDisassociateExplicitly ⇒ manager ? f map { case ForceDisassociateAck ⇒ true }
      case _                              ⇒ wrappedTransport.managementCommand(cmd)
    }
  }
}

/**
 * INTERNAL API
 */
private[transport] object ThrottlerManager {
  case class Checkin(origin: Address, handle: ThrottlerHandle) extends NoSerializationVerificationNeeded

  case class AssociateResult(handle: AssociationHandle, statusPromise: Promise[AssociationHandle])
    extends NoSerializationVerificationNeeded

  case class ListenerAndMode(listener: HandleEventListener, mode: ThrottleMode) extends NoSerializationVerificationNeeded

  case class Handle(handle: ThrottlerHandle) extends NoSerializationVerificationNeeded

  case class Listener(listener: HandleEventListener) extends NoSerializationVerificationNeeded
}

/**
 * INTERNAL API
 */
private[transport] class ThrottlerManager(wrappedTransport: Transport) extends ActorTransportAdapterManager {

  import ThrottlerManager._
  import context.dispatcher

  private var throttlingModes = Map[Address, (ThrottleMode, Direction)]()
  private var handleTable = List[(Address, ThrottlerHandle)]()

  private def nakedAddress(address: Address): Address = address.copy(protocol = "", system = "")

  override def ready: Receive = {
    case InboundAssociation(handle) ⇒
      val wrappedHandle = wrapHandle(handle, associationListener, inbound = true)
      wrappedHandle.throttlerActor ! Handle(wrappedHandle)
    case AssociateUnderlying(remoteAddress, statusPromise) ⇒
      wrappedTransport.associate(remoteAddress) onComplete {
        // Slight modification of pipe, only success is sent, failure is propagated to a separate future
        case Success(handle) ⇒ self ! AssociateResult(handle, statusPromise)
        case Failure(e)      ⇒ statusPromise.failure(e)
      }
    // Finished outbound association and got back the handle
    case AssociateResult(handle, statusPromise) ⇒
      val wrappedHandle = wrapHandle(handle, associationListener, inbound = false)
      val naked = nakedAddress(handle.remoteAddress)
      val inMode = getInboundMode(naked)
      wrappedHandle.outboundThrottleMode.set(getOutboundMode(naked))
      wrappedHandle.readHandlerPromise.future map { ListenerAndMode(_, inMode) } pipeTo wrappedHandle.throttlerActor
      handleTable ::= naked -> wrappedHandle
      statusPromise.success(wrappedHandle)
    case SetThrottle(address, direction, mode) ⇒
      val naked = nakedAddress(address)
      throttlingModes = throttlingModes.updated(naked, (mode, direction))
      val ok = Future.successful(SetThrottleAck)
      Future.sequence(handleTable map {
        case (`naked`, handle) ⇒ setMode(handle, mode, direction)
        case _                 ⇒ ok
      }).map(_ ⇒ SetThrottleAck) pipeTo sender()
    case ForceDisassociate(address) ⇒
      val naked = nakedAddress(address)
      handleTable foreach {
        case (`naked`, handle) ⇒ handle.disassociate()
        case _                 ⇒
      }
      sender() ! ForceDisassociateAck
    case ForceDisassociateExplicitly(address, reason) ⇒
      val naked = nakedAddress(address)
      handleTable foreach {
        case (`naked`, handle) ⇒ handle.disassociateWithFailure(reason)
        case _                 ⇒
      }
      sender() ! ForceDisassociateAck

    case Checkin(origin, handle) ⇒
      val naked: Address = nakedAddress(origin)
      handleTable ::= naked -> handle
      setMode(naked, handle)

  }

  private def getInboundMode(nakedAddress: Address): ThrottleMode = {
    throttlingModes.get(nakedAddress) match {
      case Some((mode, direction)) if direction.includes(Direction.Receive) ⇒ mode
      case _ ⇒ Unthrottled
    }
  }

  private def getOutboundMode(nakedAddress: Address): ThrottleMode = {
    throttlingModes.get(nakedAddress) match {
      case Some((mode, direction)) if direction.includes(Direction.Send) ⇒ mode
      case _ ⇒ Unthrottled
    }
  }

  private def setMode(nakedAddress: Address, handle: ThrottlerHandle): Future[SetThrottleAck.type] = {
    throttlingModes.get(nakedAddress) match {
      case Some((mode, direction)) ⇒ setMode(handle, mode, direction)
      case None                    ⇒ setMode(handle, Unthrottled, Direction.Both)
    }
  }

  private def setMode(handle: ThrottlerHandle, mode: ThrottleMode, direction: Direction): Future[SetThrottleAck.type] = {
    if (direction.includes(Direction.Send))
      handle.outboundThrottleMode.set(mode)
    if (direction.includes(Direction.Receive))
      askModeWithDeathCompletion(handle.throttlerActor, mode)(ActorTransportAdapter.AskTimeout)
    else
      Future.successful(SetThrottleAck)
  }

  private def askModeWithDeathCompletion(target: ActorRef, mode: ThrottleMode)(implicit timeout: Timeout): Future[SetThrottleAck.type] = {
    if (target.isTerminated) Future successful SetThrottleAck
    else {
      val internalTarget = target.asInstanceOf[InternalActorRef]
      val ref = PromiseActorRef(internalTarget.provider, timeout, target.toString)
      internalTarget.sendSystemMessage(Watch(internalTarget, ref))
      target.tell(mode, ref)
      ref.result.future.transform({
        case Terminated(t) if t.path == target.path ⇒ SetThrottleAck
        case SetThrottleAck                         ⇒ { internalTarget.sendSystemMessage(Unwatch(target, ref)); SetThrottleAck }
      }, t ⇒ { internalTarget.sendSystemMessage(Unwatch(target, ref)); t })(ref.internalCallingThreadExecutionContext)
    }
  }

  private def wrapHandle(originalHandle: AssociationHandle, listener: AssociationEventListener, inbound: Boolean): ThrottlerHandle = {
    val managerRef = self
    ThrottlerHandle(
      originalHandle,
      context.actorOf(
        RARP(context.system).configureDispatcher(
          Props(classOf[ThrottledAssociation], managerRef, listener, originalHandle, inbound)).withDeploy(Deploy.local),
        "throttler" + nextId()))
  }
}

/**
 * INTERNAL API
 */
private[transport] object ThrottledAssociation {
  private final val DequeueTimerName = "dequeue"

  case object Dequeue

  sealed trait ThrottlerState

  // --- Chain of states for inbound associations

  // Waiting for the ThrottlerHandle coupled with the throttler actor.
  case object WaitExposedHandle extends ThrottlerState
  // Waiting for the ASSOCIATE message that contains the origin address of the remote endpoint
  case object WaitOrigin extends ThrottlerState
  // After origin is known and a Checkin message is sent to the manager, we must wait for the ThrottlingMode for the
  // address
  case object WaitMode extends ThrottlerState
  // After all information is known, the throttler must wait for the upstream listener to be able to forward messages
  case object WaitUpstreamListener extends ThrottlerState

  // --- States for outbound associations

  // Waiting for the tuple containing the upstream listener and ThrottleMode
  case object WaitModeAndUpstreamListener extends ThrottlerState

  // Fully initialized state
  case object Throttling extends ThrottlerState

  sealed trait ThrottlerData
  case object Uninitialized extends ThrottlerData
  case class ExposedHandle(handle: ThrottlerHandle) extends ThrottlerData

  case class FailWith(reason: DisassociateInfo)
}

/**
 * INTERNAL API
 */
private[transport] class ThrottledAssociation(
  val manager: ActorRef,
  val associationHandler: AssociationEventListener,
  val originalHandle: AssociationHandle,
  val inbound: Boolean)
  extends Actor with LoggingFSM[ThrottledAssociation.ThrottlerState, ThrottledAssociation.ThrottlerData]
  with RequiresMessageQueue[UnboundedMessageQueueSemantics] {
  import ThrottledAssociation._
  import context.dispatcher

  var inboundThrottleMode: ThrottleMode = _
  var throttledMessages = Queue.empty[ByteString]
  var upstreamListener: HandleEventListener = _

  override def postStop(): Unit = originalHandle.disassociate()

  if (inbound) startWith(WaitExposedHandle, Uninitialized) else {
    originalHandle.readHandlerPromise.success(ActorHandleEventListener(self))
    startWith(WaitModeAndUpstreamListener, Uninitialized)
  }

  when(WaitExposedHandle) {
    case Event(Handle(handle), Uninitialized) ⇒
      // register to downstream layer and wait for origin
      originalHandle.readHandlerPromise.success(ActorHandleEventListener(self))
      goto(WaitOrigin) using ExposedHandle(handle)
  }

  when(WaitOrigin) {
    case Event(InboundPayload(p), ExposedHandle(exposedHandle)) ⇒
      throttledMessages = throttledMessages enqueue p
      peekOrigin(p) match {
        case Some(origin) ⇒
          manager ! Checkin(origin, exposedHandle)
          goto(WaitMode)
        case None ⇒ stay()
      }
  }

  when(WaitMode) {
    case Event(InboundPayload(p), _) ⇒
      throttledMessages = throttledMessages enqueue p
      stay()
    case Event(mode: ThrottleMode, ExposedHandle(exposedHandle)) ⇒
      inboundThrottleMode = mode
      try if (mode == Blackhole) {
        throttledMessages = Queue.empty[ByteString]
        exposedHandle.disassociate()
        stop()
      } else {
        associationHandler notify InboundAssociation(exposedHandle)
        exposedHandle.readHandlerPromise.future.map(Listener(_)) pipeTo self
        goto(WaitUpstreamListener)
      } finally sender() ! SetThrottleAck
  }

  when(WaitUpstreamListener) {
    case Event(InboundPayload(p), _) ⇒
      throttledMessages = throttledMessages enqueue p
      stay()
    case Event(Listener(listener), _) ⇒
      upstreamListener = listener
      self ! Dequeue
      goto(Throttling)
  }

  when(WaitModeAndUpstreamListener) {
    case Event(ListenerAndMode(listener: HandleEventListener, mode: ThrottleMode), _) ⇒
      upstreamListener = listener
      inboundThrottleMode = mode
      self ! Dequeue
      goto(Throttling)
    case Event(InboundPayload(p), _) ⇒
      throttledMessages = throttledMessages enqueue p
      stay()
  }

  when(Throttling) {
    case Event(mode: ThrottleMode, _) ⇒
      inboundThrottleMode = mode
      if (mode == Blackhole) throttledMessages = Queue.empty[ByteString]
      cancelTimer(DequeueTimerName)
      if (throttledMessages.nonEmpty)
        scheduleDequeue(inboundThrottleMode.timeToAvailable(System.nanoTime(), throttledMessages.head.length))
      sender() ! SetThrottleAck
      stay()
    case Event(InboundPayload(p), _) ⇒
      forwardOrDelay(p)
      stay()

    case Event(Dequeue, _) ⇒
      if (throttledMessages.nonEmpty) {
        val (payload, newqueue) = throttledMessages.dequeue
        upstreamListener notify InboundPayload(payload)
        throttledMessages = newqueue
        inboundThrottleMode = inboundThrottleMode.tryConsumeTokens(System.nanoTime(), payload.length)._1
        if (throttledMessages.nonEmpty)
          scheduleDequeue(inboundThrottleMode.timeToAvailable(System.nanoTime(), throttledMessages.head.length))
      }
      stay()

  }

  whenUnhandled {
    // we should always set the throttling mode
    case Event(mode: ThrottleMode, _) ⇒
      inboundThrottleMode = mode
      sender() ! SetThrottleAck
      stay()
    case Event(Disassociated(info), _) ⇒
      stop() // not notifying the upstream handler is intentional: we are relying on heartbeating
    case Event(FailWith(reason), _) ⇒
      upstreamListener notify Disassociated(reason)
      stop()
  }

  // This method captures ASSOCIATE packets and extracts the origin address
  private def peekOrigin(b: ByteString): Option[Address] = {
    try {
      AkkaPduProtobufCodec.decodePdu(b) match {
        case Associate(info) ⇒ Some(info.origin)
        case _               ⇒ None
      }
    } catch {
      // This layer should not care about malformed packets. Also, this also useful for testing, because
      // arbitrary payload could be passed in
      case NonFatal(e) ⇒ None
    }
  }

  def forwardOrDelay(payload: ByteString): Unit = {
    if (inboundThrottleMode == Blackhole) {
      // Do nothing
    } else {
      if (throttledMessages.isEmpty) {
        val tokens = payload.length
        val (newbucket, success) = inboundThrottleMode.tryConsumeTokens(System.nanoTime(), tokens)
        if (success) {
          inboundThrottleMode = newbucket
          upstreamListener notify InboundPayload(payload)
        } else {
          throttledMessages = throttledMessages.enqueue(payload)
          scheduleDequeue(inboundThrottleMode.timeToAvailable(System.nanoTime(), tokens))
        }
      } else {
        throttledMessages = throttledMessages.enqueue(payload)
      }
    }
  }

  def scheduleDequeue(delay: FiniteDuration): Unit = inboundThrottleMode match {
    case Blackhole                   ⇒ // Do nothing
    case _ if delay <= Duration.Zero ⇒ self ! Dequeue
    case _                           ⇒ setTimer(DequeueTimerName, Dequeue, delay, repeat = false)
  }

}

/**
 * INTERNAL API
 */
private[transport] case class ThrottlerHandle(_wrappedHandle: AssociationHandle, throttlerActor: ActorRef)
  extends AbstractTransportAdapterHandle(_wrappedHandle, SchemeIdentifier) {

  private[transport] val outboundThrottleMode = new AtomicReference[ThrottleMode](Unthrottled)

  override val readHandlerPromise: Promise[HandleEventListener] = Promise()

  override def write(payload: ByteString): Boolean = {
    val tokens = payload.length

    @tailrec def tryConsume(currentBucket: ThrottleMode): Boolean = {
      val timeOfSend = System.nanoTime()
      val (newBucket, allow) = currentBucket.tryConsumeTokens(timeOfSend, tokens)
      if (allow) {
        if (outboundThrottleMode.compareAndSet(currentBucket, newBucket)) true
        else tryConsume(outboundThrottleMode.get())
      } else false
    }

    outboundThrottleMode.get match {
      case Blackhole ⇒ true
      case bucket @ _ ⇒
        val success = tryConsume(outboundThrottleMode.get())
        if (success) wrappedHandle.write(payload) else false
      // FIXME: this depletes the token bucket even when no write happened!! See #2825
    }

  }

  override def disassociate(): Unit = {
    throttlerActor ! PoisonPill
  }

  def disassociateWithFailure(reason: DisassociateInfo): Unit = {
    throttlerActor ! ThrottledAssociation.FailWith(reason)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy