All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.rabit.util.LinkMap Maven / Gradle / Ivy

package hex.tree.xgboost.rabit.util;

import water.util.Pair;

import java.util.*;

/**
 * Java implementation of ai.h2o.xgboost4j.scala.rabit.util.LinkMap
 *
 * Naming left for consistency. In reality this is a simple binary tree data structure, which is used for communication
 * between Rabit workers.
 *
 */
public class LinkMap {
    private int numWorkers;
    public Map> treeMap = new LinkedHashMap<>();
    public Map parentMap = new LinkedHashMap<>();
    public Map> ringMap = new LinkedHashMap<>();

    public LinkMap(int numWorkers) {
        this.numWorkers = numWorkers;

        Map> treeMap_ = initTreeMap();
        Map parentMap_ = initParentMap();
        Map> ringMap_ = constructRingMap(treeMap_, parentMap_);

        Map rMap_ = new LinkedHashMap<>(numWorkers - 1);
        rMap_.put(0, 0);
        int k = 0;
        for(int i = 0; i < numWorkers - 1; i++) {
            int kNext = ringMap_.get(k)._2();
            k = kNext;
            rMap_.put(kNext, (i + 1));
        }

        for (Map.Entry> kv : ringMap_.entrySet()) {
            this.ringMap.put(
                    rMap_.get(kv.getKey()),
                    new Pair<>(rMap_.get(kv.getValue()._1()), rMap_.get(kv.getValue()._2()))
            );
        }

        for (Map.Entry> kv : treeMap_.entrySet()) {
            List mapped = new ArrayList<>(kv.getValue().size());
            for(Integer v : kv.getValue()) {
                mapped.add(rMap_.get(v));
            }
            treeMap.put(
                    rMap_.get(kv.getKey()),
                    mapped
            );
        }

        for (Map.Entry kv : parentMap_.entrySet()) {
            if(kv.getKey() == 0) {
                parentMap.put(rMap_.get(kv.getKey()), -1);
            } else {
                parentMap.put(rMap_.get(kv.getKey()), rMap_.get(kv.getValue()));
            }
        }
    }

    /**
     * Generates a mapping node -> neighbours(node)
     */
    Map> initTreeMap() {
        Map> treeMap = new LinkedHashMap<>(numWorkers);
        for(int r = 0; r < numWorkers; r++) {
            treeMap.put(r, getNeighbours(r));
        }
        return treeMap;
    }

    /**
     * Generates a mapping node -> parent (parent of root is -1)
     */
    Map initParentMap() {
        Map parentMap = new LinkedHashMap<>(numWorkers);
        for(int r = 0; r < numWorkers; r++) {
            parentMap.put(r, ((r + 1) / 2 - 1) );
        }
        return parentMap;
    }

    /**
     * Returns a list containing existing neighbours of a node, this includes at most 3 nodes: parent, left and right child.
     */
    List getNeighbours(int rank) {
        if(rank < 0) {
            throw new IllegalStateException("Rank should be non negative");
        }

        if(rank >= numWorkers) {
            throw new IllegalStateException("Rank ["+rank+"] too high for the number of workers ["+numWorkers+"]");
        }

        rank += 1;
        List neighbour = new ArrayList<>();

        if(rank > 1) {
            neighbour.add(rank / 2 - 1);
        }
        if(rank * 2 - 1 < numWorkers) {
            neighbour.add(rank * 2 - 1);
        }
        if(rank * 2 < numWorkers) {
            neighbour.add(rank * 2);
        }

        return neighbour;
    }

    /**
     * Returns a DFS (root, DFS(left_child), DFS(right)child) order from root with given rank.
     */
    List constructShareRing(Map> treeMap,
                                            Map parentMap,
                                            int rank) {
        Set connectionSet = new LinkedHashSet<>(treeMap.get(rank));
        connectionSet.remove(parentMap.get(rank));
        if(connectionSet.isEmpty()) {
            return Collections.singletonList(rank);
        } else {
            List ringSeq = new LinkedList<>();
            ringSeq.add(rank);
            int cnt = 0;
            for(Integer n : connectionSet) {
                List vConnSeq = constructShareRing(treeMap, parentMap, n);
                cnt++;
                if(connectionSet.size() == cnt) {
                    Collections.reverse(vConnSeq);
                }
                ringSeq.addAll(vConnSeq);
            }
            return ringSeq;
        }

    }

    /**
     * Returns for each node with "rank" the previous and next node in DFS order. For the root the "previous"
     * entry will be the last element, which will create a ring type structure.
     */
    Map> constructRingMap(Map> treeMap,
                                                                  Map parentMap) {
        assert parentMap.get(0) == -1;

        List sharedRing = constructShareRing(treeMap, parentMap, 0);
        assert sharedRing.size() == treeMap.size();

        Map> ringMap = new LinkedHashMap<>(numWorkers);
        for(int r = 0; r < numWorkers; r++) {
            int rPrev = (r + numWorkers - 1) % numWorkers;
            int rNext = (r + 1) % numWorkers;
            ringMap.put(sharedRing.get(r), new Pair<>(sharedRing.get(rPrev), sharedRing.get(rNext)));
        }
        return ringMap;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy