Skip to content

Commit 1b769dc

Browse files
authored
[Bugfix] Fix Ernie4_5_MoeForCausalLM shared experts (#21717)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 2cc5711 commit 1b769dc

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/model_executor/models/ernie45_moe.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(
109109
layer_idx = extract_layer_index(prefix)
110110
self.layer_idx = layer_idx
111111
self.tp_size = get_tensor_model_parallel_world_size()
112-
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts",
113-
None)
112+
self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
113+
> 0)
114114

115115
if self.tp_size > config.moe_num_experts:
116116
raise ValueError(
@@ -137,7 +137,7 @@ def __init__(
137137
prefix=f"{prefix}.experts",
138138
e_score_correction_bias=self.gate.e_score_correction_bias)
139139

140-
if self.moe_num_shared_experts is not None:
140+
if self.has_shared_experts:
141141
intermediate_size = (config.moe_intermediate_size *
142142
config.moe_num_shared_experts)
143143
self.shared_experts = Ernie4_5_MoeMLP(
@@ -153,15 +153,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153153
orig_shape = hidden_states.shape
154154
hidden_dim = hidden_states.shape[-1]
155155
hidden_states = hidden_states.view(-1, hidden_dim)
156-
if self.moe_num_shared_experts is not None:
156+
shared_output = None
157+
if self.has_shared_experts:
157158
shared_output = self.shared_experts(hidden_states)
158159

159160
router_logits, _ = self.gate(hidden_states)
160161

161162
final_hidden_states = self.experts(hidden_states=hidden_states,
162163
router_logits=router_logits)
163164

164-
if self.moe_num_shared_experts is not None and \
165+
if self.has_shared_experts and \
165166
shared_output is not None:
166167
final_hidden_states = final_hidden_states + shared_output
167168

0 commit comments

Comments
 (0)