Skip to content

Commit

Permalink
Dont overwrite hook handles in flop profiler (#2106)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
Sanger2000 and tjruwase authored Jul 19, 2022
1 parent 16699d8 commit 9027f86
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,18 @@ def register_module_hooks(module, ignore_list):

# if computing the flops of a module directly
if type(module) in MODULE_HOOK_MAPPING:
module.__flops_handle__ = module.register_forward_hook(
MODULE_HOOK_MAPPING[type(module)])
if not hasattr(module, "__flops_handle__"):
module.__flops_handle__ = module.register_forward_hook(
MODULE_HOOK_MAPPING[type(module)])
return

# if computing the flops of the functionals in a module
def pre_hook(module, input):
module_flop_count.append([])
module_mac_count.append([])

module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
if not hasattr(module, "__pre_hook_handle__"):
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)

def post_hook(module, input, output):
if module_flop_count:
Expand All @@ -92,20 +94,24 @@ def post_hook(module, input, output):
module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]])
module_mac_count.pop()

module.__post_hook_handle__ = module.register_forward_hook(post_hook)
if not hasattr(module, "__post_hook_handle__"):
module.__post_hook_handle__ = module.register_forward_hook(post_hook)

def start_time_hook(module, input):
torch.cuda.synchronize()
module.__start_time__ = time.time()

module.__start_time_hook_handle__ = module.register_forward_pre_hook(
start_time_hook)
if not hasattr(module, "__start_time_hook_handle"):
module.__start_time_hook_handle__ = module.register_forward_pre_hook(
start_time_hook)

def end_time_hook(module, input, output):
torch.cuda.synchronize()
module.__duration__ += time.time() - module.__start_time__

module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook)
if not hasattr(module, "__end_time_hook_handle__"):
module.__end_time_hook_handle__ = module.register_forward_hook(
end_time_hook)

self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
self.started = True
Expand Down

0 comments on commit 9027f86

Please sign in to comment.