Skip to content

Commit

Permalink
Z3: optimizations for grad norm calculation and gradient clipping (#5504
Browse files Browse the repository at this point in the history
)

This PR add the below functionality:
1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to
CPU, as expected device in such case is CPU..
2. repalce get_global_norm() with torch.linalg.norm for better
performance.
3. unscale_and_clip_grads: replace clipping based on if statement to use
torch.clamp for better performance.

change (3) is taken from
#5547 (which was closed)

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Liran Bachar <lbachar@habana.ai>
  • Loading branch information
5 people authored Aug 14, 2024
1 parent 19b01e1 commit 6eed634
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
10 changes: 5 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
Expand Down Expand Up @@ -1413,7 +1413,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

return total_norm
return total_norm.cpu()

@instrument_w_nvtx
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
Expand Down Expand Up @@ -2028,7 +2028,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down Expand Up @@ -2112,8 +2112,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale

self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)

Expand Down
1 change: 1 addition & 0 deletions tests/unit/runtime/zero/test_zero_offloadpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test(self, h_dim: int, n_layers: int) -> None:
config_dict = {
"train_batch_size": 256,
"steps_per_print": 1,
"gradient_clipping": 1.0,
"optimizer": {
"type": "Adam",
"params": {
Expand Down

0 comments on commit 6eed634

Please sign in to comment.