Skip to content

[AMD] [Quantization] Add override flag for attention dtype instead of using kv_cache_dtype trigger #17331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
Expand Down Expand Up @@ -580,6 +581,10 @@ def __init__(
logger.debug("Using naive (SDPA) attention in ROCmBackend")

self.aiter_kv_scales_initialized = False
self.force_fp8_attention = (
get_current_vllm_config() is not None
and get_current_vllm_config().model_config.override_attention_dtype
== "fp8")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
Expand Down Expand Up @@ -766,9 +771,12 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore

use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
and (self.kv_cache_dtype == "fp8"
or self.force_fp8_attention))

full_scales = (
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(),
Expand Down
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ class ModelConfig:
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -515,6 +517,12 @@ def __post_init__(self) -> None:

from vllm.platforms import current_platform

if (self.override_attention_dtype is not None
and not current_platform.is_rocm()):
warnings.warn(
"override-attention-dtype is set but not using ROCm platform",
stacklevel=2)

if (self.enable_sleep_mode
and not current_platform.is_sleep_mode_available()):
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class EngineArgs:
override_generation_config: dict[str, Any] = \
get_field(ModelConfig, "override_generation_config")
model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype

calculate_kv_scales: bool = CacheConfig.calculate_kv_scales

Expand Down Expand Up @@ -533,6 +534,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
model_group.add_argument("--model-impl",
choices=[f.value for f in ModelImpl],
**model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"])

# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
Expand Down Expand Up @@ -927,6 +930,7 @@ def create_model_config(self) -> ModelConfig:
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
)

def create_load_config(self) -> LoadConfig:
Expand Down