@@ -1286,6 +1286,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12861286 self .config = config
12871287 self .quant_config = quant_config
12881288
1289+ qk_nope_head_dim = getattr (config , "qk_nope_head_dim" , 0 )
1290+ qk_rope_head_dim = getattr (config , "qk_rope_head_dim" , 0 )
1291+ self .use_mha = config .model_type == "deepseek" or all (
1292+ dim == 0 for dim in (qk_nope_head_dim , qk_rope_head_dim )
1293+ )
1294+
12891295 # `packed_modules_mapping` needs to be modified before
12901296 # initializing DeepseekV2Model, as it is passed inplace to
12911297 # quantization config init and may be used to select the
@@ -1414,14 +1420,20 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
14141420 # (param_name, shard_name, shard_id)
14151421 ("gate_up_proj" , "gate_proj" , 0 ),
14161422 ("gate_up_proj" , "up_proj" , 1 ),
1417- # MLA
1423+ ]
1424+ mla_params_mapping = [
14181425 ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
14191426 ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
1420- # MHA
1427+ ]
1428+ mha_params_mapping = [
14211429 ("qkv_proj" , "q_proj" , "q" ),
14221430 ("qkv_proj" , "k_proj" , "k" ),
14231431 ("qkv_proj" , "v_proj" , "v" ),
14241432 ]
1433+ if self .use_mha :
1434+ stacked_params_mapping .extend (mha_params_mapping )
1435+ else :
1436+ stacked_params_mapping .extend (mla_params_mapping )
14251437
14261438 # Params for weights, fp8 weight scales, fp8 activation scales
14271439 # (param_name, weight_name, expert_id, shard_id)
0 commit comments