Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu"
"csrc/quantization/fp8/per_token_group_quant.cu")

set_gencode_flags_for_srcs(
Expand Down Expand Up @@ -585,7 +584,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/attention/mla/cutlass_mla_kernels.cu"
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
Expand Down
38 changes: 0 additions & 38 deletions csrc/attention/mla/cutlass_mla_entry.cu

This file was deleted.

225 changes: 0 additions & 225 deletions csrc/attention/mla/cutlass_mla_kernels.cu

This file was deleted.

7 changes: 0 additions & 7 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);

// CUTLASS MLA decode
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);

// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
Expand Down
9 changes: 0 additions & 9 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,15 +1823,6 @@ def flash_mla_with_kvcache(
return out, softmax_lse


def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
scale: float) -> torch.Tensor:
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, scale)
return out


def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
q_nope: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
Expand Down
Loading