Skip to content

Commit 6be6496

Browse files
Remove old cutlass MLA kernel
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
1 parent 5679399 commit 6be6496

File tree

6 files changed

+4
-338
lines changed

6 files changed

+4
-338
lines changed

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
298298
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
299299
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
300300
"csrc/cutlass_extensions/common.cpp"
301-
"csrc/attention/mla/cutlass_mla_entry.cu"
302301
"csrc/quantization/fp8/per_token_group_quant.cu")
303302

304303
set_gencode_flags_for_srcs(
@@ -585,7 +584,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
585584
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
586585
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
587586
set(SRCS
588-
"csrc/attention/mla/cutlass_mla_kernels.cu"
589587
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
590588
set_gencode_flags_for_srcs(
591589
SRCS "${SRCS}"

csrc/attention/mla/cutlass_mla_entry.cu

Lines changed: 0 additions & 38 deletions
This file was deleted.

csrc/attention/mla/cutlass_mla_kernels.cu

Lines changed: 0 additions & 225 deletions
This file was deleted.

csrc/torch_bindings.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
510510
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
511511
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
512512

513-
// CUTLASS MLA decode
514-
ops.def(
515-
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
516-
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
517-
" Tensor page_table, float scale) -> ()");
518-
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
519-
520513
// SM100 CUTLASS MLA decode
521514
ops.def(
522515
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"

vllm/_custom_ops.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,17 +1823,8 @@ def flash_mla_with_kvcache(
18231823
return out, softmax_lse
18241824

18251825

1826-
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
1827-
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
1828-
seq_lens: torch.Tensor, page_table: torch.Tensor,
1829-
scale: float) -> torch.Tensor:
1830-
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
1831-
seq_lens, page_table, scale)
1832-
return out
1833-
1834-
1835-
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
1836-
q_nope: torch.Tensor, q_pe: torch.Tensor,
1826+
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
1827+
q_pe: torch.Tensor,
18371828
kv_c_and_k_pe_cache: torch.Tensor,
18381829
seq_lens: torch.Tensor, page_table: torch.Tensor,
18391830
workspace: torch.Tensor, scale: float,

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,13 @@ def _sm100_cutlass_mla_decode(
219219

220220
return out, returned_lse
221221

222-
def _sm100_forward_decode(
222+
def _forward_decode(
223223
self,
224224
q_nope: torch.Tensor,
225225
q_pe: torch.Tensor,
226226
kv_c_and_k_pe_cache: torch.Tensor,
227227
attn_metadata: MLACommonMetadata,
228+
layer: AttentionLayer,
228229
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
229230
assert kv_c_and_k_pe_cache.numel() > 0
230231
assert attn_metadata.decode is not None
@@ -245,57 +246,3 @@ def _sm100_forward_decode(
245246
)
246247

247248
return o, (lse if self.need_to_return_lse_for_decode else None)
248-
249-
# TODO: Currently we leave it here only for backup in case something is
250-
# wrong with the new SM100 CUTLASS MLA kernel
251-
def _old_forward_decode(
252-
self,
253-
q_nope: torch.Tensor,
254-
q_pe: torch.Tensor,
255-
kv_c_and_k_pe_cache: torch.Tensor,
256-
attn_metadata: MLACommonMetadata,
257-
) -> torch.Tensor:
258-
assert kv_c_and_k_pe_cache.numel() > 0
259-
assert attn_metadata.decode is not None
260-
261-
if is_quantized_kv_cache(self.kv_cache_dtype):
262-
raise NotImplementedError(
263-
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
264-
265-
B = q_nope.shape[0]
266-
267-
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
268-
dtype=q_nope.dtype,
269-
device=q_nope.device)
270-
271-
# Run MLA
272-
# Clone q_nope and q_pe to make sure strides computation is correct.
273-
q_nope = q_nope.clone()
274-
q_pe = q_pe.clone()
275-
276-
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
277-
attn_metadata.decode.seq_lens,
278-
attn_metadata.decode.block_table, self.scale)
279-
280-
return o
281-
282-
def _forward_decode(
283-
self,
284-
q: torch.Tensor,
285-
kv_c_and_k_pe_cache: torch.Tensor,
286-
attn_metadata: MLACommonMetadata,
287-
layer: AttentionLayer,
288-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
289-
if type(q) is tuple:
290-
q_nope, q_pe = q
291-
else:
292-
q_nope, q_pe = torch.split(
293-
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
294-
if self._use_old_cutlass_mla:
295-
# TODO: Remove the old cutlass MLA kernel after more extensive
296-
# testing
297-
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
298-
attn_metadata), None
299-
300-
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
301-
attn_metadata)

0 commit comments

Comments
 (0)