Skip to content

Commit fa13a8b

Browse files
Merge pull request vllm-project#13 from vllm-model-0920/lwilkinson/build-sparse-flash-mla
Build and bind sparse-FlashMLA kernels
2 parents 446c0de + 840f205 commit fa13a8b

File tree

4 files changed

+325
-2
lines changed

4 files changed

+325
-2
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
1818
else()
1919
FetchContent_Declare(
2020
flashmla
21-
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22-
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
21+
GIT_REPOSITORY https://github.com/vllm-model-0920/FlashMLA
22+
GIT_TAG a25b977fae6925c45c3d0404c98c6ce6f4563dac
2323
GIT_PROGRESS TRUE
2424
CONFIGURE_COMMAND ""
2525
BUILD_COMMAND ""
@@ -35,6 +35,10 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3535
# sm90a
3636
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
3737
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
38+
#######################################################################
39+
# FlashMLA Dense -- _flashmla_C
40+
#######################################################################
41+
3842
set(FlashMLA_SOURCES
3943
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
4044
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
@@ -60,8 +64,90 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
6064
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
6165
USE_SABI 3
6266
WITH_SOABI)
67+
68+
#######################################################################
69+
# FlashMLA Sparse -- _flashmla_sparse_C
70+
#######################################################################
71+
72+
# We use seperate libraries to avoid crosss contaminating includes,
73+
# namely kernels/utils.h
74+
75+
set(DECODE_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/decode)
76+
set(PREFILL_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/prefill)
77+
78+
# ---- Decode object library ----
79+
set(SPARSE_FLASHMLA_DECODE_SOURCES
80+
${DECODE_FOLDER}/flash_api.cpp
81+
${DECODE_FOLDER}/kernels/get_mla_metadata.cu
82+
${DECODE_FOLDER}/kernels/mla_combine.cu
83+
${DECODE_FOLDER}/kernels/fp8_sparse/splitkv_mla.cu
84+
)
85+
86+
add_library(_flashmla_sparse_decode OBJECT ${SPARSE_FLASHMLA_DECODE_SOURCES})
87+
set_property(TARGET _flashmla_sparse_decode PROPERTY POSITION_INDEPENDENT_CODE ON)
88+
89+
set_gencode_flags_for_srcs(
90+
SRCS "${SPARSE_FLASHMLA_DECODE_SOURCES}"
91+
CUDA_ARCHS "${FLASH_MLA_ARCHS}"
92+
)
93+
94+
# Include paths for decode ONLY (do not leak DECODE_FOLDER to others)
95+
target_include_directories(_flashmla_sparse_decode
96+
PRIVATE
97+
${flashmla_SOURCE_DIR}/csrc/cutlass/include
98+
${TORCH_INCLUDE_DIRS}
99+
${Python_INCLUDE_DIRS}
100+
${DECODE_FOLDER}
101+
)
102+
target_compile_options(_flashmla_sparse_decode PRIVATE
103+
$<$<COMPILE_LANGUAGE:CUDA>:${VLLM_GPU_FLAGS}>)
104+
105+
# ---- Prefill object library ----
106+
set(SPARSE_FLASHMLA_PREFILL_SOURCES
107+
${PREFILL_FOLDER}/api.cpp
108+
${PREFILL_FOLDER}/kernels/sm90/fwd/fwd.cu
109+
)
110+
111+
add_library(_flashmla_sparse_prefill OBJECT ${SPARSE_FLASHMLA_PREFILL_SOURCES})
112+
set_property(TARGET _flashmla_sparse_prefill PROPERTY POSITION_INDEPENDENT_CODE ON)
113+
114+
set_gencode_flags_for_srcs(
115+
SRCS "${SPARSE_FLASHMLA_PREFILL_SOURCES}"
116+
CUDA_ARCHS "${FLASH_MLA_ARCHS}"
117+
)
118+
119+
target_include_directories(_flashmla_sparse_prefill
120+
PRIVATE
121+
${flashmla_SOURCE_DIR}/csrc/cutlass/include
122+
${TORCH_INCLUDE_DIRS}
123+
${Python_INCLUDE_DIRS}
124+
${PREFILL_FOLDER}
125+
)
126+
target_compile_options(_flashmla_sparse_prefill PRIVATE
127+
$<$<COMPILE_LANGUAGE:CUDA>:${VLLM_GPU_FLAGS}>)
128+
129+
# ---- Final extension target with unified API ----
130+
define_gpu_extension_target(
131+
_flashmla_sparse_C
132+
DESTINATION vllm
133+
LANGUAGE ${VLLM_GPU_LANG}
134+
SOURCES
135+
${flashmla_SOURCE_DIR}/csrc/sparse/api.cpp
136+
$<TARGET_OBJECTS:_flashmla_sparse_decode>
137+
$<TARGET_OBJECTS:_flashmla_sparse_prefill>
138+
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
139+
ARCHITECTURES ${VLLM_GPU_ARCHES}
140+
# Only the common/public includes here; do NOT add decode/prefill folders
141+
INCLUDE_DIRECTORIES
142+
csrc/
143+
${CUTLASS_INCLUDE_DIR}
144+
${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
145+
USE_SABI 3
146+
WITH_SOABI
147+
)
63148
else()
64149
# Create an empty target for setup.py when not targeting sm90a systems
65150
add_custom_target(_flashmla_C)
151+
add_custom_target(_flashmla_sparse_C)
66152
endif()
67153

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
322322
"vllm/_C.abi3.so",
323323
"vllm/_moe_C.abi3.so",
324324
"vllm/_flashmla_C.abi3.so",
325+
"vllm/_sparse_flashmla_C.abi3.so",
325326
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
326327
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
327328
"vllm/cumem_allocator.abi3.so",
@@ -589,6 +590,8 @@ def _read_requirements(filename: str) -> list[str]:
589590
# not targeting a hopper system
590591
ext_modules.append(
591592
CMakeExtension(name="vllm._flashmla_C", optional=True))
593+
ext_modules.append(
594+
CMakeExtension(name="vllm._flashmla_sparse_C", optional=True))
592595
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
593596

594597
if _build_custom_ops():
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import pytest
2+
import torch
3+
4+
5+
def _cuda_sm90_available() -> bool:
6+
if not torch.cuda.is_available():
7+
return False
8+
major, _ = torch.cuda.get_device_capability()
9+
return major == 9
10+
11+
12+
@pytest.mark.cuda
13+
def test_sparse_flashmla_imports_and_flags():
14+
import vllm.attention.ops.flashmla as fm
15+
# Functions should exist
16+
assert hasattr(fm, "get_sparse_mla_metadata")
17+
assert hasattr(fm, "flash_mla_sparse_with_kvcache")
18+
assert hasattr(fm, "flash_mla_sparse_prefill")
19+
# Support check should return a (bool, reason)
20+
ok, reason = fm.is_flashmla_supported()
21+
assert isinstance(ok, bool)
22+
assert (reason is None) or isinstance(reason, str)
23+
24+
25+
def test_sparse_flashmla_metadata_smoke():
26+
import vllm.attention.ops.flashmla as fm
27+
ok, reason = fm.is_flashmla_supported()
28+
if not ok or not _cuda_sm90_available():
29+
pytest.skip(reason or "SM90 not available")
30+
31+
device = torch.device("cuda")
32+
batch_size = 1
33+
seqlen_q = 1
34+
num_heads_q = 128
35+
num_heads_k = 1
36+
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
37+
q_heads_per_hk = num_heads_q // num_heads_k
38+
topk = 128
39+
40+
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
41+
42+
tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens,
43+
q_seq_per_hk,
44+
num_heads_k,
45+
topk,
46+
q_heads_per_hk)
47+
assert tile_md.dtype == torch.int32
48+
assert num_splits.dtype == torch.int32
49+
50+
51+
def test_sparse_flashmla_decode_smoke():
52+
import vllm.attention.ops.flashmla as fm
53+
ok, reason = fm.is_flashmla_supported()
54+
if not ok or not _cuda_sm90_available():
55+
pytest.skip(reason or "SM90 not available")
56+
57+
device = torch.device("cuda")
58+
batch_size = 1
59+
seqlen_q = 1
60+
num_heads_q = 1
61+
head_dim_k = 576
62+
head_dim_v = 512
63+
num_heads_k = 1
64+
page_block_size = 64
65+
bytes_per_token = 656
66+
topk = 128
67+
68+
# Metadata
69+
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
70+
q_heads_per_hk = num_heads_q // num_heads_k
71+
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
72+
tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens,
73+
q_seq_per_hk,
74+
num_heads_k,
75+
topk,
76+
q_heads_per_hk)
77+
78+
# Inputs
79+
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
80+
dtype=torch.bfloat16,
81+
device=device)
82+
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
83+
dtype=torch.uint8,
84+
device=device)
85+
indices = torch.zeros((batch_size, seqlen_q, topk),
86+
dtype=torch.int32,
87+
device=device)
88+
89+
out, lse = fm.flash_mla_sparse_with_kvcache(q, k_cache, cache_seqlens,
90+
head_dim_v, tile_md,
91+
num_splits, indices)
92+
assert out.shape[0] == batch_size
93+
assert out.shape[-1] == head_dim_v
94+
assert lse.shape[0] == batch_size
95+
96+
97+
def test_sparse_flashmla_prefill_smoke():
98+
import vllm.attention.ops.flashmla as fm
99+
ok, reason = fm.is_flashmla_supported()
100+
if not ok or not _cuda_sm90_available():
101+
pytest.skip(reason or "SM90 not available")
102+
103+
device = torch.device("cuda")
104+
s_q = 1
105+
s_kv = 1
106+
h_q = 64 # kernel expects multiple of 64
107+
h_kv = 1
108+
d_qk = 576
109+
d_v = 512
110+
topk = 128
111+
112+
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
113+
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
114+
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
115+
116+
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v)
117+
assert out.shape == (s_q, h_q, d_v)
118+
assert max_logits.shape == (s_q, h_q)
119+
assert lse.shape == (s_q, h_q)
120+

vllm/attention/ops/flashmla.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
if current_platform.is_cuda():
1414
try:
1515
import vllm._flashmla_C # noqa: F401
16+
import vllm._flashmla_sparse_C # noqa: F401
1617
_flashmla_C_AVAILABLE = True
1718
except ImportError:
1819
_flashmla_C_AVAILABLE = False
@@ -110,6 +111,119 @@ def flash_mla_with_kvcache(
110111
return out.squeeze(1), softmax_lse.squeeze(-1)
111112

112113

114+
# ------------------------ Sparse FlashMLA bindings -------------------------
115+
116+
117+
def get_sparse_mla_metadata(
118+
cache_seqlens: torch.Tensor,
119+
q_seq_per_hk: int,
120+
num_heads_k: int,
121+
topk: int,
122+
q_heads_per_hk: Optional[int] = None,
123+
) -> Tuple[torch.Tensor, torch.Tensor]:
124+
"""
125+
Arguments:
126+
cache_seqlens: (batch_size), dtype torch.int32.
127+
q_seq_per_hk: Equals to seq_len_q * num_heads_q // num_heads_k.
128+
num_heads_k: num_heads_k.
129+
topk: topk
130+
q_heads_per_hk: equals to num_heads_q // num_heads_k. Only need to be
131+
specified when topk is not None.
132+
133+
Return:
134+
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
135+
dtype torch.int32.
136+
num_splits: (batch_size + 1), dtype torch.int32.
137+
"""
138+
return torch.ops._flashmla_sparse_C.get_mla_metadata(
139+
cache_seqlens, q_seq_per_hk, num_heads_k, topk, q_heads_per_hk)
140+
141+
142+
def flash_mla_sparse_with_kvcache(
143+
q: torch.Tensor,
144+
k_cache: torch.Tensor,
145+
cache_seqlens: torch.Tensor,
146+
head_dim_v: int,
147+
tile_scheduler_metadata: torch.Tensor,
148+
num_splits: torch.Tensor,
149+
indices_in_kvcache: torch.Tensor,
150+
softmax_scale: Optional[float] = None,
151+
) -> Tuple[torch.Tensor, torch.Tensor]:
152+
"""
153+
Arguments:
154+
q: (batch_size, seq_len_q, num_heads_q, head_dim).
155+
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
156+
cache_seqlens: (batch_size), torch.int32.
157+
head_dim_v: Head_dim of v.
158+
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
159+
torch.int32, returned by get_sparse_mla_metadata.
160+
num_splits: (batch_size + 1), torch.int32, returned by
161+
get_sparse_mla_metadata.
162+
indices_in_kvcache: (batch_size, seq_len_q, topk). KV indices when
163+
sparse attention is enabled. Note that
164+
indices_in_kvcache[i][j][k] =
165+
(the index of the page block where token t resides) *
166+
page_block_size + (the offset of token t within that page block),
167+
where t is the k-th token of the j-th q-sequence in the i-th batch.
168+
softmax_scale: float. Scaling of QK^T before softmax.
169+
Defaults to 1 / sqrt(head_dim).
170+
171+
Explanation of K/V cache layout:
172+
We quantize the NoPE part of each token (in 1x128 granularity),
173+
yielding 512 float8_e4m3 values and 4 float32 scale factors. For the
174+
RoPE part, we keep it as 64 bfloat16. Each token occupies 656 bytes:
175+
- First 512 bytes: quantized NoPE (512 x float8_e4m3)
176+
- Next 16 bytes: scale factors (4 x float32)
177+
- Last 128 bytes: RoPE (64 x bfloat16)
178+
179+
Return:
180+
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
181+
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
182+
"""
183+
if softmax_scale is None:
184+
softmax_scale = q.shape[-1]**(-0.5)
185+
# Strict shape checks like the reference implementation
186+
assert k_cache.shape[-1] == 656, (
187+
"The last dim of k_cache must be 656 (=512+2*16+4*4) when "
188+
"is_fp8_kvcache is True")
189+
assert k_cache.shape[-2] == 1, (
190+
"The number of K heads must be 1 when is_fp8_kvcache is True")
191+
192+
out, softmax_lse = torch.ops._flashmla_sparse_C.fwd_kvcache_mla(
193+
q, k_cache, head_dim_v, cache_seqlens, softmax_scale,
194+
tile_scheduler_metadata, num_splits, indices_in_kvcache)
195+
return out, softmax_lse
196+
197+
198+
def flash_mla_sparse_prefill(
199+
q: torch.Tensor,
200+
kv: torch.Tensor,
201+
indices: torch.Tensor,
202+
sm_scale: float,
203+
d_v: int = 512,
204+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
205+
"""
206+
Sparse attention forward operator, for prefill
207+
208+
Args:
209+
q: [s_q, h_q, d_qk], bfloat16
210+
kv: [s_kv, h_kv, d_qk], bfloat16
211+
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1,
212+
or to a number >= s_kv
213+
sm_scale: float, scaling factor for the attention scores
214+
d_v: dimension of the value, default (and only supported) is 512
215+
216+
Returns:
217+
Returns (output, max_logits, lse)
218+
- output: [s_q, h_q, d_v], bfloat16, the result of attention
219+
- max_logits: [s_q, h_q], float
220+
- lse: [s_q, h_q], float, base-2
221+
"""
222+
results = torch.ops._flashmla_sparse_C.sparse_topk_attn_fwd(
223+
q, kv, indices, sm_scale, d_v)
224+
return results
225+
226+
113227
#
114228
# TODO: Add fake functions
115229
#

0 commit comments

Comments
 (0)