Skip to content

Commit

Permalink
Fix RuntimeError when using ZeRO Stage3 with mpu: #3564 (#3565)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
eggiter and tjruwase authored Jun 1, 2023
1 parent 3b29999 commit f5dde36
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,12 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
grad_norms.append(g.to(get_accelerator().device_name(), non_blocking=True).double().norm(2))

# Sum across all model parallel GPUs.
total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2))
if len(grad_norms) == 0:
# FIX https://github.com/microsoft/DeepSpeed/issues/3564
total_norm_cuda = torch.tensor(0,
dtype=gradients[0].dtype).to(get_accelerator().device_name()).double()
else:
total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2))

dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)

Expand Down

0 comments on commit f5dde36

Please sign in to comment.