Skip to content

Commit

Permalink
Support for only performing norm weights allreduce in last microbatch…
Browse files Browse the repository at this point in the history
… | fairscale
  • Loading branch information
jiecaoyu committed Mar 22, 2024
1 parent d0b506f commit 74a7313
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
13 changes: 13 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __init__(
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
should_validate_process_group: bool = True,
tensor_parallel_group: Optional[ProcessGroup] = None,
):
try:
import torch._C
Expand All @@ -380,6 +381,7 @@ def __init__(
init_start = time.time()
super().__init__()
self.process_group = process_group or get_process_group_cached()
self.tensor_parallel_group = tensor_parallel_group
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
# the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default:
Expand Down Expand Up @@ -1726,6 +1728,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if not self._require_backward_grad_sync:
return

# run allreduce on param if necessary
if self.tensor_parallel_group.size() > 1:
if self.fp32_reduce_scatter:
orig_grad_data = param.unsharded_main_grad.data
else:
orig_grad_data = param.grad.data
for idx_pair in param._param_require_tp_allreduce:
param_allreduce = orig_grad_data[idx_pair[0]:idx_pair[1]].contiguous()
torch.distributed.all_reduce(param_allreduce, group=self.tensor_parallel_group)
orig_grad_data[idx_pair[0]:idx_pair[1]].copy_(param_allreduce)

# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
Expand Down
7 changes: 7 additions & 0 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) ->
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
self._param_require_tp_allreduce = []
for idx in range(len(params)):
p = params[idx]
if hasattr(p, "norm_allreduce_last_microbatch") and p.norm_allreduce_last_microbatch:
self._param_require_tp_allreduce.append(
[sum(self._param_numels[0:idx]), sum(self._param_numels[0:idx+1])]
)
assert self.numel() <= sum(
self._param_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
Expand Down

0 comments on commit 74a7313

Please sign in to comment.