Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO-1 empty grads fix + tests #1273

Merged
merged 3 commits into from
Aug 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
prevent none grads from being reduced
  • Loading branch information
jeffra committed Aug 2, 2021
commit e249c12540ab7cd6d38bdd425387bc1d1dff8a6f
19 changes: 10 additions & 9 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def initialize_optimizer_states(self):

if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
group.grad = None #class init

return

Expand All @@ -497,7 +497,8 @@ def reduce_gradients(self, pipeline_parallel=False):
if not self.overlap_comm:
for i, group in enumerate(self.fp16_groups):
for param in group:
self.reduce_ready_partitions_and_remove_grads(param, i)
if param.grad is not None:
self.reduce_ready_partitions_and_remove_grads(param, i)

# reduce any pending grads in either hook/non-hook case
self.overlapping_partition_gradients_reduce_epilogue()
Expand Down Expand Up @@ -974,7 +975,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):

src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None
param.grad = None #offload only

def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
Expand Down Expand Up @@ -1105,7 +1106,7 @@ def reduce_ipg_grads(self):
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
param.grad = None
param.grad = None #only if self.partition_gradients
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)

Expand All @@ -1127,7 +1128,7 @@ def are_all_related_partitions_reduced(params_id):

for params_id in self.is_grad_computed[i][partition_id]:
if are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None
self.param_dict[params_id].grad = None # dead code

def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = self.flatten(tensors)
Expand Down Expand Up @@ -1216,7 +1217,7 @@ def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=N
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
param.grad = None
param.grad = None # overlap enabled
self.previous_reduced_grads = None

# if rank is specified do a reduction instead of an allreduce
Expand Down Expand Up @@ -1333,7 +1334,7 @@ def zero_grad(self, set_grads_to_None=True):
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
p.grad = None # epilogue and in step
else:
if p.grad is not None:
p.grad.detach_()
Expand Down Expand Up @@ -1459,7 +1460,7 @@ def get_flat_partition(self,

def free_grad_in_param_list(self, param_list):
for p in param_list:
p.grad = None
p.grad = None # in step

def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
Expand Down Expand Up @@ -1585,7 +1586,7 @@ def step(self, closure=None):
# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None
group.grad = None # in step

for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
Expand Down