Skip to content

[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization #15734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 97 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
2702132
Add kernel that supports variable sequence length
rasmith Jan 30, 2025
cb79d54
isort
rasmith Jan 30, 2025
f346cd2
codespell
rasmith Jan 30, 2025
0ffdfae
ruff
rasmith Jan 30, 2025
71a072e
ruff/codespell
rasmith Jan 30, 2025
1f3729d
ruff
rasmith Jan 31, 2025
d6652ac
codespell
rasmith Jan 31, 2025
8f702c8
ruff
rasmith Jan 31, 2025
1f718b2
ruff
rasmith Jan 31, 2025
f9f2aba
ruff
rasmith Jan 31, 2025
107a7a5
ruff
rasmith Jan 31, 2025
ecb3320
ruff
rasmith Jan 31, 2025
4ef102d
Merge branch 'vllm-project:main' into ransmith_triton_fav2_vsl
rasmith Jan 31, 2025
71f89c5
yapf
rasmith Jan 31, 2025
94258b0
resolve merge
rasmith Feb 7, 2025
239f1d1
unit tests work
rasmith Feb 8, 2025
fc903c3
remove bwd
rasmith Feb 8, 2025
1780946
reformatting
rasmith Feb 8, 2025
aaa0d9e
reformatting
rasmith Feb 8, 2025
9e0d8ce
reformatting
rasmith Feb 8, 2025
9a2afda
reformatting
rasmith Feb 8, 2025
2161514
reformatting
rasmith Feb 8, 2025
1366205
ruff
rasmith Feb 8, 2025
3ed0e91
ruff
rasmith Feb 8, 2025
8b66ddc
ruff
rasmith Feb 8, 2025
bd3b2c7
ruff
rasmith Feb 8, 2025
dd63b79
add unit tests
rasmith Feb 8, 2025
6350ed7
everything seems to work
rasmith Feb 8, 2025
5eff339
codespell
rasmith Feb 10, 2025
0727349
fix incorrect function call
rasmith Feb 11, 2025
f6b2001
add spdx identifier
rasmith Feb 11, 2025
f1a29d3
Merge branch 'vllm-project:main' into ransmith_triton_fav2_vsl
rasmith Feb 11, 2025
5f478de
Merge branch 'vllm-project:main' into ransmith_triton_fav2_vsl
rasmith Feb 12, 2025
65ff486
update unit tests
rasmith Feb 13, 2025
95d1795
Try using newer fa kernel
rasmith Feb 22, 2025
e8f1ed7
Merge branch 'main' into ransmith_triton_fav2_vsl
rasmith Feb 22, 2025
44eb67e
revert back to current triton fa
rasmith Feb 22, 2025
e8e6fef
use older triton fa
rasmith Feb 22, 2025
c370096
Merge branch 'vllm-project:main' into ransmith_triton_fav2_vsl
rasmith Mar 13, 2025
3423c44
Update unit tests, make it work with fp8 llama for ROCm
rasmith Mar 13, 2025
301811b
yapf
rasmith Mar 13, 2025
eda000c
isort
rasmith Mar 13, 2025
a3ca4f1
ruff/yapf
rasmith Mar 13, 2025
90022be
remove main from test file
rasmith Mar 13, 2025
67adb34
mypy
rasmith Mar 13, 2025
3b5ba1a
mypy
rasmith Mar 13, 2025
b8ba91b
isort
rasmith Mar 13, 2025
4bfacad
don't use fp8_out_scale parameter in abstract
rasmith Mar 13, 2025
b911e3f
ruff
rasmith Mar 13, 2025
ae7f6c6
remove cpa if statement
rasmith Mar 14, 2025
345ec5d
eight bit dtypes
rasmith Mar 14, 2025
5c399ea
eight bit dtype
rasmith Mar 14, 2025
787eb33
eight bit dtype
rasmith Mar 14, 2025
d39aee9
init fp8_out_scale in layer.py
rasmith Mar 14, 2025
b22250f
remove parameter
rasmith Mar 14, 2025
9508130
update autotune
rasmith Mar 14, 2025
3a7048c
merge main
rasmith Mar 21, 2025
5e4a79f
remove variable redefinitions
rasmith Mar 21, 2025
ea67811
add back in vscale float
rasmith Mar 21, 2025
6364b73
Merge branch 'vllm-project:main' into ransmith_triton_fav2_vsl
rasmith Mar 21, 2025
62c2efb
merge main
rasmith Mar 28, 2025
fb94876
remove extra q_scale and is_navi
rasmith Mar 28, 2025
f15c554
add booloean
rasmith Mar 28, 2025
94ac0b1
use kv_cache_dtype
rasmith Mar 28, 2025
2534cef
remove variable
rasmith Mar 28, 2025
9742688
rename _fp8_out_scale
rasmith Mar 28, 2025
07abdab
use the old kernel in this branch
rasmith Mar 28, 2025
79eed6c
remove unit tests
rasmith Mar 28, 2025
068b7a0
keep scales on device and use is_fp8_fnuz
rasmith Apr 4, 2025
b08573b
removing out_scale and adding fp8 scale kernel
rasmith Apr 5, 2025
2853082
Remove _out_scale
rasmith Apr 5, 2025
0073fb4
add in-place scale kernel for fp8
rasmith Apr 5, 2025
2ceedcb
use scaled_fp8_quant and dont check for fp8
rasmith Apr 7, 2025
e863a92
remove kernel
rasmith Apr 7, 2025
456a13a
add back in else
rasmith Apr 7, 2025
7f4f1a5
Merge branch 'main' into ransmith_triton_fp8_integration
rasmith Apr 8, 2025
57fa3e1
type and update rocm backend to send tensor instead of float
rasmith Apr 9, 2025
ef180b0
works with FP8-KV models now
rasmith Apr 9, 2025
54da3c9
remove unnecessary renaming
rasmith Apr 10, 2025
9da6f61
use is_fp8_fnuz instead of is_rocm
rasmith Apr 10, 2025
1a6e018
avoid cpu trip and remove extraneous check
rasmith Apr 10, 2025
8d360b6
add back empty line
rasmith Apr 10, 2025
b2e969f
Cleanup ROCm output passing
ProExpertProg Apr 10, 2025
b04ee94
Fix output for ROCm FA output
ProExpertProg Apr 13, 2025
402c564
Fix sdpa arg type
ProExpertProg Apr 13, 2025
3516018
Remove out param from FA path
ProExpertProg Apr 13, 2025
1a81fc8
update comment and removed unused code
rasmith Apr 22, 2025
3fe98e4
Merge branch 'main' of github.com:rasmith/vllm
rasmith Apr 22, 2025
3e7e199
merge main
rasmith Apr 22, 2025
6019575
check if tensor is fp and do arg_utils check for V1 vs fp8 for ROCm
rasmith Apr 23, 2025
642489b
mypy
rasmith Apr 23, 2025
e974ec4
make sure works if certain layers don't exist and avoid model_config …
rasmith Apr 23, 2025
a515af5
rearrange the to_dict
rasmith Apr 23, 2025
37f3287
mypy
rasmith Apr 23, 2025
57bb189
Merge branch 'vllm-project:main' into ransmith_triton_fp8_integration
rasmith Apr 24, 2025
9338de2
check q_proj_q_config is none
rasmith Apr 24, 2025
9229d9a
Merge branch 'ransmith_triton_fp8_integration' of github.com:rasmith/…
rasmith Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class AttentionLayer(Protocol):
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why is this called _prob_scale?

Copy link
Collaborator

@ProExpertProg ProExpertProg Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the scale for P, which is the tensor resultng from softmax(Q@K) calculation. O is softmax(Q@K)@V

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from the parameter in the model, something like self_attn.prob_output_scale, it gets remapped to .attn.prob_scale and also @ProExpertProg 's comments.


def forward(
self,
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,12 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale,
layer._prob_scale) if use_fp8_scales else None
self.triton_attn_func(
query,
key,
Expand All @@ -779,6 +785,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
Expand Down
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3766,6 +3766,17 @@ def _get_quantization_config(
return quant_config
return None

@staticmethod
def get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
import copy

# For some reason, the _ version of this modifies the model_config
# object, so using deepcopy to avoid this problem.
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
load_config)

def with_hf_config(
self,
hf_config: PretrainedConfig,
Expand Down
17 changes: 17 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,23 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

if current_platform.is_rocm():
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
load_config = self.create_load_config()
quantization_config = VllmConfig.get_quantization_config(
model_config, load_config)
if isinstance(quantization_config, Fp8Config):
_raise_or_fallback(feature_name="fp8 for ROCm",
recommend_to_remove=False)
return False
from vllm.model_executor.layers.quantization.quark.quark import (
QuarkConfig)

if isinstance(quantization_config, QuarkConfig
) and quantization_config.has_fp8_layer_weights():
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
recommend_to_remove=False)

# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def get_cache_scale(self, name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to extend this part if already handled in quark.py? cc @kewang-xlnx

Same question applies for change to compressed_tensors.py above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was necessary, I had to add this for it to work.

Copy link
Contributor Author

@rasmith rasmith Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, just checked again, the compressed tensors modification is not necesssary, unless maybe there is int8 w8a8 output_scale, but that work isn't happening yet. But some FP8 models will be quantized as "fp8" in quantization config, and some are just "quark", so the renaming needs to happen in both places.

return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None


Expand Down
36 changes: 36 additions & 0 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
Expand Down Expand Up @@ -97,5 +100,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

if layer.q_scale > 0.0:
q_scale = layer.q_scale
if current_platform.is_fp8_fnuz():
q_scale *= 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why multiply by 2 here, is that because of the extra bit in fnuz? Because we should be checking current_platform.is_fnuz() not just ROCm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why can't we just keep the scales on device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use is_fp8_fnuz() and keeping scales on device (not sure why the cpu thing was there).

layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale
if current_platform.is_fp8_fnuz():
prob_scale *= 2
else:
prob_scale = 1.0

is_singleton_float = lambda x: isinstance(x, float) or isinstance(
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
if not is_singleton_float(q_scale) or not is_singleton_float(
prob_scale):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")

# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if q_scale == 1.0 or prob_scale == 1.0:
logger.warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")

del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
47 changes: 27 additions & 20 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import fnmatch
import re
from typing import Any, Dict, List, Optional, cast

import torch
Expand Down Expand Up @@ -125,6 +124,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
for q_config in q_configs:
q_config["output_tensors"] = None

# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None

return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
Expand Down Expand Up @@ -289,29 +295,30 @@ def get_cache_scale(self, name: str) -> Optional[str]:
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None

kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")

elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")

# If no matches, return None
return None

def has_fp8_layer_weights(self):
layer_quant_config = self.quant_config.get("layer_quant_config")
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
return any([
'fp8' in cast(
str,
to_dict(
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
"weight")).get("dtype"))
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
])


class QuarkLinearMethod(LinearMethodBase):

Expand Down