All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.scala.rabit.util.LinkMap.scala Maven / Gradle / Ivy

The newest version!
/*
 Copyright (c) 2014 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.scala.rabit.util

import java.nio.{ByteBuffer, ByteOrder}

/**
  * The assigned rank to a connecting Rabit worker, along with the information of the ranks of
  * its linked peer workers, which are critical to perform Allreduce.
  * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
  * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
  * with this class as a message, which is later encoded to byte string, and sent over socket
  * connection to the worker client.
  *
  * @param rank assigned rank (ranked by worker connection order: first worker connecting to the
  *             tracker is assigned rank 0, second with rank 1, etc.)
  * @param neighbors ranks of neighboring workers in a tree map.
  * @param ring ranks of neighboring workers in a ring map.
  * @param parent rank of the parent worker.
  */
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
                                       ring: (Int, Int), parent: Int) {
  /**
    * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
    * client.
    * @param worldSize the number of total distributed workers. Must match `numWorkers` used in
    *                  LinkMap.
    * @return a ByteBuffer containing encoded data.
    */
  def toByteBuffer(worldSize: Int): ByteBuffer = {
    val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
    buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
    // neighbors in tree structure
    neighbors.foreach { n => buffer.putInt(n) }
    buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
    buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)

    buffer.flip()
    buffer
  }
}

private[rabit] class LinkMap(numWorkers: Int) {
  private def getNeighbors(rank: Int): Seq[Int] = {
    val rank1 = rank + 1
    Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
      r >= 0 && r < numWorkers
    }
  }

  /**
    * Construct a ring structure that tends to share nodes with the tree.
    *
    * @param treeMap
    * @param parentMap
    * @param rank
    * @return Seq[Int] instance starting from rank.
    */
  private def constructShareRing(treeMap: Map[Int, Seq[Int]],
                                 parentMap: Map[Int, Int],
                                 rank: Int = 0): Seq[Int] = {
    treeMap(rank).toSet - parentMap(rank) match {
      case emptySet if emptySet.isEmpty =>
        List(rank)
      case connectionSet =>
        connectionSet.zipWithIndex.foldLeft(List(rank)) {
          case (ringSeq, (v, cnt)) =>
            val vConnSeq = constructShareRing(treeMap, parentMap, v)
            vConnSeq match {
              case vconn if vconn.size == cnt + 1 =>
                ringSeq ++ vconn.reverse
              case vconn =>
                ringSeq ++ vconn
            }
        }
    }
  }
  /**
    * Construct a ring connection used to recover local data.
    *
    * @param treeMap
    * @param parentMap
    */
  private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
    assert(parentMap(0) == -1)

    val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
    assert(sharedRing.length == treeMap.size)

    (0 until numWorkers).map { r =>
      val rPrev = (r + numWorkers - 1) % numWorkers
      val rNext = (r + 1) % numWorkers
      sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
    }.toMap
  }

  private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
  private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
  private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
  val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
    case ((rmap, k), i) =>
      val kNext = ringMap_(k)._2
      (rmap ++ Map(kNext -> (i + 1)), kNext)
  }._1

  val ringMap = ringMap_.map {
    case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
  }
  val treeMap = treeMap_.map {
    case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
  }
  val parentMap = parentMap_.map {
    case (k, v) if k == 0 =>
      rMap_(k) -> -1
    case (k, v) =>
      rMap_(k) -> rMap_(v)
  }

  def assignRank(rank: Int): AssignedRank = {
    AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy