All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler.scala Maven / Gradle / Ivy

The newest version!
/*
 Copyright (c) 2014 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.rabit.handler

import java.net.InetSocketAddress
import java.util.UUID

import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}

import scala.util.{Failure, Random, Success, Try}

/** The Akka actor for handling and coordinating Rabit worker connections.
  * This is the main actor for handling socket connections, interacting with the synchronous
  * tracker interface, and resolving tree/ring/parent dependencies between workers.
  *
  * @param numWorkers Number of workers to track.
  */
private[scala] class RabitTrackerHandler(numWorkers: Int)
  extends Actor with ActorLogging {

  import context.system
  import RabitWorkerHandler._
  import RabitTrackerHandler._

  private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
  private[this] val promisedShutdownWorkers = Promise[Int]()
  private[this] val tcpManager = IO(Tcp)

  // resolves worker connection dependency.
  val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")

  // workers that have sent "shutdown" signal
  private[this] val shutdownWorkers = mutable.Set.empty[Int]
  private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
  private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
  private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
  private[this] var maxPortTrials = 0
  private[this] var workerConnectionTimeout: Duration = Duration.Inf
  private[this] var portTrials = 0
  private[this] val startedWorkers = mutable.Set.empty[Int]

  val linkMap = new LinkMap(numWorkers)

  def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
    rank match {
      case r if r >= 0 => Some(r)
      case _ =>
        jobId match {
          case "NULL" => None
          case jid => jobToRankMap.get(jid)
        }
    }
  }

  /**
    * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
    * by the RabitWorkerHandler.
    *
    * @param event Generic Tcp.Event
    */
  private def handleTcpEvents(event: Tcp.Event): Unit = event match {
    case Tcp.Bound(local) =>
      // expect all workers to connect within timeout
      log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
      log.info(s"Worker connection timeout is $workerConnectionTimeout.")

      context.setReceiveTimeout(workerConnectionTimeout)
      promisedWorkerEnvs.success(Map(
        "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
        "DMLC_TRACKER_PORT" -> local.getPort.toString,
        // not required because the world size will be communicated to the
        // worker node after the rank is assigned.
        "rabit_world_size" -> numWorkers.toString
      ))

    case Tcp.CommandFailed(cmd: Tcp.Bind) =>
      if (portTrials < maxPortTrials) {
        portTrials += 1
        tcpManager ! Tcp.Bind(self,
          new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
          backlog = 256)
      }

    case Tcp.Connected(remote, local) =>
      log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
      // revoke timeout if all workers have connected.
      val workerHandler = context.actorOf(RabitWorkerHandler.props(
        remote.getAddress.getHostAddress, numWorkers, self, sender()
      ), s"ConnectionHandler-${UUID.randomUUID().toString}")
      val connection = sender()
      connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)

      actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
  }

  /**
    * Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
    * to interact with the tracker interface.
    *
    * @param trackerMsg control messages sent by RabitTracker class.
    */
  private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
    trackerMsg match {

    case msg: StartTracker =>
      maxPortTrials = msg.maxPortTrials
      workerConnectionTimeout = msg.connectionTimeout

      // if the port number is missing, try binding to a random ephemeral port.
      if (msg.addr.getPort == 0) {
        tcpManager ! Tcp.Bind(self,
          new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
          backlog = 256)
      } else {
        tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
      }
      sender() ! true

    case RequestBoundFuture =>
      sender() ! promisedWorkerEnvs.future

    case RequestCompletionFuture =>
      sender() ! promisedShutdownWorkers.future

    case InterruptTracker(e) =>
      log.error(e, "Uncaught exception thrown by worker.")
      // make sure that waitFor() does not hang indefinitely.
      promisedShutdownWorkers.failure(e)
      context.stop(self)
  }

  /**
    * Handles messages sent by child actors representing connecting Rabit workers, by brokering
    * messages to the dependency resolver, and processing worker commands.
    *
    * @param workerMsg Message sent by RabitWorkerHandler actors.
    */
  private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
    case req @ RequestAwaitConnWorkers(_, _) =>
      // since the requester may request to connect to other workers
      // that have not fully set up, delegate this request to the
      // dependency resolver which handles the dependencies properly.
      resolver forward req

    // ---- Rabit worker commands: start/recover/shutdown/print ----
    case WorkerTrackerPrint(_, _, _, msg) =>
      log.info(msg.trim)

    case WorkerShutdown(rank, _, _) =>
      assert(rank >= 0, "Invalid rank.")
      assert(!shutdownWorkers.contains(rank))
      shutdownWorkers.add(rank)

      log.info(s"Received shutdown signal from $rank")

      if (shutdownWorkers.size == numWorkers) {
        promisedShutdownWorkers.success(shutdownWorkers.size)
      }

    case WorkerRecover(prevRank, worldSize, jobId) =>
      assert(prevRank >= 0)
      sender() ! linkMap.assignRank(prevRank)

    case WorkerStart(rank, worldSize, jobId) =>
      assert(worldSize == numWorkers || worldSize == -1,
        s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
      )

      Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
        case Success(wkRank) =>
          if (jobId != "NULL") {
            jobToRankMap.put(jobId, wkRank)
          }

          val assignedRank = linkMap.assignRank(wkRank)
          sender() ! assignedRank
          resolver ! assignedRank

          log.info("Received start signal from " +
            s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")

        case Failure(ex: IndexOutOfBoundsException) =>
          // More than worldSize workers have connected, likely due to executor loss.
          // Since Rabit currently does not support crash recovery (because the Allreduce results
          // are not cached by the tracker, and because existing workers cannot reestablish
          // connections to newly spawned executor/worker), the most reasonble action here is to
          // interrupt the tracker immediate with failure state.
          log.error("Received invalid start signal from " +
            s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
          )
          promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
            " received from worker, likely due to executor loss."))

        case Failure(ex) =>
          log.error(ex, "Unexpected error")
          promisedShutdownWorkers.failure(ex)
      }


    // ---- Dependency resolving related messages ----
    case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
      log.info(s"Worker $host (rank: $rank) has started.")
      resolver forward msg

      startedWorkers.add(rank)
      if (startedWorkers.size == numWorkers) {
        log.info("All workers have started.")
      }

    case req @ DropFromWaitingList(_) =>
      // all peer workers in dependency link map have connected;
      // forward message to resolver to update dependencies.
      resolver forward req

    case _ =>
  }

  def receive: Actor.Receive = {
    case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
    case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
    case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)

    case akka.actor.ReceiveTimeout =>
      if (startedWorkers.size < numWorkers) {
        promisedShutdownWorkers.failure(
          new TimeoutException("Timed out waiting for workers to connect: " +
            s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
        )
        context.stop(self)
      }

      context.setReceiveTimeout(Duration.Undefined)
  }
}

/**
  * Resolve the dependency between nodes as they connect to the tracker.
  * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
  * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
  * connections by workers, this dependency constraint assumes that a worker node connects first
  * is likely to finish setup first.
  */
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
  import RabitWorkerHandler._

  context.watch(handler)

  case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])

  // worker nodes that have connected, but have not send WorkerStarted message.
  private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
  private val startedWorkers = mutable.Set.empty[Int]
  // worker nodes that have started, and await for connections.
  private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
  private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]

  def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
    val connSet = awaitConnWorkers.toMap
      .filterKeys(k => linkSet.contains(k))
    AwaitingConnections(connSet, linkSet.size - connSet.size)
  }

  def receive: Actor.Receive = {
    // a copy of the AssignedRank message that is also sent to the worker
    case AssignedRank(rank, tree_neighbors, ring, parent) =>
      // the workers that the worker of given `rank` depends on:
      // worker of rank K only depends on workers with rank smaller than K.
      val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
        .filter{ r => r != -1 && r < rank}

      log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
      dependencyMap.put(rank, dependentWorkers)

    case RequestAwaitConnWorkers(rank, toConnectSet) =>
      val promise = Promise[AwaitingConnections]()

      assert(dependencyMap.contains(rank))

      val updatedDependency = dependencyMap(rank) diff startedWorkers
      if (updatedDependency.isEmpty) {
        // all dependencies are satisfied
        log.debug(s"Rank $rank has all dependencies satisfied.")
        promise.success(awaitingWorkers(toConnectSet))
      } else {
        log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
        // promise is pending fulfillment due to unresolved dependency
        pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
      }

      sender() ! promise.future

    case WorkerStarted(_, started, awaitingAcceptance) =>
      startedWorkers.add(started)
      if (awaitingAcceptance > 0) {
        awaitConnWorkers.put(started, sender())
      }

      // remove the started rank from all dependencies.
      dependencyMap.remove(started)
      dependencyMap.foreach { case (r, dset) =>
        val updatedDependency = dset diff startedWorkers
        // fulfill the future if all dependencies are met (started.)
        if (updatedDependency.isEmpty) {
          log.debug(s"Rank $r has all dependencies satisfied.")
          pendingFulfillment.remove(r).map{
            case Fulfillment(toConnectSet, promise) =>
              promise.success(awaitingWorkers(toConnectSet))
          }
        }

        dependencyMap.update(r, updatedDependency)
      }

    case DropFromWaitingList(rank) =>
      assert(awaitConnWorkers.remove(rank).isDefined)

    case Terminated(ref) =>
      if (ref.equals(handler)) {
        context.stop(self)
      }
  }
}

private[scala] object RabitTrackerHandler {
  // Messages sent by RabitTracker to this RabitTrackerHandler actor
  trait TrackerControlMessage
  case object RequestCompletionFuture extends TrackerControlMessage
  case object RequestBoundFuture extends TrackerControlMessage
  // Start the Rabit tracker at given socket address awaiting worker connections.
  // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
  // shut down due to timeout.
  case class StartTracker(addr: InetSocketAddress,
                          maxPortTrials: Int,
                          connectionTimeout: Duration) extends TrackerControlMessage
  // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
  // driver for the distributed training.
  case class InterruptTracker(e: Throwable) extends TrackerControlMessage

  def props(numWorkers: Int): Props =
    Props(new RabitTrackerHandler(numWorkers))
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy