@@ -109,8 +109,8 @@ def __init__(
109109 layer_idx = extract_layer_index (prefix )
110110 self .layer_idx = layer_idx
111111 self .tp_size = get_tensor_model_parallel_world_size ()
112- self .moe_num_shared_experts = getattr (config , "moe_num_shared_experts" ,
113- None )
112+ self .has_shared_experts = ( getattr (config , "moe_num_shared_experts" , 0 )
113+ > 0 )
114114
115115 if self .tp_size > config .moe_num_experts :
116116 raise ValueError (
@@ -137,7 +137,7 @@ def __init__(
137137 prefix = f"{ prefix } .experts" ,
138138 e_score_correction_bias = self .gate .e_score_correction_bias )
139139
140- if self .moe_num_shared_experts is not None :
140+ if self .has_shared_experts :
141141 intermediate_size = (config .moe_intermediate_size *
142142 config .moe_num_shared_experts )
143143 self .shared_experts = Ernie4_5_MoeMLP (
@@ -153,15 +153,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153153 orig_shape = hidden_states .shape
154154 hidden_dim = hidden_states .shape [- 1 ]
155155 hidden_states = hidden_states .view (- 1 , hidden_dim )
156- if self .moe_num_shared_experts is not None :
156+ shared_output = None
157+ if self .has_shared_experts :
157158 shared_output = self .shared_experts (hidden_states )
158159
159160 router_logits , _ = self .gate (hidden_states )
160161
161162 final_hidden_states = self .experts (hidden_states = hidden_states ,
162163 router_logits = router_logits )
163164
164- if self .moe_num_shared_experts is not None and \
165+ if self .has_shared_experts and \
165166 shared_output is not None :
166167 final_hidden_states = final_hidden_states + shared_output
167168
0 commit comments