Skip to content

[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) #19904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,16 @@ def __init__(
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS
if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
Comment on lines +309 to +318
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modification to set self.attn_backend = _Backend.TORCH_SDPA for MultiHeadAttention when running on ROCm is a clear and sensible platform-specific adjustment. Given the comment # currently, only torch_sdpa is supported on rocm, this change directly addresses compatibility or support limitations on ROCm, ensuring that a known-working backend is utilized.


def forward(
self,
Expand Down
55 changes: 37 additions & 18 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_max_query_len = int(seqlens_q_local_np.max())
local_max_seq_len = int(virt_k_seqlens_np.max())
Comment on lines +246 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly casting the results of seqlens_q_local_np.max() and virt_k_seqlens_np.max() to int is a good practice. It ensures that local_max_query_len and local_max_seq_len are Python integers, which can prevent potential type mismatches with downstream operations or library calls that expect standard integer types rather than NumPy scalar types (e.g., numpy.int64). This enhances type safety and robustness.

local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
Expand All @@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len=local_max_seq_len,
causal=True)

local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.runner.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(
device=self.runner.device,
dtype=torch.int32,
non_blocking=True),
dim=0)


local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)

Expand Down Expand Up @@ -368,6 +380,7 @@ class LocalAttentionMetadata:
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]

local_attn_metadata: Optional[LocalAttentionMetadata] = None
Expand All @@ -387,6 +400,7 @@ def __init__(
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
Expand All @@ -408,6 +422,7 @@ def __init__(
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -478,22 +493,25 @@ def forward(
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
Comment on lines +497 to +514
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of the kv_sharing_target_layer_name parameter (initialized in the __init__ method at line 390 and stored at line 412) and its use here to conditionally call torch.ops._C_cache_ops.reshape_and_cache_flash is a clean implementation for KV cache sharing. By skipping the cache update when kv_sharing_target_layer_name is set, it correctly avoids redundant writes and ensures that the shared KV cache from a target layer is utilized. The added comment # Skip this if sharing KV cache with an earlier attention layer. effectively clarifies the logic. This change is crucial for models leveraging KV cache sharing.


if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz)
Expand Down Expand Up @@ -541,7 +559,8 @@ def forward(
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
)

_, num_heads, head_size = query.shape
Expand Down