Skip to content

Commit

Permalink
Enable hpz when running with torch.no_grad (#4232)
Browse files Browse the repository at this point in the history
* enable hpz when running with torch.no_grad

* change the way to detect no_grad

* fix format

---------

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
  • Loading branch information
HeyangQin and awan-10 authored Aug 31, 2023
1 parent 6cbf666 commit 462def4
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,19 +490,20 @@ def _run_after_backward_function(sub_module):
# post backward hook
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

prev_grad_state = torch.is_grad_enabled(
) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
torch.set_grad_enabled(False)
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=True)

param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
torch.set_grad_enabled(prev_grad_state)
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

@torch.no_grad()
Expand Down

0 comments on commit 462def4

Please sign in to comment.