All Downloads are FREE. Search and download functionalities are using the official Maven repository.

st4j.0.72.source-code.tracker.py Maven / Gradle / Ivy

There is a newer version: 0.90
Show newest version
"""
Tracker script for DMLC
Implements the tracker control protocol
 - start dmlc jobs
 - start ps scheduler and rabit tracker
 - help nodes to establish links with each other

Tianqi Chen
"""
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements
from __future__ import absolute_import

import os
import sys
import socket
import struct
import subprocess
import argparse
import time
import logging
from threading import Thread

class ExSocket(object):
    """
    Extension of socket to handle recv and send of special data
    """
    def __init__(self, sock):
        self.sock = sock
    def recvall(self, nbytes):
        res = []
        nread = 0
        while nread < nbytes:
            chunk = self.sock.recv(min(nbytes - nread, 1024))
            nread += len(chunk)
            res.append(chunk)
        return b''.join(res)
    def recvint(self):
        return struct.unpack('@i', self.recvall(4))[0]
    def sendint(self, n):
        self.sock.sendall(struct.pack('@i', n))
    def sendstr(self, s):
        self.sendint(len(s))
        self.sock.sendall(s.encode())
    def recvstr(self):
        slen = self.recvint()
        return self.recvall(slen).decode()

# magic number used to verify existence of data
kMagic = 0xff99

def get_some_ip(host):
    return socket.getaddrinfo(host, None)[0][4][0]

def get_family(addr):
    return socket.getaddrinfo(addr, None)[0][0]

class SlaveEntry(object):
    def __init__(self, sock, s_addr):
        slave = ExSocket(sock)
        self.sock = slave
        self.host = get_some_ip(s_addr[0])
        magic = slave.recvint()
        assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, self.host)
        slave.sendint(kMagic)
        self.rank = slave.recvint()
        self.world_size = slave.recvint()
        self.jobid = slave.recvstr()
        self.cmd = slave.recvstr()
        self.wait_accept = 0
        self.port = None

    def decide_rank(self, job_map):
        if self.rank >= 0:
            return self.rank
        if self.jobid != 'NULL' and self.jobid in job_map:
            return job_map[self.jobid]
        return -1

    def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
        self.rank = rank
        nnset = set(tree_map[rank])
        rprev, rnext = ring_map[rank]
        self.sock.sendint(rank)
        # send parent rank
        self.sock.sendint(parent_map[rank])
        # send world size
        self.sock.sendint(len(tree_map))
        self.sock.sendint(len(nnset))
        # send the rprev and next link
        for r in nnset:
            self.sock.sendint(r)
        # send prev link
        if rprev != -1 and rprev != rank:
            nnset.add(rprev)
            self.sock.sendint(rprev)
        else:
            self.sock.sendint(-1)
        # send next link
        if rnext != -1 and rnext != rank:
            nnset.add(rnext)
            self.sock.sendint(rnext)
        else:
            self.sock.sendint(-1)
        while True:
            ngood = self.sock.recvint()
            goodset = set([])
            for _ in range(ngood):
                goodset.add(self.sock.recvint())
            assert goodset.issubset(nnset)
            badset = nnset - goodset
            conset = []
            for r in badset:
                if r in wait_conn:
                    conset.append(r)
            self.sock.sendint(len(conset))
            self.sock.sendint(len(badset) - len(conset))
            for r in conset:
                self.sock.sendstr(wait_conn[r].host)
                self.sock.sendint(wait_conn[r].port)
                self.sock.sendint(r)
            nerr = self.sock.recvint()
            if nerr != 0:
                continue
            self.port = self.sock.recvint()
            rmset = []
            # all connection was successuly setup
            for r in conset:
                wait_conn[r].wait_accept -= 1
                if wait_conn[r].wait_accept == 0:
                    rmset.append(r)
            for r in rmset:
                wait_conn.pop(r, None)
            self.wait_accept = len(badset) - len(conset)
            return rmset

class RabitTracker(object):
    """
    tracker for rabit
    """
    def __init__(self, hostIP, nslave, port=9091, port_end=9999):
        sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
        for port in range(port, port_end):
            try:
                sock.bind((hostIP, port))
                self.port = port
                break
            except socket.error as e:
                if e.errno in [98, 48]:
                    continue
                else:
                    raise
        sock.listen(256)
        self.sock = sock
        self.hostIP = hostIP
        self.thread = None
        self.start_time = None
        self.end_time = None
        self.nslave = nslave
        logging.info('start listen on %s:%d', hostIP, self.port)

    def __del__(self):
        self.sock.close()

    @staticmethod
    def get_neighbor(rank, nslave):
        rank = rank + 1
        ret = []
        if rank > 1:
            ret.append(rank // 2 - 1)
        if rank * 2 - 1 < nslave:
            ret.append(rank * 2 - 1)
        if rank * 2 < nslave:
            ret.append(rank * 2)
        return ret

    def slave_envs(self):
        """
        get enviroment variables for slaves
        can be passed in as args or envs
        """
        return {'DMLC_TRACKER_URI': self.hostIP,
                'DMLC_TRACKER_PORT': self.port}

    def get_tree(self, nslave):
        tree_map = {}
        parent_map = {}
        for r in range(nslave):
            tree_map[r] = self.get_neighbor(r, nslave)
            parent_map[r] = (r + 1) // 2 - 1
        return tree_map, parent_map

    def find_share_ring(self, tree_map, parent_map, r):
        """
        get a ring structure that tends to share nodes with the tree
        return a list starting from r
        """
        nset = set(tree_map[r])
        cset = nset - set([parent_map[r]])
        if len(cset) == 0:
            return [r]
        rlst = [r]
        cnt = 0
        for v in cset:
            vlst = self.find_share_ring(tree_map, parent_map, v)
            cnt += 1
            if cnt == len(cset):
                vlst.reverse()
            rlst += vlst
        return rlst

    def get_ring(self, tree_map, parent_map):
        """
        get a ring connection used to recover local data
        """
        assert parent_map[0] == -1
        rlst = self.find_share_ring(tree_map, parent_map, 0)
        assert len(rlst) == len(tree_map)
        ring_map = {}
        nslave = len(tree_map)
        for r in range(nslave):
            rprev = (r + nslave - 1) % nslave
            rnext = (r + 1) % nslave
            ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
        return ring_map

    def get_link_map(self, nslave):
        """
        get the link map, this is a bit hacky, call for better algorithm
        to place similar nodes together
        """
        tree_map, parent_map = self.get_tree(nslave)
        ring_map = self.get_ring(tree_map, parent_map)
        rmap = {0 : 0}
        k = 0
        for i in range(nslave - 1):
            k = ring_map[k][1]
            rmap[k] = i + 1

        ring_map_ = {}
        tree_map_ = {}
        parent_map_ = {}
        for k, v in ring_map.items():
            ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
        for k, v in tree_map.items():
            tree_map_[rmap[k]] = [rmap[x] for x in v]
        for k, v in parent_map.items():
            if k != 0:
                parent_map_[rmap[k]] = rmap[v]
            else:
                parent_map_[rmap[k]] = -1
        return tree_map_, parent_map_, ring_map_

    def accept_slaves(self, nslave):
        # set of nodes that finishs the job
        shutdown = {}
        # set of nodes that is waiting for connections
        wait_conn = {}
        # maps job id to rank
        job_map = {}
        # list of workers that is pending to be assigned rank
        pending = []
        # lazy initialize tree_map
        tree_map = None

        while len(shutdown) != nslave:
            fd, s_addr = self.sock.accept()
            s = SlaveEntry(fd, s_addr)
            if s.cmd == 'print':
                msg = s.sock.recvstr()
                logging.info(msg.strip())
                continue
            if s.cmd == 'shutdown':
                assert s.rank >= 0 and s.rank not in shutdown
                assert s.rank not in wait_conn
                shutdown[s.rank] = s
                logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
                continue
            assert s.cmd == 'start' or s.cmd == 'recover'
            # lazily initialize the slaves
            if tree_map is None:
                assert s.cmd == 'start'
                if s.world_size > 0:
                    nslave = s.world_size
                tree_map, parent_map, ring_map = self.get_link_map(nslave)
                # set of nodes that is pending for getting up
                todo_nodes = list(range(nslave))
            else:
                assert s.world_size == -1 or s.world_size == nslave
            if s.cmd == 'recover':
                assert s.rank >= 0

            rank = s.decide_rank(job_map)
            # batch assignment of ranks
            if rank == -1:
                assert len(todo_nodes) != 0
                pending.append(s)
                if len(pending) == len(todo_nodes):
                    pending.sort(key=lambda x: x.host)
                    for s in pending:
                        rank = todo_nodes.pop(0)
                        if s.jobid != 'NULL':
                            job_map[s.jobid] = rank
                        s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
                        if s.wait_accept > 0:
                            wait_conn[rank] = s
                        logging.debug('Recieve %s signal from %s; assign rank %d',
                                      s.cmd, s.host, s.rank)
                if len(todo_nodes) == 0:
                    logging.info('@tracker All of %d nodes getting started', nslave)
                    self.start_time = time.time()
            else:
                s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
                logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
                if s.wait_accept > 0:
                    wait_conn[rank] = s
        logging.info('@tracker All nodes finishes job')
        self.end_time = time.time()
        logging.info('@tracker %s secs between node start and job finish',
                     str(self.end_time - self.start_time))

    def start(self, nslave):
        def run():
            self.accept_slaves(nslave)
        self.thread = Thread(target=run, args=())
        self.thread.setDaemon(True)
        self.thread.start()

    def join(self):
        while self.thread.isAlive():
            self.thread.join(100)

    def alive(self):
        return self.thread.isAlive()

class PSTracker(object):
    """
    Tracker module for PS
    """
    def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None):
        """
        Starts the PS scheduler
        """
        self.cmd = cmd
        if cmd is None:
            return
        envs = {} if envs is None else envs
        self.hostIP = hostIP
        sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
        for port in range(port, port_end):
            try:
                sock.bind(('', port))
                self.port = port
                sock.close()
                break
            except socket.error:
                continue
        env = os.environ.copy()

        env['DMLC_ROLE'] = 'scheduler'
        env['DMLC_PS_ROOT_URI'] = str(self.hostIP)
        env['DMLC_PS_ROOT_PORT'] = str(self.port)
        for k, v in envs.items():
            env[k] = str(v)
        self.thread = Thread(
            target=(lambda: subprocess.check_call(self.cmd, env=env, shell=True)), args=())
        self.thread.setDaemon(True)
        self.thread.start()

    def join(self):
        if self.cmd is not None:
            while self.thread.isAlive():
                self.thread.join(100)

    def slave_envs(self):
        if self.cmd is None:
            return {}
        else:
            return {'DMLC_PS_ROOT_URI': self.hostIP,
                    'DMLC_PS_ROOT_PORT': self.port}

    def alive(self):
        if self.cmd is not None:
            return self.thread.isAlive()
        else:
            return False


def get_host_ip(hostIP=None):
    if hostIP is None or hostIP == 'auto':
        hostIP = 'ip'

    if hostIP == 'dns':
        hostIP = socket.getfqdn()
    elif hostIP == 'ip':
        from socket import gaierror
        try:
            hostIP = socket.gethostbyname(socket.getfqdn())
        except gaierror:
            logging.warn('gethostbyname(socket.getfqdn()) failed... trying on hostname()')
            hostIP = socket.gethostbyname(socket.gethostname())
        if hostIP.startswith("127."):
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            # doesn't have to be reachable
            s.connect(('10.255.255.255', 1))
            hostIP = s.getsockname()[0]
    return hostIP


def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None):
    if nserver == 0:
        pscmd = None

    envs = {'DMLC_NUM_WORKER' : nworker,
            'DMLC_NUM_SERVER' : nserver}
    hostIP = get_host_ip(hostIP)

    if nserver == 0:
        rabit = RabitTracker(hostIP=hostIP, nslave=nworker)
        envs.update(rabit.slave_envs())
        rabit.start(nworker)
        if rabit.alive():
           fun_submit(nworker, nserver, envs) 
    else:
        pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs)
        envs.update(pserver.slave_envs())
        if pserver.alive():
            fun_submit(nworker, nserver, envs)

    if nserver == 0:
        rabit.join()
    else:
        pserver.join()

def start_rabit_tracker(args):
    """Standalone function to start rabit tracker.

    Parameters
    ----------
    args: arguments to start the rabit tracker.
    """
    envs = {'DMLC_NUM_WORKER' : args.num_workers,
            'DMLC_NUM_SERVER' : args.num_servers}
    rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers)
    envs.update(rabit.slave_envs())
    rabit.start(args.num_workers)
    sys.stdout.write('DMLC_TRACKER_ENV_START\n')
    # simply write configuration to stdout
    for k, v in envs.items():
        sys.stdout.write('%s=%s\n' % (k, str(v)))
    sys.stdout.write('DMLC_TRACKER_ENV_END\n')
    sys.stdout.flush()
    rabit.join()


def main():
    """Main function if tracker is executed in standalone mode."""
    parser = argparse.ArgumentParser(description='Rabit Tracker start.')
    parser.add_argument('--num-workers', required=True, type=int,
                        help='Number of worker proccess to be launched.')
    parser.add_argument('--num-servers', default=0, type=int,
                        help='Number of server process to be launched. Only used in PS jobs.')
    parser.add_argument('--host-ip', default=None, type=str,
                        help=('Host IP addressed, this is only needed ' +
                              'if the host IP cannot be automatically guessed.'))
    parser.add_argument('--log-level', default='INFO', type=str,
                        choices=['INFO', 'DEBUG'],
                        help='Logging level of the logger.')
    args = parser.parse_args()

    fmt = '%(asctime)s %(levelname)s %(message)s'
    if args.log_level == 'INFO':
        level = logging.INFO
    elif args.log_level == 'DEBUG':
        level = logging.DEBUG
    else:
        raise RuntimeError("Unknown logging level %s" % args.log_level)

    logging.basicConfig(format=fmt, level=level)

    if args.num_servers == 0:
        start_rabit_tracker(args)
    else:
        raise RuntimeError("Do not yet support start ps tracker in standalone mode.")

if __name__ == "__main__":
    main()




© 2015 - 2025 Weber Informatics LLC | Privacy Policy