Skip to content

Making vLLM compatible with Mistral fp8 weights. #10229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c8865cf
Making vLLM compatible with Mistral fp8 weights.
akllm Nov 11, 2024
6937cc7
remove extra whitespace
akllm Nov 11, 2024
3518639
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 11, 2024
cd1ca2c
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 12, 2024
f61bc32
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 12, 2024
5a027f1
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 13, 2024
869923f
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 14, 2024
6cb6bad
Commiting the value of the parameter
akllm Nov 14, 2024
40e37e2
Rest of the changes
akllm Nov 14, 2024
cf88baf
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 14, 2024
1723734
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 15, 2024
133881a
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 15, 2024
27595b7
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 17, 2024
3ad46bc
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 18, 2024
c98fa33
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 18, 2024
dcbd25f
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 19, 2024
ed1cb8c
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 19, 2024
226be71
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 19, 2024
cdbaf7e
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 20, 2024
4e756a7
Merge branch 'vllm-project:main' into vllmfp8mistral
akllm Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
Expand Down Expand Up @@ -614,15 +617,22 @@
modules = name.split(".")

# rotary embeds should be sliced
if "wk" in modules:
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

for item in modules:
if item in mapping and mapping[item] not in name:
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None

combined_item = f"{item}.{next_item}" if next_item is not None else None

Check failure on line 632 in vllm/model_executor/models/llama.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/llama.py:632:81: E501 Line too long (84 > 80)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])

return name, loaded_weight
Expand Down
12 changes: 8 additions & 4 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,12 @@ def load_params_config(model: Union[str, Path],
"hidden_dim": "intermediate_size",
}

def recurse_elems(elem: Any):
if isinstance(elem, dict):
def recurse_elems(elem: Any, wrap_to_hf_config: bool=True):
if isinstance(elem, dict) and wrap_to_hf_config:
config_dict = {}
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
config_dict[key] = recurse_elems(value, wrap_to_hf_config=False)
return PretrainedConfig(**config_dict)
else:
return elem
Expand All @@ -507,7 +507,11 @@ def recurse_elems(elem: Any):
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)

if config_dict.get("quantization") is not None:
config_dict["quantization_config"] = {
"quant_method": "fp8",
"activation_scheme": "static"
}
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
Expand Down
Loading