From 00c351d1ed5008faa6221315b30b23fa6ca960e8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 19 Jun 2024 05:38:26 -0400 Subject: [PATCH] [Model] Add FP8 kv cache for Qwen2 (#5656) --- vllm/model_executor/models/qwen2.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 9a4829a27873e..b5d13bb6b937c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -46,6 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class Qwen2MLP(nn.Module): @@ -375,6 +376,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)