Skip to content

Commit a00c797

Browse files
authored
Refactor MLA decode kernel: Replace T.If with native Python if statement (#162)
Simplify the control flow in the MLA decode kernel by replacing TileLang's T.If construct with a standard Python if statement. This change improves code readability and maintains the existing logic for handling sequence length constraints during block-wise computation.
1 parent 9789049 commit a00c797

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/deepseek_mla/example_mla_decode_paged.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def flash_mla_kernel(
7373
policy=T.GemmWarpPolicy.FullCol)
7474
T.copy(scores_max, scores_max_prev)
7575
T.fill(scores_max, -T.infinity(accum_dtype))
76-
with T.If(kr == 0), T.Then():
76+
if kr == 0:
7777
for i, j in T.Parallel(block_H, block_N):
7878
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
7979
T.reduce_max(acc_s, scores_max, dim=1, clear=False)

0 commit comments

Comments
 (0)