Skip to content

Commit 1ca29da

Browse files
committed
conditional stacked_params_mapping
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 1dbba64 commit 1ca29da

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12861286
self.config = config
12871287
self.quant_config = quant_config
12881288

1289+
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
1290+
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
1291+
self.use_mha = config.model_type == "deepseek" or all(
1292+
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
1293+
)
1294+
12891295
# `packed_modules_mapping` needs to be modified before
12901296
# initializing DeepseekV2Model, as it is passed inplace to
12911297
# quantization config init and may be used to select the
@@ -1414,14 +1420,20 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
14141420
# (param_name, shard_name, shard_id)
14151421
("gate_up_proj", "gate_proj", 0),
14161422
("gate_up_proj", "up_proj", 1),
1417-
# MLA
1423+
]
1424+
mla_params_mapping = [
14181425
("fused_qkv_a_proj", "q_a_proj", 0),
14191426
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
1420-
# MHA
1427+
]
1428+
mha_params_mapping = [
14211429
("qkv_proj", "q_proj", "q"),
14221430
("qkv_proj", "k_proj", "k"),
14231431
("qkv_proj", "v_proj", "v"),
14241432
]
1433+
if self.use_mha:
1434+
stacked_params_mapping.extend(mha_params_mapping)
1435+
else:
1436+
stacked_params_mapping.extend(mla_params_mapping)
14251437

14261438
# Params for weights, fp8 weight scales, fp8 activation scales
14271439
# (param_name, weight_name, expert_id, shard_id)

0 commit comments

Comments
 (0)