Skip to content

[Bug]: FP8 checkpoints with fused linear modules fail to load scales correctly #5915

@mgoin

Description

@mgoin

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

Description:
When loading FP8 quantized models with merged linear modules (e.g., Phi-3 with merged qkv_proj and up_gate_proj), the scales for each shard are not handled correctly. This occurs because the vLLM FP8 config assumes separate scales for each shard, but merged layers have a single scale.

Steps to Reproduce:

  1. Attempt to load an FP8 quantized Phi-3 model (e.g., https://huggingface.co/nm-testing/Phi-3-mini-128k-instruct-FP8)
  2. Observe error due to shape mismatch:
    param_data.shape=torch.Size([2]) loaded_weight.shape=torch.Size([])
    param_data.shape=torch.Size([3]) loaded_weight.shape=torch.Size([])
    

Expected Behavior:
Scales should be correctly loaded for merged linear modules in FP8 checkpoints.

Proposed Fix:
Modify process_weights_after_loading in MergedColumnParallelLinear and QKVParallelLinear to repeat the merged scale during weight loading.

Temporary Workaround:
Apply the following patch in vllm/model_executor/layers/linear.py:

- assert param_data.shape == loaded_weight.shape
- param_data.copy_(loaded_weight)
+ temp = loaded_weight.repeat(param_data.shape)
+ assert param_data.shape == temp.shape
+ param_data.copy_(temp)

cc @robertgshaw2-neuralmagic @comaniac

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions