23
23
from fairseq .meters import AverageMeter , StopwatchMeter
24
24
25
25
26
- def main (args ):
26
+ def main (args , init_distributed = False ):
27
27
utils .import_user_module (args )
28
28
29
- if args .max_tokens is None :
30
- args .max_tokens = 6000
31
- print (args )
29
+ assert args .max_tokens is not None or args .max_sentences is not None , \
30
+ 'Must specify batch size either with --max-tokens or --max-sentences'
32
31
32
+ # Initialize CUDA and distributed training
33
33
if torch .cuda .is_available () and not args .cpu :
34
34
torch .cuda .set_device (args .device_id )
35
35
torch .manual_seed (args .seed )
36
+ if init_distributed :
37
+ args .distributed_rank = distributed_utils .distributed_init (args )
38
+
39
+ # Print args
40
+ print (args )
36
41
37
42
# Setup task, e.g., translation, language modeling, etc.
38
43
task = tasks .setup_task (args )
@@ -372,11 +377,11 @@ def load_dataset_splits(args, task):
372
377
raise e
373
378
374
379
375
- def distributed_main (i , args ):
380
+ def distributed_main (i , args , start_rank = 0 ):
376
381
args .device_id = i
377
382
if args .distributed_rank is None : # torch.multiprocessing.spawn
378
- args .distributed_rank = i
379
- main (args )
383
+ args .distributed_rank = start_rank + i
384
+ main (args , init_distributed = True )
380
385
381
386
382
387
def cli_main ():
@@ -388,9 +393,19 @@ def cli_main():
388
393
389
394
if args .distributed_init_method is not None :
390
395
# distributed training
391
- distributed_main (args .device_id , args )
396
+ if torch .cuda .device_count () > 1 and not args .distributed_no_spawn :
397
+ start_rank = args .distributed_rank
398
+ args .distributed_rank = None # assign automatically
399
+ torch .multiprocessing .spawn (
400
+ fn = distributed_main ,
401
+ args = (args , start_rank ),
402
+ nprocs = torch .cuda .device_count (),
403
+ )
404
+ else :
405
+ distributed_main (args .device_id , args )
392
406
elif args .distributed_world_size > 1 :
393
407
# fallback for single node with multiple GPUs
408
+ assert args .distributed_world_size <= torch .cuda .device_count ()
394
409
port = random .randint (10000 , 20000 )
395
410
args .distributed_init_method = 'tcp://localhost:{port}' .format (port = port )
396
411
args .distributed_rank = None # set based on device id
0 commit comments