Skip to content

Commit c7ea0b5

Browse files
authored
[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger (#17331)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
1 parent 29fa5ca commit c7ea0b5

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CommonMetadataBuilder)
1818
from vllm.attention.ops.paged_attn import (PagedAttention,
1919
PagedAttentionMetadata)
20+
from vllm.config import get_current_vllm_config
2021
from vllm.logger import init_logger
2122
from vllm.platforms import current_platform
2223
from vllm.platforms.rocm import use_rocm_custom_paged_attention
@@ -584,6 +585,10 @@ def __init__(
584585
logger.debug("Using naive (SDPA) attention in ROCmBackend")
585586

586587
self.aiter_kv_scales_initialized = False
588+
self.force_fp8_attention = (
589+
get_current_vllm_config() is not None
590+
and get_current_vllm_config().model_config.override_attention_dtype
591+
== "fp8")
587592

588593
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
589594
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -770,9 +775,12 @@ def forward(
770775
query.dtype,
771776
seq_lens,
772777
make_attn_mask=causal_mask) # type: ignore
778+
773779
use_fp8_scales = (layer._q_scale and layer._k_scale
774780
and layer._v_scale and layer._prob_scale
775-
and self.kv_cache_dtype == "fp8")
781+
and (self.kv_cache_dtype == "fp8"
782+
or self.force_fp8_attention))
783+
776784
full_scales = (
777785
layer._q_scale.item(), layer._k_scale.item(),
778786
layer._v_scale.item(),

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,8 @@ class ModelConfig:
417417
available.\n
418418
- "vllm" will use the vLLM model implementation.\n
419419
- "transformers" will use the Transformers model implementation."""
420+
override_attention_dtype: Optional[str] = None
421+
"""Override dtype for attention"""
420422

421423
def compute_hash(self) -> str:
422424
"""
@@ -517,6 +519,12 @@ def __post_init__(self) -> None:
517519

518520
from vllm.platforms import current_platform
519521

522+
if (self.override_attention_dtype is not None
523+
and not current_platform.is_rocm()):
524+
warnings.warn(
525+
"override-attention-dtype is set but not using ROCm platform",
526+
stacklevel=2)
527+
520528
if (self.enable_sleep_mode
521529
and not current_platform.is_sleep_mode_available()):
522530
raise ValueError(

vllm/engine/arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ class EngineArgs:
429429
override_generation_config: dict[str, Any] = \
430430
get_field(ModelConfig, "override_generation_config")
431431
model_impl: str = ModelConfig.model_impl
432+
override_attention_dtype: str = ModelConfig.override_attention_dtype
432433

433434
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
434435

@@ -549,6 +550,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
549550
model_group.add_argument("--model-impl",
550551
choices=[f.value for f in ModelImpl],
551552
**model_kwargs["model_impl"])
553+
model_group.add_argument("--override-attention-dtype",
554+
**model_kwargs["override_attention_dtype"])
552555

553556
# Model loading arguments
554557
load_kwargs = get_kwargs(LoadConfig)
@@ -946,6 +949,7 @@ def create_model_config(self) -> ModelConfig:
946949
override_generation_config=self.override_generation_config,
947950
enable_sleep_mode=self.enable_sleep_mode,
948951
model_impl=self.model_impl,
952+
override_attention_dtype=self.override_attention_dtype,
949953
)
950954

951955
def create_load_config(self) -> LoadConfig:

0 commit comments

Comments
 (0)