All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.xgboost.rabit.RabitTrackerH2O Maven / Gradle / Ivy

package hex.tree.xgboost.rabit;

import hex.tree.xgboost.rabit.util.LinkMap;
import ai.h2o.xgboost4j.java.IRabitTracker;
import org.apache.log4j.Logger;
import water.H2O;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.*;

public class RabitTrackerH2O implements IRabitTracker {

    private static final Logger LOG = Logger.getLogger(RabitTrackerH2O.class);

    public static final int MAGIC = 0xff99;

    private ServerSocketChannel sock;
    private int port = 9091;

    private final int workers;

    private final Map envs = new HashMap<>();

    private volatile RabitTrackerH2OThread trackerThread;

    public RabitTrackerH2O(int workers) {
        if (workers < 1) {
            throw new IllegalStateException("workers must be greater than or equal to one (1).");
        }
        this.workers = workers;
    }

    @Override
    public Map getWorkerEnvs() {
        envs.put("DMLC_NUM_WORKER", String.valueOf(workers));
        envs.put("DMLC_NUM_SERVER", "0");
        envs.put("DMLC_TRACKER_URI", H2O.SELF_ADDRESS.getHostAddress());
        envs.put("DMLC_TRACKER_PORT", Integer.toString(port));
        envs.put("DMLC_TASK_ID", Integer.toString(H2O.SELF.index()));
        envs.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false");
        envs.put("rabit_world_size", Integer.toString(workers));
        return envs;
    }

    @Override
    public boolean start(long timeout) {
        boolean tryToBind = true;
        while (tryToBind) {
            try {
                this.sock = ServerSocketChannel.open();
                this.sock.socket().setReceiveBufferSize(64 * 1024);
                InetSocketAddress isa = new InetSocketAddress(H2O.SELF_ADDRESS, this.port);
                this.sock.socket().bind(isa);
                LOG.debug("Started listening on port " + this.port);
                tryToBind = false;
            } catch (java.io.IOException e) {
                this.port++;
                try {
                    this.sock.close();
                } catch (IOException socketCloseException) {
                    LOG.warn("Failed to close Rabit Tracker socket on port " + sock.socket().getLocalPort());
                }
                if (this.port > 9999) {
                    throw new RuntimeException("Failed to bind Rabit tracker to a socket in range 9091-9999", e);
                }
            }
        }

        if (null != this.trackerThread) {
            throw new IllegalStateException("Rabit tracker already started.");
        }
        this.trackerThread = new RabitTrackerH2OThread(this);
        this.trackerThread.start();
        return true;
    }

    @Override
    public void stop() {
        assert this.trackerThread != null;
        try {
            this.trackerThread.interrupt();
        } catch (SecurityException e){
            LOG.error("Could not interrupt a thread in RabitTrackerH2O: " + trackerThread.toString());
        }
        this.trackerThread.terminateSocketChannels();
        this.trackerThread = null;


        try {
            this.sock.close();
            this.port = 9091;
        } catch (IOException e) {
            LOG.error("Failed to close Rabit tracker socket.", e);
        }
    }

    private static class RabitTrackerH2OThread extends Thread {

        private final RabitTrackerH2O tracker;

        private final Map jobToRankMap = new HashMap<>();
        private final List socketChannels = new ArrayList<>();

        private RabitTrackerH2OThread(RabitTrackerH2O tracker) {
            setPriority(MAX_PRIORITY-1);
            this.setName("TCP-" + tracker.sock);
            this.tracker = tracker;
        }

        private void terminateSocketChannels(){
            for (SocketChannel channel : socketChannels) {
                try {
                    channel.close();
                } catch (IOException e) {
                    LOG.warn("Unable to close RabitTracerH2O SocketChannel on port " + channel.socket().getPort());
                }
            }
        }

        private static final String PRINT_CMD = "print";
        private static final String SHUTDOWN_CMD = "shutdown";
        private static final String START_CMD = "start";
        private static final String RECOVER_CMD = "recover";
        private static final String NULL_STR = "null";

        @Override
        public void run() {
            LOG.debug("Rabit tracker started on port " + tracker.port);
            LinkMap linkMap = null;
            Set shutdown = new HashSet<>();
            Map waitConn = new HashMap<>();
            List pending = new ArrayList<>();
            Queue todoNodes = new ArrayDeque<>(tracker.workers);
            while (!interrupted() && shutdown.size() != tracker.workers) {
                try{
                    final SocketChannel channel = tracker.sock.accept(); // Does not proceed when interrupt() is called.
                    socketChannels.add(channel);
                    final RabitWorker worker = new RabitWorker(channel);

                    if (PRINT_CMD.equals(worker.cmd)) {
                        String msg = worker.receiver().getStr();
                        LOG.warn("Rabit worker: " + msg);
                        continue;
                    } else if (SHUTDOWN_CMD.equals(worker.cmd)) {
                        assert worker.rank >= 0 && !shutdown.contains(worker.rank);
                        assert !waitConn.containsKey(worker.rank);
                        shutdown.add(worker.rank);
                        channel.socket().close();
                        LOG.debug("Received " + worker.cmd + " signal from " + worker.rank);
                        continue;
                    }
                    assert START_CMD.equals(worker.cmd) || RECOVER_CMD.equals(worker.cmd);

                    if (null == linkMap) {
                        assert START_CMD.equals(worker.cmd);
                        linkMap = new LinkMap(tracker.workers);
                        for (int i = 0; i < tracker.workers; i++) {
                            todoNodes.add(i);
                        }
                    } else {
                        assert worker.worldSize == -1 || worker.worldSize == tracker.workers;
                    }

                    if (RECOVER_CMD.equals(worker.cmd)) {
                        assert worker.rank >= 0;
                    }

                    int rank = worker.decideRank(jobToRankMap);
                    if (-1 == rank) {
                        assert todoNodes.size() != 0;
                        pending.add(worker);
                        if (pending.size() == todoNodes.size()) {
                            Collections.sort(pending);
                            for (RabitWorker p : pending) {
                                rank = todoNodes.poll();
                                if (!NULL_STR.equals(p.jobId)) {
                                    jobToRankMap.put(p.jobId, rank);
                                }
                                p.assignRank(rank, waitConn, linkMap);

                                if (p.waitAccept > 0) {
                                    waitConn.put(rank, p);
                                }

                                LOG.debug("Received " + p.cmd +
                                        " signal from " + p.host + ":" + p.workerPort +
                                        ". Assigned rank " + p.rank
                                );
                            }
                        }
                        if (todoNodes.isEmpty()) {
                            LOG.debug("All " + tracker.workers + " Rabit workers are getting started.");
                        }
                    } else {
                        worker.assignRank(rank, waitConn, linkMap);
                        if (worker.waitAccept > 0) {
                            waitConn.put(rank, worker);
                        }
                    }
                } catch (IOException e) {
                    LOG.error("Exception in Rabit tracker.", e);
                }
            }
            LOG.debug("All Rabit nodes finished.");
        }
    }

    @Override
    public int waitFor(long timeout) {
        while(null != this.trackerThread && this.trackerThread.isAlive()) {
            try {
                this.trackerThread.join(timeout);
            } catch (InterruptedException e) {
                LOG.debug("Rabit tracker thread got suddenly interrupted.", e);
                Thread.currentThread().interrupt();
            }
        }
        return 0;
    }

    @Override
    public void uncaughtException(Thread t, Throwable e) {
        LOG.error("Uncaught exception occurred on Rabit tracker thread " + t.getName(), e);
        stop();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy