This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
Copy pathlaunch.py
executable file
·88 lines (80 loc) · 3.19 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
"""
Launch a distributed job
"""
import argparse
import os, sys
import signal
import logging
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../dmlc-core/tracker"))
def dmlc_opts(opts):
"""convert from mxnet's opts to dmlc's opts
"""
args = ['--num-workers', str(opts.num_workers),
'--num-servers', str(opts.num_servers),
'--cluster', opts.launcher,
'--host-file', opts.hostfile,
'--sync-dst-dir', opts.sync_dst_dir]
args += opts.command;
try:
from dmlc_tracker import opts
except ImportError:
print("Can't load dmlc_tracker package. Perhaps you need to run")
print(" git submodule update --init --recursive")
raise
dmlc_opts = opts.get_opts(args)
return dmlc_opts
def main():
parser = argparse.ArgumentParser(description='Launch a distributed job')
parser.add_argument('-n', '--num-workers', required=True, type=int,
help = 'number of worker nodes to be launched')
parser.add_argument('-s', '--num-servers', type=int,
help = 'number of server nodes to be launched, \
in default it is equal to NUM_WORKERS')
parser.add_argument('-H', '--hostfile', type=str,
help = 'the hostfile of slave machines which will run \
the job. Required for ssh and mpi launcher')
parser.add_argument('--sync-dst-dir', type=str,
help = 'if specificed, it will sync the current \
directory into slave machines\'s SYNC_DST_DIR if ssh \
launcher is used')
parser.add_argument('--launcher', type=str, default='ssh',
choices = ['local', 'ssh', 'mpi', 'sge', 'yarn'],
help = 'the launcher to use')
parser.add_argument('command', nargs='+',
help = 'command for launching the program')
args, unknown = parser.parse_known_args()
args.command += unknown
if args.num_servers is None:
args.num_servers = args.num_workers
args = dmlc_opts(args)
if args.host_file is None or args.host_file == 'None':
if args.cluster == 'yarn':
from dmlc_tracker import yarn
yarn.submit(args)
elif args.cluster == 'local':
from dmlc_tracker import local
local.submit(args)
elif args.cluster == 'sge':
from dmlc_tracker import sge
sge.submit(args)
else:
raise RuntimeError('Unknown submission cluster type %s' % args.cluster)
else:
if args.cluster == 'ssh':
from dmlc_tracker import ssh
ssh.submit(args)
elif args.cluster == 'mpi':
from dmlc_tracker import mpi
mpi.submit(args)
else:
raise RuntimeError('Unknown submission cluster type %s' % args.cluster)
def signal_handler(signal, frame):
logging.info('Stop launcher')
sys.exit(0)
if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s'
logging.basicConfig(format=fmt, level=logging.INFO)
signal.signal(signal.SIGINT, signal_handler)
main()