Skip to content

Commit

Permalink
Support for only performing norm weights allreduce in last microbatch…
Browse files Browse the repository at this point in the history
… | fairscale
  • Loading branch information
jiecaoyu committed Mar 18, 2024
1 parent d0b506f commit 9eba19b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
10 changes: 10 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __init__(
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
should_validate_process_group: bool = True,
tensor_parallel_group: Optional[ProcessGroup] = None,
):
try:
import torch._C
Expand All @@ -380,6 +381,7 @@ def __init__(
init_start = time.time()
super().__init__()
self.process_group = process_group or get_process_group_cached()
self.tensor_parallel_group = tensor_parallel_group
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
# the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default:
Expand Down Expand Up @@ -1737,6 +1739,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
else:
orig_grad_data = param.grad.data

for idx in range(len(param._param_require_allreduce)):
if param._param_require_allreduce[idx] and (self.tensor_parallel_group.size() > 1):
start_idx = sum(param._param_numels[0 : idx])
end_idx = sum(param._param_numels[0 : idx + 1])
param_allreduce = orig_grad_data[start_idx:end_idx].contiguous()
torch.distributed.all_reduce(param_allreduce, group=self.tensor_parallel_group)
orig_grad_data[start_idx:end_idx].copy_(param_allreduce)

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
if getattr(param, "unsharded_main_grad", None) is not None:
Expand Down
3 changes: 3 additions & 0 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) ->
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
self._param_require_allreduce = [
p.norm_allreduce_last_microbatch if hasattr(p, "norm_allreduce_last_microbatch") else False for p in params
]
assert self.numel() <= sum(
self._param_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
Expand Down

0 comments on commit 9eba19b

Please sign in to comment.