1
- """Inference-only PanGuMoE model compatible with HuggingFace weights."""
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # This file is a part of the vllm-ascend project.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
2
18
from typing import Any , Dict , Iterable , List , Optional , Set , Tuple , Union
3
19
4
20
import torch
40
56
_ROUTER_SCALE = None
41
57
42
58
43
- class PanGuMoeMLP (nn .Module ):
59
+ class PanguProMoEMLP (nn .Module ):
44
60
45
61
def __init__ (
46
62
self ,
@@ -79,7 +95,7 @@ def forward(self, x):
79
95
return x
80
96
81
97
82
- class PanGuMoeSparseMoeBlock (nn .Module ):
98
+ class PanguProMoESparseMoeBlock (nn .Module ):
83
99
84
100
@staticmethod
85
101
def pangu_group8_topk (
@@ -152,7 +168,7 @@ def __init__(
152
168
intermediate_size = config .moe_intermediate_size ,
153
169
reduce_results = False ,
154
170
quant_config = quant_config ,
155
- custom_routing_function = PanGuMoeSparseMoeBlock .pangu_group8_topk ,
171
+ custom_routing_function = PanguProMoESparseMoeBlock .pangu_group8_topk ,
156
172
prefix = f"{ prefix } .experts" ,
157
173
)
158
174
@@ -165,7 +181,7 @@ def __init__(
165
181
)
166
182
167
183
if config .shared_expert_intermediate_size > 0 :
168
- self .shared_expert = PanGuMoeMLP (
184
+ self .shared_expert = PanguProMoEMLP (
169
185
hidden_size = config .hidden_size ,
170
186
intermediate_size = config .shared_expert_intermediate_size ,
171
187
hidden_act = config .hidden_act ,
@@ -201,7 +217,7 @@ def forward(
201
217
return final_hidden_states .view (num_tokens , hidden_dim )
202
218
203
219
204
- class PanGuMoeAttention (nn .Module ):
220
+ class PanguProMoEAttention (nn .Module ):
205
221
206
222
def __init__ (
207
223
self ,
@@ -288,7 +304,7 @@ def forward(
288
304
return output
289
305
290
306
291
- class PanGuMoeDecoderLayer (nn .Module ):
307
+ class PanguProMoEDecoderLayer (nn .Module ):
292
308
293
309
def __init__ (
294
310
self ,
@@ -304,7 +320,7 @@ def __init__(
304
320
max_position_embeddings = getattr (config , "max_position_embeddings" ,
305
321
8192 )
306
322
307
- self .self_attn = PanGuMoeAttention (
323
+ self .self_attn = PanguProMoEAttention (
308
324
hidden_size = self .hidden_size ,
309
325
num_heads = config .num_attention_heads ,
310
326
num_kv_heads = config .num_key_value_heads ,
@@ -322,13 +338,13 @@ def __init__(
322
338
config .mlp_only_layers )
323
339
if (layer_idx
324
340
not in mlp_only_layers ) and (config .num_experts > 0 ): ### ???
325
- self .mlp = PanGuMoeSparseMoeBlock (
341
+ self .mlp = PanguProMoESparseMoeBlock (
326
342
config = config ,
327
343
quant_config = quant_config ,
328
344
prefix = f"{ prefix } .mlp" ,
329
345
)
330
346
else :
331
- self .mlp = PanGuMoeMLP (
347
+ self .mlp = PanguProMoEMLP (
332
348
hidden_size = config .hidden_size ,
333
349
intermediate_size = config .intermediate_size ,
334
350
hidden_act = config .hidden_act ,
@@ -370,7 +386,7 @@ def forward(
370
386
371
387
372
388
@support_torch_compile
373
- class PanGuMoEModel (nn .Module ):
389
+ class PanguProMoEModel (nn .Module ):
374
390
375
391
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
376
392
super ().__init__ ()
@@ -390,7 +406,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
390
406
391
407
self .start_layer , self .end_layer , self .layers = make_layers (
392
408
config .num_hidden_layers ,
393
- lambda prefix : PanGuMoeDecoderLayer (config = config ,
409
+ lambda prefix : PanguProMoEDecoderLayer (config = config ,
394
410
cache_config = cache_config ,
395
411
quant_config = quant_config ,
396
412
prefix = prefix ),
@@ -439,7 +455,7 @@ def forward(
439
455
return hidden_states
440
456
441
457
442
- class PanGuMoEForCausalLM (nn .Module , SupportsPP ):
458
+ class PanguProMoEForCausalLM (nn .Module , SupportsPP ):
443
459
444
460
fall_back_to_pt_during_load = False
445
461
@@ -456,7 +472,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
456
472
quant_config = vllm_config .quant_config
457
473
self .config = config
458
474
self .quant_config = quant_config
459
- self .model = PanGuMoEModel (vllm_config = vllm_config ,
475
+ self .model = PanguProMoEModel (vllm_config = vllm_config ,
460
476
prefix = maybe_prefix (prefix , "model" ))
461
477
self .lm_head = ParallelLMHead (
462
478
config .vocab_size ,
0 commit comments