Skip to content

Commit cf17068

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Initialize distributed using multiproc with all visible GPUs
Summary: Pull Request resolved: facebookresearch#695 Differential Revision: D15182613 Pulled By: myleott fbshipit-source-id: 4196346517d8e75ed9e903e9e01ab943d086f6f1
1 parent 96ac28d commit cf17068

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

fairseq/distributed_utils.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections import namedtuple
99
import os
1010
import pickle
11+
import socket
1112
import subprocess
1213
import warnings
1314

@@ -42,9 +43,20 @@ def infer_init_method(args):
4243
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
4344
args.distributed_init_method = 'tcp://{host}:{port}'.format(
4445
host=hostnames.split()[0].decode('utf-8'),
45-
port=args.distributed_port)
46-
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
47-
args.device_id = int(os.environ.get('SLURM_LOCALID'))
46+
port=args.distributed_port,
47+
)
48+
nnodes = int(os.environ.get('SLURM_NNODES'))
49+
ntasks_per_node = int(os.environ.get('SLURM_NTASKS_PER_NODE'))
50+
if ntasks_per_node == 1:
51+
assert args.distributed_world_size % nnodes == 0
52+
gpus_per_node = args.distributed_world_size // nnodes
53+
node_id = int(os.environ.get('SLURM_NODEID'))
54+
args.distributed_rank = node_id * gpus_per_node
55+
else:
56+
assert ntasks_per_node == args.distributed_world_size // nnodes
57+
args.distributed_no_spawn = True
58+
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
59+
args.device_id = int(os.environ.get('SLURM_LOCALID'))
4860
except subprocess.CalledProcessError as e: # scontrol failed
4961
raise e
5062
except FileNotFoundError: # Slurm is not installed
@@ -60,13 +72,17 @@ def distributed_init(args):
6072
else:
6173
print('| distributed init (rank {}): {}'.format(
6274
args.distributed_rank, args.distributed_init_method), flush=True)
63-
6475
dist.init_process_group(
6576
backend=args.distributed_backend,
6677
init_method=args.distributed_init_method,
6778
world_size=args.distributed_world_size,
6879
rank=args.distributed_rank,
6980
)
81+
print('| initialized host {} as rank {}'.format(
82+
socket.gethostname(), args.distributed_rank), flush=True)
83+
84+
# perform a dummy all-reduce to initialize the NCCL communicator
85+
dist.all_reduce(torch.rand(1).cuda())
7086

7187
suppress_output(is_master(args))
7288

fairseq/options.py

+2
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def add_distributed_training_args(parser):
266266
help='port number (not required if using --distributed-init-method)')
267267
group.add_argument('--device-id', '--local_rank', default=0, type=int,
268268
help='which GPU to use (usually configured automatically)')
269+
group.add_argument('--distributed-no-spawn', action='store_true',
270+
help='do not spawn multiple processes even if multiple GPUs are visible')
269271
group.add_argument('--ddp-backend', default='c10d', type=str,
270272
choices=['c10d', 'no_c10d'],
271273
help='DistributedDataParallel backend')

train.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,21 @@
2323
from fairseq.meters import AverageMeter, StopwatchMeter
2424

2525

26-
def main(args):
26+
def main(args, init_distributed=False):
2727
utils.import_user_module(args)
2828

29-
if args.max_tokens is None:
30-
args.max_tokens = 6000
31-
print(args)
29+
assert args.max_tokens is not None or args.max_sentences is not None, \
30+
'Must specify batch size either with --max-tokens or --max-sentences'
3231

32+
# Initialize CUDA and distributed training
3333
if torch.cuda.is_available() and not args.cpu:
3434
torch.cuda.set_device(args.device_id)
3535
torch.manual_seed(args.seed)
36+
if init_distributed:
37+
args.distributed_rank = distributed_utils.distributed_init(args)
38+
39+
# Print args
40+
print(args)
3641

3742
# Setup task, e.g., translation, language modeling, etc.
3843
task = tasks.setup_task(args)
@@ -372,11 +377,11 @@ def load_dataset_splits(args, task):
372377
raise e
373378

374379

375-
def distributed_main(i, args):
380+
def distributed_main(i, args, start_rank=0):
376381
args.device_id = i
377382
if args.distributed_rank is None: # torch.multiprocessing.spawn
378-
args.distributed_rank = i
379-
main(args)
383+
args.distributed_rank = start_rank + i
384+
main(args, init_distributed=True)
380385

381386

382387
def cli_main():
@@ -388,9 +393,19 @@ def cli_main():
388393

389394
if args.distributed_init_method is not None:
390395
# distributed training
391-
distributed_main(args.device_id, args)
396+
if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
397+
start_rank = args.distributed_rank
398+
args.distributed_rank = None # assign automatically
399+
torch.multiprocessing.spawn(
400+
fn=distributed_main,
401+
args=(args, start_rank),
402+
nprocs=torch.cuda.device_count(),
403+
)
404+
else:
405+
distributed_main(args.device_id, args)
392406
elif args.distributed_world_size > 1:
393407
# fallback for single node with multiple GPUs
408+
assert args.distributed_world_size <= torch.cuda.device_count()
394409
port = random.randint(10000, 20000)
395410
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
396411
args.distributed_rank = None # set based on device id

0 commit comments

Comments
 (0)