@@ -748,15 +748,21 @@ def _assign_modules_buffers(self):
748
748
"""
749
749
# Collect buffers for modules, filtering out buffers that should be ignored.
750
750
named_module_buffers = [
751
- (buffer , buffer_name )
752
- for buffer_name , buffer in self .module .named_buffers ()
751
+ [
752
+ (buffer , buffer_name )
753
+ for buffer_name , buffer in self .module .named_buffers ()
754
+ ]
753
755
]
754
756
self .modules_buffers = [
755
- buffer
756
- for (buffer , buffer_name ) in named_module_buffers
757
- if buffer_name not in self .parameters_to_ignore
757
+ [
758
+ buffer
759
+ for (buffer , buffer_name ) in module_buffers
760
+ if buffer_name not in self .parameters_to_ignore
761
+ ]
762
+ for module_buffers in named_module_buffers
758
763
]
759
764
765
+
760
766
def _build_param_to_name_mapping (self , parameters ):
761
767
param_to_param_index = {parameters [0 ][i ]: i for i in range (len (parameters [0 ]))}
762
768
param_set = set (parameters [0 ])
@@ -1033,7 +1039,7 @@ def _check_and_sync_module_buffers(self):
1033
1039
if self .will_sync_module_buffers ():
1034
1040
authoritative_rank = self ._find_common_rank (self ._distributed_rank , False )
1035
1041
self ._distributed_broadcast_coalesced (
1036
- self .modules_buffers , self .broadcast_bucket_size , authoritative_rank
1042
+ self .modules_buffers [ 0 ] , self .broadcast_bucket_size , authoritative_rank
1037
1043
)
1038
1044
1039
1045
# When running in join model, agrees upon a common rank and broadcast model
@@ -1339,7 +1345,7 @@ def will_sync_module_buffers(self):
1339
1345
return (
1340
1346
self .require_forward_param_sync
1341
1347
and self .broadcast_buffers
1342
- and len (self .modules_buffers ) > 0
1348
+ and len (self .modules_buffers [ 0 ] ) > 0
1343
1349
)
1344
1350
1345
1351
def _find_common_rank (self , input_rank , rank_cond ):
@@ -1377,7 +1383,7 @@ def _sync_params(self):
1377
1383
# reassigned.
1378
1384
self ._assign_modules_buffers ()
1379
1385
self ._distributed_broadcast_coalesced (
1380
- self .modules_buffers ,
1386
+ self .modules_buffers [ 0 ] ,
1381
1387
self .broadcast_bucket_size ,
1382
1388
authoritative_rank ,
1383
1389
)
0 commit comments