Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self,
self.gas_boundary_ctr = 0
self.dist_backend = "nccl"
self._step_applied = False
self._global_grad_norm = None

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
Expand Down Expand Up @@ -256,6 +258,30 @@ def set_train_batch_size(self, train_batch_size):
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas

def _compute_global_grad_norm(self):
params = [p for p in self.module.parameters() if p.grad is not None]
Copy link
Collaborator

@stas00 stas00 Aug 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be a bit more efficient if tied params don't get calculated more than once? Probably, something like:

params = dict((p.data_ptr(), p) for p in self.module.parameters() if p.grad is not None).values()

or it might help to have a wrapper to do that, since this is a handy util.

But please verify that I got it right. Thanks.

return get_grad_norm(params, mpu=self.mpu)

def get_global_grad_norm(self, force_compute=False) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.

The computed norm will be cached and reused until the next step()
pass unless ``force_compute=True``.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Args:
force_compute (bool, optional): Force a recomputation of the norm. Defaults to False.
Returns:
float: norm
"""
# Check for an outdated parameter norm.
if force_compute or self._global_grad_norm is None:
self._global_grad_norm = self._compute_global_grad_norm()

return self._global_grad_norm

def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled

Expand Down Expand Up @@ -1315,6 +1341,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):

self.optimizer.step()

if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm

# Quantize the updated parameter if there no overflow
if self.quantizer:
self.quantizer.quantize(
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self._global_grad_norm = 0.

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
Expand Down Expand Up @@ -251,6 +253,8 @@ def step(self, closure=None):
all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
self.stop_timers([COMPUTE_NORM])

self._global_grad_norm = all_groups_norm

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.stop_timers([UNSCALE_AND_CLIP])
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self,
fused_lamb_legacy=False):

self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.

if torch.distributed.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
Expand Down Expand Up @@ -217,6 +218,7 @@ def unscale_and_clip_grads(self, norm_groups, apply_scale=True):
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)
self._global_grad_norm = total_norm

# compute combined scale factor for this group
combined_scale = self.cur_scale
Expand Down