Skip to content

[FEAT][ROCm] Upgrade AITER MLA v1 backend #18338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="5a77249"
ARG AITER_BRANCH="c1debd8"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
36 changes: 30 additions & 6 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None


class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
Expand All @@ -75,27 +77,33 @@ def _get_paged_kv_tensors(
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device

mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is q_indptr?

Copy link
Contributor Author

@vllmellm vllmellm May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg it's a pointer array used to manage query sequences in AITER mla_deocde_fwd kernel. The variable name is used here is the same as it is referred to it in AITER kernel (see more).

self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)

return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
qo_indptr,
)

def _build_decode(self, block_table_tensor: torch.Tensor,
Expand All @@ -105,14 +113,16 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)

attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len)
paged_kv_last_page_len=paged_last_page_len,
qo_indptr=qo_indptr)

return attn_metadata

Expand All @@ -137,7 +147,10 @@ def __init__(
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)

assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
Expand Down Expand Up @@ -189,7 +202,18 @@ def _forward_decode(

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert here? seems dangerous to just change this since it may cause a mismatch with qo_indptr?

Copy link
Contributor Author

@vllmellm vllmellm May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson in decode forward ideally the kernel processes one token at a time per sequence and is not constraint to length of qo_indptr which is a pointer array used to manage decode query sequences, this allows identifying the start and end of decode query sequence (see more). While max_seqlen_q determines the grid dimension for parallel processing: grid = (bs, nhead, max_seqlen_q=1) which invokes the kernel using this grid size (see more) and it is constraint to the number of heads to be 16 only which means processes one token at a time during decode. this is a hard constraint set in the kernel (see more).

With that being said, in vllm v1 engine the max_query_len that is stored in prefill metadata is not equal to 1 that's why there is this change due to the above mentioned kernel's logic. It would be safer as well if we don't need to access max_query_len from prefill metadata but instead from common metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my preference would be to store a separate max_seqlen_qo for decode (we do this for most attention backends that split prefill decode) but thats just personal opinion

else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg as the metadata required here is max_seqlen_q which is stored part of prefill metadata I think it would be better if we could modify the common module and store this in common metadata so we wouldn't need to access this through prefill. Generally speaking max_query_len is passed build function as below in common module:

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:

prefill_metadata = MLACommonPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
)

or we could simply pass this argument to build_decode function such that the function would determine its own max_decode_query_len metadata.

def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
)

perhaps @LucasWilkinson have better ideas to suggest here as well. would appreciate your suggestion.

max_seqlen_qo = attn_metadata.prefill.max_query_len

aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
Expand Down