Skip to content

Commit 4a74b62

Browse files
committed
Refactor attention sink examples to simplify index calculations
- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices. - Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.
1 parent 05b68d0 commit 4a74b62

File tree

4 files changed

+34
-24
lines changed

4 files changed

+34
-24
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,10 @@ def flash_fwd(
8181
sinks[i] = Sinks[by]
8282

8383
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
84-
start = T.alloc_local([1], 'int32')
85-
if window_size is not None:
86-
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
87-
else:
88-
start[0] = 0
84+
start = T.max(0,
85+
(bx * block_M - window_size) // block_N) if window_size is not None else 0
8986

90-
for k in T.Pipelined(start[0], end, num_stages=num_stages):
87+
for k in T.Pipelined(start, end, num_stages=num_stages):
9188
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
9289
for i, j in T.Parallel(block_M, block_N):
9390
q_idx = bx * block_M + i

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,10 @@ def flash_bwd(
267267
T.clear(dk)
268268

269269
loop_st = T.floordiv(by * block_M, block_N)
270-
loop_ed = T.alloc_local([1], 'int32')
271-
if window_size is not None:
272-
loop_ed[0] = T.min(
273-
T.ceildiv((by + 1) * block_M + window_size, block_N),
274-
T.ceildiv(seq_len, block_N))
275-
else:
276-
loop_ed[0] = T.ceildiv(seq_len, block_N)
277-
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
270+
loop_ed = T.min(
271+
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
272+
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
273+
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
278274
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
279275
T.clear(qkT)
280276
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

src/transform/legalize_negative_index.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
5151
states.push_back(IndexSignState::kUnknown);
5252
needs_record = true;
5353
DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
54-
<< simplified << " for buffer " << load->buffer->name
55-
<< " (axis " << i << ").";
54+
<< simplified << " for buffer " << load->buffer->name
55+
<< " (axis " << i << ").";
5656
}
5757

5858
if (needs_record) {

tilelang/intrinsics/wgmma_macro_generator.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def wgmma(self,
273273
def _warp_mma(A_ptr, B_ptr, C_buf):
274274
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
275275

276+
scale_out = T.alloc_var("int32")
276277
desc_a = T.alloc_wgmma_desc()
277278
desc_b = T.alloc_wgmma_desc()
278279
T.initialize_wgmma_descriptor(desc_a, A_ptr, a_swizzle_mode,
@@ -283,12 +284,16 @@ def _warp_mma(A_ptr, B_ptr, C_buf):
283284
int(b_stride_byte_offset >> 4))
284285
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
285286
T.warpgroup_arrive()
286-
for j in T.serial(num_inst_n):
287-
for i in T.serial(num_inst_m):
288-
for ki in T.serial(k_dim // micro_size_k):
287+
288+
if clear_accum:
289+
scale_out = 0
290+
else:
291+
scale_out = 1
292+
for j in T.unroll(num_inst_n):
293+
for i in T.unroll(num_inst_m):
294+
for ki in T.unroll(k_dim // micro_size_k):
289295
warp_i = (warp_m // 4) * num_inst_m + i
290296
warp_j = warp_n * num_inst_n + j
291-
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
292297
A_offset = (
293298
ki % ak_atom_size
294299
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
@@ -305,6 +310,9 @@ def _warp_mma(A_ptr, B_ptr, C_buf):
305310
(A_offset * elems_in_bytes) >> 4, desc_b.data,
306311
(B_offset * elems_in_bytes) >> 4, C_buf.data, C_offset,
307312
scale_out, scale_in_a, scale_in_b)
313+
if clear_accum:
314+
scale_out = 1
315+
308316
T.warpgroup_commit_batch()
309317
if wg_wait >= 0:
310318
T.warpgroup_wait(wg_wait)
@@ -387,6 +395,7 @@ def wgmma_rs(self,
387395
def _warp_mma(A_buf, B_ptr, C_buf):
388396
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
389397

398+
scale_out = T.alloc_var("int32")
390399
desc_b = T.alloc_wgmma_desc()
391400
T.initialize_wgmma_descriptor(desc_b, B_ptr, b_swizzle_mode,
392401
int(b_leading_byte_offset >> 4),
@@ -395,11 +404,16 @@ def _warp_mma(A_buf, B_ptr, C_buf):
395404
T.warpgroup_fence_operand(C_buf, num_regs=accum_regs)
396405
T.warpgroup_arrive()
397406

398-
for j in T.serial(0, num_inst_n):
399-
for i in T.serial(num_inst_m):
400-
for ki in T.serial(0, (k_dim // micro_size_k)):
407+
if clear_accum:
408+
scale_out = 0
409+
else:
410+
scale_out = 1
411+
412+
for j in T.unroll(0, num_inst_n):
413+
for i in T.unroll(num_inst_m):
414+
for ki in T.unroll(0, (k_dim // micro_size_k)):
401415
warp_j = warp_n * num_inst_n + j
402-
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
416+
403417
A_offset = ki * warp_rows * local_size_a + i * local_size_a
404418
B_offset = (
405419
ki // bk_atom_size
@@ -425,6 +439,9 @@ def _warp_mma(A_buf, B_ptr, C_buf):
425439
scale_in_a,
426440
scale_in_b,
427441
)
442+
if clear_accum:
443+
scale_out = 1
444+
428445
T.warpgroup_commit_batch()
429446
if wg_wait >= 0:
430447
T.warpgroup_wait(wg_wait)

0 commit comments

Comments
 (0)