-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) #19904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, | |
self.runner.device, non_blocking=True) | ||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( | ||
self.runner.device, non_blocking=True) | ||
local_max_query_len = seqlens_q_local_np.max() | ||
local_max_seq_len = virt_k_seqlens_np.max() | ||
local_max_query_len = int(seqlens_q_local_np.max()) | ||
local_max_seq_len = int(virt_k_seqlens_np.max()) | ||
Comment on lines
+246
to
+247
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly casting the results of |
||
local_scheduler_metadata = schedule( | ||
batch_size=local_query_start_loc.shape[0] - 1, | ||
cu_query_lens=local_query_start_loc, | ||
|
@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, | |
max_seq_len=local_max_seq_len, | ||
causal=True) | ||
|
||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, | ||
dtype=torch.int32, | ||
device=self.runner.device) | ||
local_cu_seq_lens[1:] = torch.cumsum( | ||
torch.from_numpy(virt_k_seqlens_np).to( | ||
device=self.runner.device, | ||
dtype=torch.int32, | ||
non_blocking=True), | ||
dim=0) | ||
|
||
|
||
local_attn_metadata = \ | ||
AiterFlashAttentionMetadata.LocalAttentionMetadata( | ||
local_query_start_loc=local_query_start_loc, | ||
local_seqused_k=local_seqused_k, | ||
local_block_table=virt_block_table_tensor, | ||
local_max_query_len=local_max_query_len, | ||
local_max_seq_len=local_max_seq_len, | ||
local_cu_seq_lens=local_cu_seq_lens, | ||
local_scheduler_metadata=local_scheduler_metadata, | ||
) | ||
|
||
|
@@ -368,6 +380,7 @@ class LocalAttentionMetadata: | |
local_block_table: torch.Tensor | ||
local_max_query_len: int | ||
local_max_seq_len: int | ||
local_cu_seq_lens: torch.Tensor | ||
local_scheduler_metadata: Optional[torch.Tensor] | ||
|
||
local_attn_metadata: Optional[LocalAttentionMetadata] = None | ||
|
@@ -387,6 +400,7 @@ def __init__( | |
blocksparse_params: Optional[dict[str, Any]] = None, | ||
logits_soft_cap: Optional[float] = None, | ||
attn_type: AttentionType = AttentionType.DECODER, | ||
kv_sharing_target_layer_name: Optional[int] = None, | ||
use_irope: bool = False, | ||
) -> None: | ||
if blocksparse_params is not None: | ||
|
@@ -408,6 +422,7 @@ def __init__( | |
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. | ||
logits_soft_cap = 0. | ||
self.logits_soft_cap = logits_soft_cap | ||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
@@ -478,22 +493,25 @@ def forward( | |
# performance to make sure it does not introduce any overhead. | ||
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
# Reshape the input keys and values and store them in the cache. | ||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is | ||
# not padded. However, we don't need to do key[:num_actual_tokens] and | ||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses | ||
# the slot_mapping's shape to determine the number of actual tokens. | ||
key_cache, value_cache = kv_cache.unbind(0) | ||
torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) | ||
if self.kv_sharing_target_layer_name is None: | ||
# Reshape the input keys and values and store them in the cache. | ||
# Skip this if sharing KV cache with an earlier attention layer. | ||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is | ||
# not padded. However, we don't need to do key[:num_actual_tokens] | ||
# and value[:num_actual_tokens] because the reshape_and_cache_flash | ||
# op uses the slot_mapping's shape to determine the number of | ||
# actual tokens. | ||
torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) | ||
Comment on lines
+497
to
+514
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The introduction of the |
||
|
||
if self.kv_cache_dtype.startswith("fp8"): | ||
key_cache = key_cache.view(torch.float8_e4m3fnuz) | ||
|
@@ -541,7 +559,8 @@ def forward( | |
alibi_slopes=self.alibi_slopes, | ||
window_size=self.sliding_window, | ||
block_table=block_table, | ||
cu_seqlens_k=cu_seq_lens, | ||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else | ||
local_metadata.local_cu_seq_lens), | ||
) | ||
|
||
_, num_heads, head_size = query.shape | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modification to set
self.attn_backend = _Backend.TORCH_SDPA
forMultiHeadAttention
when running on ROCm is a clear and sensible platform-specific adjustment. Given the comment# currently, only torch_sdpa is supported on rocm
, this change directly addresses compatibility or support limitations on ROCm, ensuring that a known-working backend is utilized.