All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler.scala Maven / Gradle / Ivy

The newest version!
/*
 Copyright (c) 2014 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.rabit.handler

import java.nio.{ByteBuffer, ByteOrder}

import akka.io.Tcp
import akka.actor._
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}

import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Try

/**
  * Actor to handle socket communication from worker node.
  * To handle fragmentation in received data, this class acts like a FSM
  * (finite-state machine) to keep track of the internal states.
  *
  * @param host IP address of the remote worker
  * @param worldSize number of total workers
  * @param tracker the RabitTrackerHandler actor reference
  */
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
                                        connection: ActorRef)
  extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
    with ActorLogging with Stash {

  import RabitWorkerHandler._
  import RabitTrackerHelpers._

  private[this] var rank: Int = 0
  private[this] var port: Int = 0

  // indicate if the connection is transient (like "print" or "shutdown")
  private[this] var transient: Boolean = false
  private[this] var peerClosed: Boolean = false

  // number of workers pending acceptance of current worker
  private[this] var awaitingAcceptance: Int = 0
  private[this] var neighboringWorkers = Set.empty[Int]

  // TODO: use a single memory allocation to host all buffers,
  // including the transient ones for writing.
  private[this] val readBuffer = ByteBuffer.allocate(4096)
    .order(ByteOrder.nativeOrder())
  // in case the received message is longer than needed,
  // stash the spilled over part in this buffer, and send
  // to self when transition occurs.
  private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
    .order(ByteOrder.nativeOrder())
  // when setup is complete, need to notify peer handlers
  // to reduce the awaiting-connection counter.
  private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None

  private def resetBuffers(): Unit = {
    readBuffer.clear()
    if (spillOverBuffer.position() > 0) {
      spillOverBuffer.flip()
      self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
      spillOverBuffer.clear()
    }
  }

  private def stashSpillOver(buf: ByteBuffer): Unit = {
    if (buf.remaining() > 0) spillOverBuffer.put(buf)
  }

  def getNeighboringWorkers: Set[Int] = neighboringWorkers

  def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
    val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
    readBuffer.flip()

    val rank = readBuffer.getInt()
    val worldSize = readBuffer.getInt()
    val jobId = readBuffer.getString

    val command = readBuffer.getString
    val trackerCommand = command match {
      case "start" => WorkerStart(rank, worldSize, jobId)
      case "shutdown" =>
        transient = true
        WorkerShutdown(rank, worldSize, jobId)
      case "recover" =>
        require(rank >= 0, "Invalid rank for recovering worker.")
        WorkerRecover(rank, worldSize, jobId)
      case "print" =>
        transient = true
        WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
    }

    stashSpillOver(readBuffer)
    trackerCommand
  }

  startWith(AwaitingHandshake, DataStruct())

  when(AwaitingHandshake) {
    case Event(Tcp.Received(magic), _) =>
      assert(magic.length == 4)
      val purportedMagic = magic.asNativeOrderByteBuffer.getInt
      assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")

      // echo back the magic number
      connection ! Tcp.Write(magic)
      goto(AwaitingCommand) using StructTrackerCommand
  }

  when(AwaitingCommand) {
    case Event(Tcp.Received(bytes), validator) =>
      bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
      if (validator.verify(readBuffer)) {
        Try(decodeCommand(readBuffer)) match {
          case scala.util.Success(decodedCommand) =>
            tracker ! decodedCommand
          case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
            // BufferUnderflowException would occur if the message to print has not arrived yet.
            // Do nothing, wait for next Tcp.Received event
          case scala.util.Failure(th: Throwable) => throw th
        }
      }

      stay
    // when rank for a worker is assigned, send encoded rank information
    // back to worker over Tcp socket.
    case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
      log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")

      rank = assignedRank
      // ranks from the ring
      val ringRanks = List(
        // ringPrev
        if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
        // ringNext
        if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
      )

      // update the set of all linked workers to current worker.
      neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet

      connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
      // to prevent reading before state transition
      connection ! Tcp.SuspendReading
      goto(BuildingLinkMap) using StructNodes
  }

  when(BuildingLinkMap) {
    case Event(Tcp.Received(bytes), validator) =>
      bytes.asByteBuffers.foreach { buf =>
        readBuffer.put(buf)
      }

      if (validator.verify(readBuffer)) {
        readBuffer.flip()
        // for a freshly started worker, numConnected should be 0.
        val numConnected = readBuffer.getInt()
        val toConnectSet = neighboringWorkers.diff(
          (0 until numConnected).map { index => readBuffer.getInt() }.toSet)

        // check which workers are currently awaiting connections
        tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
      }
      stay

    // got a Future from the tracker (resolver) about workers that are
    // currently awaiting connections (particularly from this node.)
    case Event(future: Future[_], _) =>
      // blocks execution until all dependencies for current worker is resolved.
      Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
        // numNotReachable is the number of workers that currently
        // cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
        // connections from those currently non-reachable nodes in the future.
        case AwaitingConnections(waitConnNodes, numNotReachable) =>
          log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
          val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
          buf.putInt(waitConnNodes.size).putInt(numNotReachable)
          buf.flip()

          // cache this message until the final state (SetupComplete)
          pendingAcknowledgement = Some(AcknowledgeAcceptance(
            waitConnNodes, numNotReachable))

          connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
          if (waitConnNodes.isEmpty) {
            connection ! Tcp.SuspendReading
            goto(AwaitingErrorCount)
          }
          else {
            waitConnNodes.foreach { case (peerRank, peerRef) =>
              peerRef ! RequestWorkerHostPort
            }

            // a countdown for DivulgedHostPort messages.
            stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
          }
      }

    case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
      val hostBytes = peerHost.getBytes()
      val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
        .order(ByteOrder.nativeOrder())
      buffer.putInt(peerHost.length).put(hostBytes)
        .putInt(peerPort).putInt(peerRank)

      buffer.flip()
      connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))

      if (data.counter == 0) {
        // to prevent reading before state transition
        connection ! Tcp.SuspendReading
        goto(AwaitingErrorCount)
      }
      else {
        stay using data.decrement()
      }
  }

  when(AwaitingErrorCount) {
    case Event(Tcp.Received(numErrors), _) =>
      val buf = numErrors.asNativeOrderByteBuffer

      buf.getInt match {
        case 0 =>
          stashSpillOver(buf)
          goto(AwaitingPortNumber)
        case _ =>
          stashSpillOver(buf)
          goto(BuildingLinkMap) using StructNodes
      }
  }

  when(AwaitingPortNumber) {
    case Event(Tcp.Received(assignedPort), _) =>
      assert(assignedPort.length == 4)
      port = assignedPort.asNativeOrderByteBuffer.getInt
      log.debug(s"Rank $rank listening @ $host:$port")
      // wait until the worker closes connection.
      if (peerClosed) goto(SetupComplete) else stay

    case Event(Tcp.PeerClosed, _) =>
      peerClosed = true
      if (port == 0) stay else goto(SetupComplete)
  }

  when(SetupComplete) {
    case Event(ReduceWaitCount(count: Int), _) =>
      awaitingAcceptance -= count
      // check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
      if (awaitingAcceptance == 0 && peerClosed) {
        tracker ! DropFromWaitingList(rank)
        // no longer needed.
        context.stop(self)
      }
      stay

    case Event(AcknowledgeAcceptance(peers, numBad), _) =>
      awaitingAcceptance = numBad
      tracker ! WorkerStarted(host, rank, awaitingAcceptance)
      peers.values.foreach { peer =>
        peer ! ReduceWaitCount(1)
      }

      if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill

      stay

    // can only divulge the complete host and port information
    // when this worker is declared fully connected (otherwise
    // port information is still missing.)
    case Event(RequestWorkerHostPort, _) =>
      sender() ! DivulgedWorkerHostPort(rank, host, port)
      stay
  }

  onTransition {
    // reset buffer when state transitions as data becomes stale
    case _ -> SetupComplete =>
      connection ! Tcp.ResumeReading
      resetBuffers()
      if (pendingAcknowledgement.isDefined) {
        self ! pendingAcknowledgement.get
      }
    case _ =>
      connection ! Tcp.ResumeReading
      resetBuffers()
  }

  // default message handler
  whenUnhandled {
    case Event(Tcp.PeerClosed, _) =>
      peerClosed = true
      if (transient) context.stop(self)
      stay
  }
}

private[scala] object RabitWorkerHandler {
  val MAGIC_NUMBER = 0xff99

  // Finite states of this actor, which acts like a FSM.
  // The following states are defined in order as the FSM progresses.
  sealed trait State

  // [1] Initial state, awaiting worker to send magic number per protocol.
  case object AwaitingHandshake extends State
  // [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
  case object AwaitingCommand extends State
  // [3] Brokers connections between workers per ring/tree/parent link map.
  case object BuildingLinkMap extends State
  // [4] A transient state in which the worker reports the number of errors in establishing
  // connections to other peer workers. If no errors, transition to next state.
  case object AwaitingErrorCount extends State
  // [5] Awaiting the worker to report its port number for accepting connections from peer workers.
  // This port number information is later forwarded to linked workers.
  case object AwaitingPortNumber extends State
  // [6] Final state after completing the setup with the connecting worker. At this stage, the
  // worker will have closed the Tcp connection. The actor remains alive to handle messages from
  // peer actors representing workers with pending setups.
  case object SetupComplete extends State

  sealed trait DataField
  case object IntField extends DataField
  // an integer preceding the actual string
  case object StringField extends DataField
  case object IntSeqField extends DataField

  object DataStruct {
    def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
  }

  // Internal data pertaining to individual state, used to verify the validity of packets sent by
  // workers.
  case class DataStruct(fields: Seq[DataField], counter: Int) {
    /**
      * Validate whether the provided buffer is complete (i.e., contains
      * all data fields specified for this DataStruct.)
 *
      * @param buf a byte buffer containing received data.
      */
    def verify(buf: ByteBuffer): Boolean = {
      if (fields.isEmpty) return true

      val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
      dupBuf.flip()

      Try(fields.foldLeft(true) {
        case (complete, field) =>
          val remBytes = dupBuf.remaining()
          complete && (remBytes > 0) && (remBytes >= (field match {
            case IntField =>
              dupBuf.position(dupBuf.position() + 4)
              4
            case StringField =>
              val strLen = dupBuf.getInt
              dupBuf.position(dupBuf.position() + strLen)
              4 + strLen
            case IntSeqField =>
              val seqLen = dupBuf.getInt
              dupBuf.position(dupBuf.position() + seqLen * 4)
              4 + seqLen * 4
          }))
      }).getOrElse(false)
    }

    def increment(): DataStruct = DataStruct(fields, counter + 1)
    def decrement(): DataStruct = DataStruct(fields, counter - 1)
  }

  val StructNodes = DataStruct(List(IntSeqField), 0)
  val StructTrackerCommand = DataStruct(List(
    IntField, IntField, StringField, StringField
  ), 0)

  // ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----

  // RabitWorkerHandler --> RabitTrackerHandler
  sealed trait RabitWorkerRequest
  // RabitWorkerHandler <-- RabitTrackerHandler
  sealed trait RabitWorkerResponse

  // Representations of decoded worker commands.
  abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
    def rank: Int
    def worldSize: Int
    def jobId: String

    def encode: ByteString = {
      val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
        .order(ByteOrder.nativeOrder())

      buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
        .putInt(command.length).put(command.getBytes()).flip()

      ByteString.fromByteBuffer(buf)
    }
  }

  case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
    extends TrackerCommand("start")
  case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
    extends TrackerCommand("shutdown")
  case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
    extends TrackerCommand("recover")
  case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
    extends TrackerCommand("print") {

    override def encode: ByteString = {
      val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
        .order(ByteOrder.nativeOrder())

      buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
        .putInt(command.length).put(command.getBytes())
        .putInt(msg.length).put(msg.getBytes()).flip()

      ByteString.fromByteBuffer(buf)
    }
  }

  // Request to remove the worker of given rank from the list of workers awaiting peer connections.
  case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
  // Notify the tracker that the worker of given rank has finished setup and started.
  case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
    extends RabitWorkerRequest
  // Request the set of workers to connect to, according to the LinkMap structure.
  case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
    extends RabitWorkerRequest

  // Request, from the tracker, the set of nodes to connect.
  case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
    extends RabitWorkerResponse

  // ---- Messages between ConnectionHandler actors ----
  sealed trait IntraWorkerMessage

  // Notify neighboring workers to decrease the counter of awaiting workers by `count`.
  case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
  // Request host and port information from peer ConnectionHandler actors (acting on behave of
  // connecting workers.) This message will be brokered by RabitTrackerHandler.
  case object RequestWorkerHostPort extends IntraWorkerMessage
  // Response to the above request
  case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
  // A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
  case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
    extends IntraWorkerMessage

  // ---- End of message definitions ----

  def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
    Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy