From 9027f861f216249eaacfc6b2cd70a202c72798c3 Mon Sep 17 00:00:00 2001 From: Aman Sanger Date: Tue, 19 Jul 2022 14:32:06 -0400 Subject: [PATCH] Dont overwrite hook handles in flop profiler (#2106) Co-authored-by: Olatunji Ruwase --- .../profiling/flops_profiler/profiler.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 32a68be91d79..7fbfb19c777f 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -74,8 +74,9 @@ def register_module_hooks(module, ignore_list): # if computing the flops of a module directly if type(module) in MODULE_HOOK_MAPPING: - module.__flops_handle__ = module.register_forward_hook( - MODULE_HOOK_MAPPING[type(module)]) + if not hasattr(module, "__flops_handle__"): + module.__flops_handle__ = module.register_forward_hook( + MODULE_HOOK_MAPPING[type(module)]) return # if computing the flops of the functionals in a module @@ -83,7 +84,8 @@ def pre_hook(module, input): module_flop_count.append([]) module_mac_count.append([]) - module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) + if not hasattr(module, "__pre_hook_handle__"): + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) def post_hook(module, input, output): if module_flop_count: @@ -92,20 +94,24 @@ def post_hook(module, input, output): module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]]) module_mac_count.pop() - module.__post_hook_handle__ = module.register_forward_hook(post_hook) + if not hasattr(module, "__post_hook_handle__"): + module.__post_hook_handle__ = module.register_forward_hook(post_hook) def start_time_hook(module, input): torch.cuda.synchronize() module.__start_time__ = time.time() - module.__start_time_hook_handle__ = module.register_forward_pre_hook( - start_time_hook) + if not hasattr(module, "__start_time_hook_handle"): + module.__start_time_hook_handle__ = module.register_forward_pre_hook( + start_time_hook) def end_time_hook(module, input, output): torch.cuda.synchronize() module.__duration__ += time.time() - module.__start_time__ - module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook) + if not hasattr(module, "__end_time_hook_handle__"): + module.__end_time_hook_handle__ = module.register_forward_hook( + end_time_hook) self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) self.started = True