-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[FEAT][ROCm] Upgrade AITER MLA v1 backend #18338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): | |||||||||||||||||||||||||||||||
# The number of entries in the last page of each request in | ||||||||||||||||||||||||||||||||
# the paged kv cache, shape: [batch_size] | ||||||||||||||||||||||||||||||||
paged_kv_last_page_len: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||||
# The query indptr, shape : [num_decode + 1] | ||||||||||||||||||||||||||||||||
qo_indptr: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): | ||||||||||||||||||||||||||||||||
|
@@ -75,27 +77,33 @@ def _get_paged_kv_tensors( | |||||||||||||||||||||||||||||||
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: | ||||||||||||||||||||||||||||||||
page_size = self.kv_cache_spec.block_size | ||||||||||||||||||||||||||||||||
block_table_bounds = (seq_lens + page_size - 1) // page_size | ||||||||||||||||||||||||||||||||
device = self.runner.device | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
mask = (torch.arange(block_table.size(1), | ||||||||||||||||||||||||||||||||
dtype=block_table.dtype, | ||||||||||||||||||||||||||||||||
device=block_table.device).unsqueeze(0) | ||||||||||||||||||||||||||||||||
device=device).unsqueeze(0) | ||||||||||||||||||||||||||||||||
< block_table_bounds.unsqueeze(1)) | ||||||||||||||||||||||||||||||||
paged_kv_indices = block_table[mask] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
paged_kv_indptr = torch.cat([ | ||||||||||||||||||||||||||||||||
torch.zeros(1, | ||||||||||||||||||||||||||||||||
dtype=block_table_bounds.dtype, | ||||||||||||||||||||||||||||||||
device=block_table_bounds.device), | ||||||||||||||||||||||||||||||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device), | ||||||||||||||||||||||||||||||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32) | ||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
paged_kv_last_page_len = seq_lens % page_size | ||||||||||||||||||||||||||||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||||||||||||||||||||||||||||||||
page_size, paged_kv_last_page_len) | ||||||||||||||||||||||||||||||||
qo_indptr = torch.arange(0, | ||||||||||||||||||||||||||||||||
self._num_decodes + 1, | ||||||||||||||||||||||||||||||||
step=1, | ||||||||||||||||||||||||||||||||
dtype=torch.int32, | ||||||||||||||||||||||||||||||||
device=device) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return ( | ||||||||||||||||||||||||||||||||
paged_kv_indices, | ||||||||||||||||||||||||||||||||
paged_kv_indptr, | ||||||||||||||||||||||||||||||||
paged_kv_last_page_len, | ||||||||||||||||||||||||||||||||
qo_indptr, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def _build_decode(self, block_table_tensor: torch.Tensor, | ||||||||||||||||||||||||||||||||
|
@@ -105,14 +113,16 @@ def _build_decode(self, block_table_tensor: torch.Tensor, | |||||||||||||||||||||||||||||||
paged_kv_indices, | ||||||||||||||||||||||||||||||||
paged_kv_indptr, | ||||||||||||||||||||||||||||||||
paged_last_page_len, | ||||||||||||||||||||||||||||||||
qo_indptr, | ||||||||||||||||||||||||||||||||
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
attn_metadata = AiterMLADecodeMetadata( | ||||||||||||||||||||||||||||||||
block_table=block_table_tensor, | ||||||||||||||||||||||||||||||||
seq_lens=seq_lens, | ||||||||||||||||||||||||||||||||
paged_kv_indptr=paged_kv_indptr, | ||||||||||||||||||||||||||||||||
paged_kv_indices=paged_kv_indices, | ||||||||||||||||||||||||||||||||
paged_kv_last_page_len=paged_last_page_len) | ||||||||||||||||||||||||||||||||
paged_kv_last_page_len=paged_last_page_len, | ||||||||||||||||||||||||||||||||
qo_indptr=qo_indptr) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return attn_metadata | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
@@ -137,7 +147,10 @@ def __init__( | |||||||||||||||||||||||||||||||
alibi_slopes, sliding_window, kv_cache_dtype, | ||||||||||||||||||||||||||||||||
blocksparse_params, logits_soft_cap, attn_type, | ||||||||||||||||||||||||||||||||
**mla_args) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
assert (num_heads == 16 or num_heads == 128), ( | ||||||||||||||||||||||||||||||||
f"Aiter MLA only supports 16 or 128 number of heads.\n" | ||||||||||||||||||||||||||||||||
f"Provided {num_heads} number of heads.\n" | ||||||||||||||||||||||||||||||||
"Try adjusting tensor_parallel_size value.") | ||||||||||||||||||||||||||||||||
unsupported_features = [ | ||||||||||||||||||||||||||||||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap | ||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||
|
@@ -189,7 +202,18 @@ def _forward_decode( | |||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if self.num_heads == 16: | ||||||||||||||||||||||||||||||||
# AITER MLA decode kernel only supports | ||||||||||||||||||||||||||||||||
# max_seqlen_q=1 when using 16 heads. | ||||||||||||||||||||||||||||||||
max_seqlen_qo = 1 | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we assert here? seems dangerous to just change this since it may cause a mismatch with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson in decode forward ideally the kernel processes one token at a time per sequence and is not constraint to length of With that being said, in vllm v1 engine the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my preference would be to store a separate |
||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
# AITER MLA decode Kernel handles arbitrary | ||||||||||||||||||||||||||||||||
# max_seqlen_q values when using 128 heads. | ||||||||||||||||||||||||||||||||
assert attn_metadata.prefill is not None | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this always set? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ProExpertProg as the metadata required here is vllm/vllm/v1/attention/backends/mla/common.py Lines 452 to 454 in 0c15c2e
vllm/vllm/v1/attention/backends/mla/common.py Lines 539 to 544 in 0c15c2e
or we could simply pass this argument to vllm/vllm/v1/attention/backends/mla/common.py Lines 445 to 450 in 0c15c2e
perhaps @LucasWilkinson have better ideas to suggest here as well. would appreciate your suggestion. |
||||||||||||||||||||||||||||||||
max_seqlen_qo = attn_metadata.prefill.max_query_len | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, | ||||||||||||||||||||||||||||||||
attn_metadata.decode.qo_indptr, max_seqlen_qo, | ||||||||||||||||||||||||||||||||
attn_metadata.decode.paged_kv_indptr, | ||||||||||||||||||||||||||||||||
attn_metadata.decode.paged_kv_indices, | ||||||||||||||||||||||||||||||||
attn_metadata.decode.paged_kv_last_page_len) | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is
q_indptr
?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg it's a pointer array used to manage query sequences in AITER
mla_deocde_fwd
kernel. The variable name is used here is the same as it is referred to it in AITER kernel (see more).