Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ def __init__(
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
q_pad_num_heads: Optional[int] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
Expand All @@ -959,6 +960,7 @@ def __init__(
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
self.q_pad_num_heads = q_pad_num_heads

if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA")
Expand Down Expand Up @@ -1134,7 +1136,7 @@ def _run_prefill_context_chunk_cudnn(self,
True, #Indicates actual_seq_lens are on GPU or CPU.
)

def _v_up_proj(self, x):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
Expand All @@ -1146,12 +1148,23 @@ def _v_up_proj(self, x):
transpose_bm=True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)

# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"

# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x
out_new = out.transpose(0, 1).reshape(
-1, self.num_heads * self.v_head_dim)

# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result

def process_weights_after_loading(self, act_dtype: torch.dtype):

Expand Down Expand Up @@ -1559,6 +1572,15 @@ def forward(
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)

# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty(
(B, self.q_pad_num_heads, L))
decode_pe_padded.resize_((B, N, L))
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded

if is_rocm_aiter_fp8bmm_enabled():
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
Expand All @@ -1567,8 +1589,19 @@ def forward(
group_size=128,
transpose_bm=True)
else:
# Pads the head_dim if necessary (for the underlying kernel)
N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None:
decode_ql_nope = decode_q_nope.new_empty(
(self.q_pad_num_heads, B, L))
decode_ql_nope.resize_((N, B, L))

else:
decode_ql_nope = decode_q_nope.new_empty((N, B, L))

# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)

Expand Down Expand Up @@ -1603,5 +1636,5 @@ def forward(
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())

# v_up projection
output[:num_decode_tokens] = self._v_up_proj(attn_out)
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
return output_padded
30 changes: 16 additions & 14 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def ensure_size(self, attn_metadata: MLACommonMetadata,

g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB

MAX_HEADS = 128


class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
Expand All @@ -92,10 +94,18 @@ def __init__(
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
super().__init__(num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
q_pad_num_heads=MAX_HEADS,
**mla_args)

unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
Expand Down Expand Up @@ -157,14 +167,6 @@ def _sm100_cutlass_mla_decode(

MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded

q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded

assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
Expand Down Expand Up @@ -206,9 +208,9 @@ def _sm100_cutlass_mla_decode(
)

if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H]
if self.need_to_return_lse_for_decode:
lse = lse[:, :H].contiguous()

return out, lse

Expand Down
Loading