Skip to content

Commit 437e6ff

Browse files
committed
[Fix][TIR] LowerThreadAllreduce with correct thread mask
This PR fixes a bug in the LowerThreadAllreduce pass. Prior to this PR, in multi-group settings, the thread mask is not correctly set: when the reduction extent is 32, the thread mask will always be 0. This bug was not spotted because even when the mask is 0, the CUDA program still gives correct result. But in any way, having the zero mask is dangerous and should be fixed.
1 parent b6502f4 commit 437e6ff

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
333333
{
334334
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
335335
if (group_extent > 1) {
336-
mask = mask &
337-
(((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
336+
mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
337+
<< (reduce_extent * cast(mask_dtype, group_index)));
338338
}
339339
seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
340340
// Push the buffer description. Later this will have an
@@ -392,7 +392,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
392392
// During the sub-warp reduction, values from inactive threads could be read,
393393
// which is an undefined behavior according to the cuda document.
394394
//
395-
// In practise, the return value are usually 0, which does no harm to sum reduction.
395+
// In practice, the return value are usually 0, which does no harm to sum reduction.
396396
// However, the result can be incorrect in max or prod reduction.
397397
// Therefore an additional range check has to be performed to ensure the correctness.
398398
if (offset * 2 > reduce_extent) {
@@ -405,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
405405

406406
// Broadcast the reduction result from lane 0 to all other lanes.
407407
// This avoids to emit predicated stores, as all threads are
408-
// uniformly writting the same result.
408+
// uniformly writing the same result.
409409
//
410410
for (size_t i = 0; i < size; ++i) {
411411
Buffer buf = shared_bufs[i];

tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,70 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
235235
B[i] = reduce[0]
236236

237237

238+
class TestMultiGroupMask(BaseCompare):
239+
@T.prim_func
240+
def before(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
241+
T.func_attr({"target": T.target("cuda", host="llvm")})
242+
threadIdx_y = T.launch_thread("threadIdx.y", 32)
243+
cross_thread_B = T.allocate([1], "float32", "local")
244+
threadIdx_x = T.launch_thread("threadIdx.x", 32)
245+
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
246+
with T.attr(
247+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
248+
"reduce_scope",
249+
T.reinterpret("handle", T.uint64(0)),
250+
):
251+
A_1 = T.Buffer((1024,), data=A.data)
252+
T.tvm_thread_allreduce(
253+
T.uint32(1),
254+
A_1[threadIdx_y * 32 + threadIdx_x],
255+
T.bool(True),
256+
cross_thread_B_1[0],
257+
threadIdx_x,
258+
)
259+
if threadIdx_x == 0:
260+
B_1 = T.Buffer((32,), data=B.data)
261+
B_1[threadIdx_y] = cross_thread_B_1[0]
262+
263+
@T.prim_func
264+
def expected(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32,), "float32")):
265+
T.func_attr({"target": T.target("cuda", host="llvm")})
266+
threadIdx_y = T.launch_thread("threadIdx.y", 32)
267+
red_buf0 = T.allocate([1], "float32", "local")
268+
threadIdx_x = T.launch_thread("threadIdx.x", 32)
269+
red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
270+
with T.attr(
271+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
272+
"reduce_scope",
273+
T.reinterpret("handle", T.uint64(0)),
274+
):
275+
mask = T.allocate([1], "uint32", "local")
276+
t0 = T.allocate([1], "float32", "local")
277+
A_1 = T.Buffer((1024,), data=A.data)
278+
red_buf0_1[0] = A_1[threadIdx_y * 32 + threadIdx_x]
279+
280+
mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
281+
mask_1[0] = T.bitwise_and(
282+
T.tvm_warp_activemask(),
283+
T.shift_left(T.uint32(4294967295), T.uint32(32) * T.Cast("uint32", threadIdx_y)),
284+
)
285+
286+
t0_1 = T.Buffer((1,), data=t0, scope="local")
287+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
288+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
289+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32)
290+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
291+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
292+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
293+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
294+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
295+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
296+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
297+
red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 32 * threadIdx_y, 32, 32)
298+
if threadIdx_x == 0:
299+
B_1 = T.Buffer((32,), data=B.data)
300+
B_1[threadIdx_y] = red_buf0_1[0]
301+
302+
238303
if __name__ == "__main__":
239304
tvm.testing.main()

0 commit comments

Comments
 (0)