Skip to content

Commit d366ccc

Browse files
[RFC] [Mistral] FP8 format (#10130)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 870c374 commit d366ccc

File tree

4 files changed

+55
-12
lines changed

4 files changed

+55
-12
lines changed

vllm/model_executor/models/llama.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
467467
mistral_mapping = {
468468
"layers": "model.layers",
469469
"attention": "self_attn",
470+
"qscale_act": "input_scale",
471+
"qscale_weight": "weight_scale",
472+
"kv_fake_quantizer.qscale_act": "kv_scale",
470473
"wq": "q_proj",
471474
"wk": "k_proj",
472475
"wv": "v_proj",
@@ -590,15 +593,24 @@ def permute(w: torch.Tensor, n_heads: int):
590593
modules = name.split(".")
591594

592595
# rotary embeds should be sliced
593-
if "wk" in modules:
596+
if "wk" in modules and modules[-1] == "weight":
594597
loaded_weight = permute(loaded_weight,
595598
self.config.num_key_value_heads)
596-
elif "wq" in modules:
599+
elif "wq" in modules and modules[-1] == "weight":
597600
loaded_weight = permute(loaded_weight,
598601
self.config.num_attention_heads)
599602

600-
for item in modules:
601-
if item in mapping and mapping[item] not in name:
603+
num_modules = len(modules)
604+
for i in range(num_modules):
605+
item = modules[i]
606+
next_item = modules[i + 1] if i < num_modules - 1 else None
607+
608+
combined_item = (f"{item}.{next_item}"
609+
if next_item is not None else None)
610+
611+
if combined_item in mapping:
612+
name = name.replace(combined_item, mapping[combined_item])
613+
elif item in mapping and mapping[item] not in name:
602614
name = name.replace(item, mapping[item])
603615

604616
return name, loaded_weight

vllm/model_executor/models/pixtral.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
5454
tokenizer_mode=ctx.model_config.tokenizer_mode)
5555
mm_encoder = tokenizer.instruct.mm_encoder
5656

57-
max_image_size = mm_encoder.mm_config.max_image_size
58-
image_patch_size = mm_encoder.mm_config.image_patch_size
57+
image_config = mm_encoder.mm_config if hasattr(
58+
mm_encoder, "mm_config") else mm_encoder.image_config
59+
60+
max_image_size = image_config.max_image_size
61+
image_patch_size = image_config.image_patch_size
5962

6063
return ((max_image_size // image_patch_size)**2)
6164

vllm/transformers_utils/config.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import os
66
from pathlib import Path
7-
from typing import Any, Dict, Optional, Type, Union
7+
from typing import Any, Dict, Literal, Optional, Type, Union
88

99
import huggingface_hub
1010
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
@@ -554,7 +554,8 @@ def recurse_elems(elem: Any):
554554
for key, value in elem.items():
555555
key = config_mapping.get(key, key)
556556
config_dict[key] = recurse_elems(value)
557-
return PretrainedConfig(**config_dict)
557+
558+
return config_dict
558559
else:
559560
return elem
560561

@@ -566,12 +567,30 @@ def recurse_elems(elem: Any):
566567
config_dict["max_position_embeddings"] = config_dict.get(
567568
"max_position_embeddings", 128_000)
568569

570+
if config_dict.get("quantization") is not None:
571+
quantization = config_dict.get("quantization", {})
572+
if quantization.get("qformat_weight") == "fp8_e4m3":
573+
# This maps to the FP8 static per-tensor quantization scheme
574+
quantization_config = {
575+
"quant_method": "fp8",
576+
"activation_scheme": "static"
577+
}
578+
else:
579+
raise ValueError(
580+
f"Found unknown quantization='{quantization}' in config")
581+
582+
config_dict["quantization_config"] = quantization_config
583+
584+
config_type: Literal["text",
585+
"multimodal"] = "multimodal" if config_dict.get(
586+
"vision_encoder") is not None else "text"
587+
569588
if config_dict.get("moe") is not None:
570589
config_dict["architectures"] = ["MixtralForCausalLM"]
571590
else:
572591
config_dict["architectures"] = ["MistralForCausalLM"]
573592

574-
if config_dict.get("vision_encoder") is not None:
593+
if config_type == "multimodal":
575594
multimodal_config = config_dict.pop("vision_encoder")
576595

577596
config_dict = {
@@ -583,8 +602,16 @@ def recurse_elems(elem: Any):
583602

584603
config_dict.update(kwargs)
585604

586-
config = recurse_elems(config_dict)
587-
return config
605+
config_dict = recurse_elems(config_dict)
606+
607+
# transform to HF config format
608+
if config_type == "multimodal":
609+
config_dict["text_config"] = PretrainedConfig(
610+
**config_dict["text_config"])
611+
config_dict["vision_config"] = PretrainedConfig(
612+
**config_dict["vision_config"])
613+
614+
return PretrainedConfig(**config_dict)
588615

589616

590617
def get_hf_image_processor_config(

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
8888

8989

9090
def find_tokenizer_file(files: List[str]):
91-
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
91+
file_pattern = re.compile(
92+
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
9293

9394
matched_files = [file for file in files if file_pattern.match(file)]
9495
if len(matched_files) > 1:

0 commit comments

Comments
 (0)