1- import warnings
21from functools import partial
32from typing import Callable , Dict , List , Optional , Union
43
2120
2221
2322class OpenMoePolicy (Policy ):
24-
2523 def config_sanity_check (self ):
2624 pass
2725
@@ -43,7 +41,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
4341 if self .shard_config .enable_sequence_parallelism :
4442 self .shard_config .enable_sequence_parallelism = False
4543 raise NotImplementedError (
46- "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag." )
44+ "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
45+ )
4746
4847 if self .shard_config .enable_tensor_parallelism :
4948 raise NotImplementedError ("Tensor parallelism is not supported for openmoe model now." )
@@ -143,7 +142,6 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
143142
144143
145144class OpenMoeModelPolicy (OpenMoePolicy ):
146-
147145 def __init__ (self ) -> None :
148146 super ().__init__ ()
149147
@@ -169,21 +167,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
169167
170168
171169class OpenMoeForCausalLMPolicy (OpenMoePolicy ):
172-
173170 def module_policy (self ):
174171 policy = super ().module_policy ()
175172
176173 if self .shard_config .enable_tensor_parallelism :
177174 # add a new item for casual lm
178175 new_item = {
179- OpenMoeForCausalLM :
180- ModulePolicyDescription ( sub_module_replacement = [
176+ OpenMoeForCausalLM : ModulePolicyDescription (
177+ sub_module_replacement = [
181178 SubModuleReplacementDescription (
182179 suffix = "lm_head" ,
183180 target_module = Linear1D_Col ,
184181 kwargs = dict (gather_output = True ),
185182 )
186- ])
183+ ]
184+ )
187185 }
188186 policy .update (new_item )
189187
@@ -208,13 +206,17 @@ def get_held_layers(self) -> List[Module]:
208206 def get_shared_params (self ) -> List [Dict [int , Tensor ]]:
209207 llama_model = self .model .model
210208 if self .pipeline_stage_manager and self .pipeline_stage_manager .num_stages > 1 :
211- if (id (llama_model .embed_tokens .weight ) == id (self .model .lm_head .weight )
212- and self .pipeline_stage_manager .num_stages > 1 ):
209+ if (
210+ id (llama_model .embed_tokens .weight ) == id (self .model .lm_head .weight )
211+ and self .pipeline_stage_manager .num_stages > 1
212+ ):
213213 # tie weights
214- return [{
215- 0 : llama_model .embed_tokens .weight ,
216- self .pipeline_stage_manager .num_stages - 1 : self .model .lm_head .weight ,
217- }]
214+ return [
215+ {
216+ 0 : llama_model .embed_tokens .weight ,
217+ self .pipeline_stage_manager .num_stages - 1 : self .model .lm_head .weight ,
218+ }
219+ ]
218220 return []
219221
220222
@@ -247,12 +249,13 @@ def openmoe_model_forward(
247249
248250 logger = logging .get_logger (__name__ )
249251
250- output_attentions = (output_attentions if output_attentions is not None else self .config .output_attentions )
251- output_hidden_states = (output_hidden_states
252- if output_hidden_states is not None else self .config .output_hidden_states )
252+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
253+ output_hidden_states = (
254+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
255+ )
253256 use_cache = use_cache if use_cache is not None else self .config .use_cache
254257
255- return_dict = ( return_dict if return_dict is not None else self .config .use_return_dict )
258+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
256259
257260 # retrieve input_ids and inputs_embeds
258261 if stage_manager .is_first_stage ():
@@ -320,7 +323,8 @@ def openmoe_model_forward(
320323 if self .gradient_checkpointing and self .training :
321324 if use_cache :
322325 logger .warning_once (
323- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." )
326+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
327+ )
324328 use_cache = False
325329
326330 # decoder layers
@@ -333,12 +337,11 @@ def openmoe_model_forward(
333337 if output_hidden_states :
334338 all_hidden_states += (hidden_states ,)
335339
336- past_key_value = ( past_key_values [idx ] if past_key_values is not None else None )
340+ past_key_value = past_key_values [idx ] if past_key_values is not None else None
337341
338342 if self .gradient_checkpointing and self .training :
339343
340344 def create_custom_forward (module ):
341-
342345 def custom_forward (* inputs ):
343346 # None for past_key_value
344347 return module (* inputs , output_attentions , None )
@@ -384,14 +387,16 @@ def custom_forward(*inputs):
384387 router_z_loss = past_router_z_loss + router_z_loss
385388
386389 if stage_manager .is_last_stage ():
387- return tuple ([
388- hidden_states ,
389- next_cache ,
390- all_hidden_states ,
391- all_self_attns ,
392- router_aux_loss ,
393- router_z_loss ,
394- ])
390+ return tuple (
391+ [
392+ hidden_states ,
393+ next_cache ,
394+ all_hidden_states ,
395+ all_self_attns ,
396+ router_aux_loss ,
397+ router_z_loss ,
398+ ]
399+ )
395400 # always return dict for imediate stage
396401 return {
397402 "hidden_states" : hidden_states ,
@@ -445,10 +450,11 @@ def llama_for_causal_lm_forward(
445450 "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
446451 ```"""
447452 logger = logging .get_logger (__name__ )
448- output_attentions = (output_attentions if output_attentions is not None else self .config .output_attentions )
449- output_hidden_states = (output_hidden_states
450- if output_hidden_states is not None else self .config .output_hidden_states )
451- return_dict = (return_dict if return_dict is not None else self .config .use_return_dict )
453+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
454+ output_hidden_states = (
455+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
456+ )
457+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
452458
453459 # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
454460 if output_attentions :
@@ -504,7 +510,6 @@ def llama_for_causal_lm_forward(
504510 if chunk_head == True :
505511
506512 def create_custom_forward (module ):
507-
508513 def custom_forward (* inputs ):
509514 logits = module (inputs [0 ])
510515 logits = logits .float ()
@@ -522,8 +527,8 @@ def custom_forward(*inputs):
522527 for batch_idx in range (hidden_states .shape [0 ]):
523528 loss = loss + torch .utils .checkpoint .checkpoint (
524529 create_custom_forward (self .lm_head ),
525- hidden_states [batch_idx : batch_idx + 1 , :],
526- labels [batch_idx : batch_idx + 1 , :],
530+ hidden_states [batch_idx : batch_idx + 1 , :],
531+ labels [batch_idx : batch_idx + 1 , :],
527532 )
528533 logits = None
529534 else :
0 commit comments