Skip to content

Commit 34726d5

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Move distributed_init into DistributedFairseqModel (facebookresearch#687)
Summary: This should make rendezvous happen as lazily as possible. Pull Request resolved: facebookresearch#687 Differential Revision: D15151145 Pulled By: myleott fbshipit-source-id: d70816a85414c5d509a6b12e2b339b4736db2c88
1 parent fb18be0 commit 34726d5

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

fairseq/distributed_utils.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import pickle
1111
import subprocess
12+
import warnings
1213

1314
import torch
1415
import torch.distributed as dist
@@ -54,18 +55,22 @@ def distributed_init(args):
5455
if args.distributed_world_size == 1:
5556
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
5657

57-
print('| distributed init (rank {}): {}'.format(
58-
args.distributed_rank, args.distributed_init_method), flush=True)
59-
60-
dist.init_process_group(
61-
backend=args.distributed_backend,
62-
init_method=args.distributed_init_method,
63-
world_size=args.distributed_world_size,
64-
rank=args.distributed_rank,
65-
)
58+
if torch.distributed.is_initialized():
59+
warnings.warn('Distributed is already initialized, cannot initialize twice!')
60+
else:
61+
print('| distributed init (rank {}): {}'.format(
62+
args.distributed_rank, args.distributed_init_method), flush=True)
63+
64+
dist.init_process_group(
65+
backend=args.distributed_backend,
66+
init_method=args.distributed_init_method,
67+
world_size=args.distributed_world_size,
68+
rank=args.distributed_rank,
69+
)
6670

67-
suppress_output(is_master(args))
71+
suppress_output(is_master(args))
6872

73+
args.distributed_rank = torch.distributed.get_rank()
6974
return args.distributed_rank
7075

7176

fairseq/models/distributed_fairseq_model.py

+6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
# can be found in the PATENTS file in the same directory.
77

88
import inspect
9+
import socket
10+
911
from torch.nn import parallel
1012

13+
from fairseq import distributed_utils
1114
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
1215

1316
from . import BaseFairseqModel
@@ -26,6 +29,9 @@ def DistributedFairseqModel(args, model):
2629
args (argparse.Namespace): fairseq args
2730
model (BaseFairseqModel): model to wrap
2831
"""
32+
# rendezvous with other workers
33+
args.distributed_rank = distributed_utils.distributed_init(args)
34+
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
2935

3036
# determine which DDP class to extend
3137
assert isinstance(model, BaseFairseqModel)

train.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from fairseq.meters import AverageMeter, StopwatchMeter
2424

2525

26-
def main(args, init_distributed=False):
26+
def main(args):
2727
utils.import_user_module(args)
2828

2929
if args.max_tokens is None:
@@ -82,12 +82,6 @@ def main(args, init_distributed=False):
8282
num_workers=args.num_workers,
8383
)
8484

85-
# Initialize distributed training (after data loading)
86-
if init_distributed:
87-
import socket
88-
args.distributed_rank = distributed_utils.distributed_init(args)
89-
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
90-
9185
# Load the latest checkpoint if one is available
9286
if not load_checkpoint(args, trainer, epoch_itr):
9387
trainer.dummy_train_step([dummy_batch])
@@ -390,7 +384,7 @@ def distributed_main(i, args):
390384
args.device_id = i
391385
if args.distributed_rank is None: # torch.multiprocessing.spawn
392386
args.distributed_rank = i
393-
main(args, init_distributed=True)
387+
main(args)
394388

395389

396390
def cli_main():

0 commit comments

Comments
 (0)