@@ -694,7 +694,7 @@ def forward(
694694 return forward
695695
696696
697- def get_llama_seq_parallel_model_forward (sp_mode , sp_size , sp_group ):
697+ def get_llama_seq_parallel_model_forward (sp_mode , sp_size , sp_group , zero_stage = 0 ):
698698 logger = logging .get_logger (__name__ )
699699
700700 # Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -804,10 +804,6 @@ def forward(
804804 else :
805805 raise ValueError ("You have to specify either decoder_input_ids or decoder_inputs_embeds" )
806806
807- # sp: modify seq_length when using sequence parallel
808- if sp_mode in ["ring" , "all_to_all" ]:
809- seq_length *= sp_size
810-
811807 seq_length_with_past = seq_length
812808 past_key_values_length = 0
813809
@@ -827,13 +823,12 @@ def forward(
827823 position_ids = position_ids .view (- 1 , seq_length ).long ()
828824
829825 if inputs_embeds is None :
830- if sp_mode == "ring" :
831- input_ids = _gather (input_ids , 1 , sp_group )
832- inputs_embeds = self .embed_tokens (input_ids )
833- input_ids = input_ids .chunk (sp_size , dim = 1 )[torch .distributed .get_rank (sp_group )]
834- inputs_embeds = split_forward_gather_backward (inputs_embeds , 1 , sp_group )
835- else :
836- inputs_embeds = self .embed_tokens (input_ids )
826+ inputs_embeds = self .embed_tokens (input_ids )
827+
828+ if sp_mode in ["ring" , "split_gather" ]:
829+ inputs_embeds = split_forward_gather_backward (inputs_embeds , 1 , sp_group )
830+ elif sp_mode == "all_to_all" :
831+ inputs_embeds = split_forward_gather_backward (inputs_embeds , 1 , sp_group , 'down' )
837832
838833 # TODO use_distributed_mask
839834 use_distributed_mask = True if sp_mode in ["ring" , "all_to_all" ] else False
@@ -864,8 +859,6 @@ def forward(
864859 attention_mask = _gather (attention_mask , 1 , sp_group )
865860
866861 hidden_states = inputs_embeds
867- if sp_mode == "split_gather" :
868- hidden_states = split_forward_gather_backward (hidden_states , 1 , sp_group )
869862
870863 if (self .gradient_checkpointing or sp_mode in ["ring" , "all_to_all" ]) and self .training :
871864 if use_cache :
@@ -922,7 +915,10 @@ def custom_forward(*inputs):
922915 hidden_states = self .norm (hidden_states )
923916
924917 # Todo: Maybe this line can be optimized
925- hidden_states = gather_forward_split_backward (hidden_states , 1 , sp_group , grad_scale = "up" )
918+ if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0 ):
919+ hidden_states = gather_forward_split_backward (hidden_states , 1 , sp_group )
920+ elif sp_mode == "all_to_all" and zero_stage in [1 , 2 ]:
921+ hidden_states = gather_forward_split_backward (hidden_states , 1 , sp_group , grad_scale = "up" )
926922
927923 # add hidden states from the last decoder layer
928924 if output_hidden_states :
0 commit comments