Skip to content

Commit

Permalink
stage_1_and_2: optimize clip calculation to use clamp (deepspeedai#5632)
Browse files Browse the repository at this point in the history
instead of "if" that causes host/device synchronization and introduces a
bubble, while clamp is hapenning on the device
  • Loading branch information
nelyahu authored Jun 10, 2024
1 parent 6e2899f commit 1ef9b02
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,8 +1962,8 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale

for grad in grad_groups_flat:
if isinstance(grad, list):
Expand Down

0 comments on commit 1ef9b02

Please sign in to comment.