@@ -40,7 +40,13 @@ def __init__(
4040 max_scale : float = 2 ** 32 ,
4141 ) -> None :
4242 super ().__init__ (
43- initial_scale , min_scale , growth_factor , backoff_factor , growth_interval , hysteresis , max_scale
43+ initial_scale ,
44+ min_scale ,
45+ growth_factor ,
46+ backoff_factor ,
47+ growth_interval ,
48+ hysteresis ,
49+ max_scale ,
4450 )
4551 self .num_working_param_groups = num_working_param_groups
4652 self .grad_store = grad_store
@@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list):
273279 # Backward Reduction Hook #
274280 ###########################
275281
276- def _grad_handler (self , param , group_id , grad ):
282+ def _grad_handler (self , group_id , param ):
277283 # if run with no_sync context, would not sync grad when backward
278284 if self .require_grad_sync :
279285 self ._add_to_bucket (param , group_id )
280- return grad
281286
282287 def _attach_reduction_hook (self ):
283288 # we iterate over the working params
@@ -286,7 +291,7 @@ def _attach_reduction_hook(self):
286291 param_group = self ._working_param_groups [group_id ]
287292 for param in param_group :
288293 if param .requires_grad :
289- param .register_hook (partial (self ._grad_handler , param , group_id ))
294+ param .register_post_accumulate_grad_hook (partial (self ._grad_handler , group_id ))
290295
291296 #######################
292297 # Reduction Functions #
@@ -415,15 +420,22 @@ def _run_reduction(self):
415420 recieved_grad = torch .zeros_like (flat_grads_list [0 ])
416421 dist .reduce_scatter (recieved_grad , flat_grads_list , group = self .dp_pg )
417422 self ._update_partitoned_grad (
418- non_moe_grad_in_bucket_current_rank , recieved_grad , group_id , 1
423+ non_moe_grad_in_bucket_current_rank ,
424+ recieved_grad ,
425+ group_id ,
426+ 1 ,
419427 )
420428
421429 if len (moe_grad_list ) > 0 :
422430 flat_grads_list = list (
423431 moe_flat_grads .split (len (moe_flat_grads ) // self .moe_extra_dp_pg_size )
424432 )
425433 recieved_grad = torch .zeros_like (flat_grads_list [0 ])
426- dist .reduce_scatter (recieved_grad , flat_grads_list , group = self .moe_extra_dp_pg )
434+ dist .reduce_scatter (
435+ recieved_grad ,
436+ flat_grads_list ,
437+ group = self .moe_extra_dp_pg ,
438+ )
427439 param_slice = self ._world_size // self .moe_extra_dp_pg_size
428440 recieved_grad = list (recieved_grad .split (len (recieved_grad ) // param_slice ))
429441 for split_recieved_grad in recieved_grad :
@@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List
444456 self ._add_grad (grad , self ._world_size , group_id , param_id , rank )
445457
446458 def _update_partitoned_grad (
447- self , origin_grad_list : List , flat_grad : torch .Tensor , group_id : int , partition_num : int
459+ self ,
460+ origin_grad_list : List ,
461+ flat_grad : torch .Tensor ,
462+ group_id : int ,
463+ partition_num : int ,
448464 ) -> None :
449465 sync_tensor (flat_grad , origin_grad_list )
450466 for grad in origin_grad_list :
451467 param_id = self ._bucket_store .get_param_id_of_grad (grad )
452468 self ._add_grad (grad , partition_num , group_id , param_id )
453469
454- def _add_grad (self , grad : torch .Tensor , partition_num : int , group_id : int , param_id : int , rank : int = 0 ) -> None :
470+ def _add_grad (
471+ self ,
472+ grad : torch .Tensor ,
473+ partition_num : int ,
474+ group_id : int ,
475+ param_id : int ,
476+ rank : int = 0 ,
477+ ) -> None :
455478 if len (self ._grad_store .get_partitioned_gradients_by_param_id (group_id , param_id )) < partition_num :
456479 self ._grad_store .append_gradients_by_param_id (grad , group_id , param_id )
457480 else :
@@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True):
534557 if param .grad is not None :
535558 param .grad .detach ()
536559 param .grad .zero_ ()
560+ self ._bucket_store .reset_all ()
537561
538562 ####################
539563 # Update Parameter #
@@ -655,14 +679,20 @@ def step(self, closure=None):
655679 for _ in range (self .moe_extra_dp_pg_size )
656680 ]
657681 dist .all_gather (
658- all_splited_param , splited_param .to (device ).to (self ._dtype ), group = self .moe_extra_dp_pg
682+ all_splited_param ,
683+ splited_param .to (device ).to (self ._dtype ),
684+ group = self .moe_extra_dp_pg ,
659685 )
660686 else :
661687 all_splited_param = [
662688 torch .zeros (splited_param .shape , device = device , dtype = self ._dtype )
663689 for _ in range (self ._world_size )
664690 ]
665- dist .all_gather (all_splited_param , splited_param .to (device ).to (self ._dtype ), group = self .dp_pg )
691+ dist .all_gather (
692+ all_splited_param ,
693+ splited_param .to (device ).to (self ._dtype ),
694+ group = self .dp_pg ,
695+ )
666696 working_param .data .copy_ (flatten (all_splited_param )[: working_param .numel ()].reshape_as (working_param ))
667697 self .optim .param_groups [group_id ]["params" ] = self ._master_param_groups_of_current_rank [group_id ]
668698
@@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
685715 if norm_type == inf :
686716 total_norm = max (grad .data .abs ().max () for grad in gradients )
687717 total_norm_cuda = torch .tensor (
688- [float (total_norm )], device = get_accelerator ().get_current_device (), dtype = torch .float
718+ [float (total_norm )],
719+ device = get_accelerator ().get_current_device (),
720+ dtype = torch .float ,
689721 )
690722 dist .all_reduce (total_norm_cuda , op = torch .distributed .ReduceOp .MAX , group = self .dp_pg )
691723 total_norm = total_norm_cuda .item ()
@@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
698730
699731 # Sum across all model parallel GPUs.
700732 total_norm_exponentiated_cuda = torch .tensor (
701- [float (total_norm_exponentiated )], device = get_accelerator ().get_current_device (), dtype = torch .float
733+ [float (total_norm_exponentiated )],
734+ device = get_accelerator ().get_current_device (),
735+ dtype = torch .float ,
702736 )
703737 torch .distributed .all_reduce (
704- total_norm_exponentiated_cuda , op = torch .distributed .ReduceOp .SUM , group = self .dp_pg
738+ total_norm_exponentiated_cuda ,
739+ op = torch .distributed .ReduceOp .SUM ,
740+ group = self .dp_pg ,
705741 )
706742 total_norm = total_norm_exponentiated_cuda .item () ** (1.0 / norm_type )
707743
@@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
920956
921957 def get_master_to_working_map (self ) -> Dict [int , torch .Tensor ]:
922958 if hasattr (self , "moe_master_to_working_map" ):
923- return {** self ._param_store .master_to_working_param , ** self .moe_master_to_working_map }
959+ return {
960+ ** self ._param_store .master_to_working_param ,
961+ ** self .moe_master_to_working_map ,
962+ }
924963 return self ._param_store .master_to_working_param
0 commit comments