Skip to content

Commit fc41463

Browse files
authored
[BugFix] Robust gemm policy for sparse_mla_fwd in Hopper and Ada Lovelace architectures (#984)
* [BugFix] Robust gemm policy for sparse_mla_fwd in Hopper and Ada Lovelace architectures * [Lint]
1 parent b0b5347 commit fc41463

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

examples/deepseek_v32/sparse_mla_fwd.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ def main(
136136
KV_shared,
137137
acc_s,
138138
transpose_B=True,
139-
policy=T.GemmWarpPolicy.FullCol,
139+
policy=T.GemmWarpPolicy.FullRow,
140140
)
141141
T.gemm(
142142
Q_tail_shared,
143143
K_tail_shared,
144144
acc_s,
145145
transpose_B=True,
146-
policy=T.GemmWarpPolicy.FullCol,
146+
policy=T.GemmWarpPolicy.FullRow,
147147
)
148148
T.copy(m_i, m_i_prev)
149149
T.reduce_max(acc_s, m_i, dim=1, clear=False)
@@ -158,7 +158,7 @@ def main(
158158
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
159159

160160
T.copy(acc_s, S_shared)
161-
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
161+
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
162162

163163
# Rescale
164164
for h_i, d_i in T.Parallel(H_per_block, D):
@@ -174,7 +174,15 @@ def main(
174174
return main
175175

176176

177-
def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512):
177+
def sparse_mla_fwd_interface(q,
178+
kv,
179+
indices,
180+
sm_scale=None,
181+
return_p_sum: bool = False,
182+
d_v=512,
183+
block_I=64,
184+
num_stages=2,
185+
threads=256):
178186
is_casual = True
179187
assert return_p_sum == False, "This kernel file is for fwd only"
180188
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
@@ -190,7 +198,17 @@ def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool =
190198
_, _, _, topk = indices.shape
191199
assert indices.shape == (batch, seq_len, kv_group, topk)
192200

193-
kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual)
201+
kernel = sparse_mla_fwd(
202+
heads,
203+
dim,
204+
tail_dim,
205+
topk,
206+
kv_group,
207+
sm_scale,
208+
is_casual,
209+
block_I=block_I,
210+
num_stages=num_stages,
211+
threads=threads)
194212
out, lse = kernel(q, kv, indices)
195213
return out, lse
196214

@@ -241,7 +259,10 @@ def test_sparse_mla_fwd(B=1,
241259
DV=512,
242260
topk=2048,
243261
dtype=torch.bfloat16,
244-
check_correctness=True):
262+
check_correctness=True,
263+
block_I=64,
264+
num_stages=2,
265+
threads=256):
245266
torch.random.manual_seed(0)
246267
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
247268
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
@@ -253,7 +274,8 @@ def test_sparse_mla_fwd(B=1,
253274
i_i = torch.randperm(max(1, t))[:topk]
254275
indices[b, t, h, :len(i_i)] = i_i
255276

256-
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
277+
tl_out, tl_lse = sparse_mla_fwd_interface(
278+
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
257279

258280
if check_correctness:
259281
# otherwise may cause out of memory
@@ -262,7 +284,8 @@ def test_sparse_mla_fwd(B=1,
262284
print("assert_tensors_similar passed")
263285

264286
def fn():
265-
return sparse_mla_fwd_interface(q, kv, indices)
287+
return sparse_mla_fwd_interface(
288+
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
266289

267290
from tilelang.profiler import do_bench
268291

@@ -287,4 +310,7 @@ def fn():
287310
DV=512,
288311
topk=2048,
289312
dtype=torch.bfloat16,
290-
check_correctness=True)
313+
check_correctness=True,
314+
block_I=64,
315+
num_stages=2,
316+
threads=256)

0 commit comments

Comments
 (0)