Skip to content

Commit f1aaf8a

Browse files
H-Huangfacebook-github-bot
authored andcommitted
Revert D30745960: [DDP] Remove SPMD from self.modules_buffers
Test Plan: revert-hammer Differential Revision: D30745960 (pytorch@1553259) Original commit changeset: 66a8f9847e9f fbshipit-source-id: d3f3fb813c45ac1b0ff15c6154b2e99e5dbab433
1 parent 3bf93d7 commit f1aaf8a

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

torch/nn/parallel/distributed.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -748,15 +748,21 @@ def _assign_modules_buffers(self):
748748
"""
749749
# Collect buffers for modules, filtering out buffers that should be ignored.
750750
named_module_buffers = [
751-
(buffer, buffer_name)
752-
for buffer_name, buffer in self.module.named_buffers()
751+
[
752+
(buffer, buffer_name)
753+
for buffer_name, buffer in self.module.named_buffers()
754+
]
753755
]
754756
self.modules_buffers = [
755-
buffer
756-
for (buffer, buffer_name) in named_module_buffers
757-
if buffer_name not in self.parameters_to_ignore
757+
[
758+
buffer
759+
for (buffer, buffer_name) in module_buffers
760+
if buffer_name not in self.parameters_to_ignore
761+
]
762+
for module_buffers in named_module_buffers
758763
]
759764

765+
760766
def _build_param_to_name_mapping(self, parameters):
761767
param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
762768
param_set = set(parameters[0])
@@ -1033,7 +1039,7 @@ def _check_and_sync_module_buffers(self):
10331039
if self.will_sync_module_buffers():
10341040
authoritative_rank = self._find_common_rank(self._distributed_rank, False)
10351041
self._distributed_broadcast_coalesced(
1036-
self.modules_buffers, self.broadcast_bucket_size, authoritative_rank
1042+
self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank
10371043
)
10381044

10391045
# When running in join model, agrees upon a common rank and broadcast model
@@ -1339,7 +1345,7 @@ def will_sync_module_buffers(self):
13391345
return (
13401346
self.require_forward_param_sync
13411347
and self.broadcast_buffers
1342-
and len(self.modules_buffers) > 0
1348+
and len(self.modules_buffers[0]) > 0
13431349
)
13441350

13451351
def _find_common_rank(self, input_rank, rank_cond):
@@ -1377,7 +1383,7 @@ def _sync_params(self):
13771383
# reassigned.
13781384
self._assign_modules_buffers()
13791385
self._distributed_broadcast_coalesced(
1380-
self.modules_buffers,
1386+
self.modules_buffers[0],
13811387
self.broadcast_bucket_size,
13821388
authoritative_rank,
13831389
)

0 commit comments

Comments
 (0)