Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero stage2 cpu_offload when some model trainable parameters skipped in training #861

Merged
merged 5 commits into from
Mar 27, 2021
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# as some model have trainable parameters but skipped in training,
# their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2

# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
Expand Down