Skip to content

Commit b04fcf1

Browse files
committed
[Fix][TIR] LowerThreadAllreduce warp reduction mask
The warp reduction implemented by "shuffle down" primitive takes a mask denoting the active threads within the warp that participate in this shuffle. Previously we compute the mask, while in practice we find that it results in "CUDA illegal instruction" error on NVIDIA H100 GPU when the mask is set, and the issue is gone if we do not update the mask. Therefore, this PR updates the allreduce lowering to remove the mask update. Confirmed the correctness on the following devices: * NVIDIA H100, * NVIDIA RTX 4090, * AMD Radeon 7900 XTX, * Apple M2 Ultra.
1 parent 99defd2 commit b04fcf1

File tree

2 files changed

+4
-18
lines changed

2 files changed

+4
-18
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
294294
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
295295

296296
if (reduce_extent <= warp_size_) {
297-
if (group_extent > 1 && reduce_extent < warp_size_) {
298-
mask = mask &
299-
(((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
300-
}
301297
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
302298
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
303299

@@ -352,9 +348,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
352348
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
353349
/*indices=*/{group_index * n_warps + reduce_index});
354350
}
355-
if (n_warps < warp_size_) {
356-
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
357-
}
358351
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
359352
values, types, combiner, reduce_index, n_warps, group_index, mask,
360353
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);

tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,7 @@ def expected(A: T.Buffer((32, 8), "float32"), B: T.Buffer((32,), "float32")):
342342
t0 = T.decl_buffer([1], "float32", scope="local")
343343
A_1 = T.Buffer((256,), data=A.data)
344344
red_buf0_1[0] = A_1[threadIdx_y * 8 + threadIdx_x]
345-
mask[0] = T.bitwise_and(
346-
T.tvm_warp_activemask(),
347-
T.shift_left(T.uint32(255), T.uint32(8) * T.Cast("uint32", threadIdx_y)),
348-
)
345+
mask[0] = T.tvm_warp_activemask()
349346
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 4, 32, 32)
350347
red_buf0_1[0] = red_buf0_1[0] + t0[0]
351348
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0_1[0], 2, 32, 32)
@@ -421,7 +418,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
421418
T.tvm_storage_sync("shared")
422419
if threadIdx_x < 4:
423420
red_buf0[0] = red_buf_staging[threadIdx_x]
424-
mask[0] = T.bitwise_and(T.tvm_warp_activemask(), T.uint32(15))
421+
mask[0] = T.tvm_warp_activemask()
425422
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32)
426423
red_buf0[0] = red_buf0[0] + t0[0]
427424
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32)
@@ -573,9 +570,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
573570
T.tvm_storage_sync("shared")
574571
if threadIdx_x < 4:
575572
red_buf0[0] = red_buf_staging[threadIdx_y * 4 + threadIdx_x]
576-
mask[0] = T.bitwise_and(
577-
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4))
578-
)
573+
mask[0] = T.tvm_warp_activemask()
579574
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 2, 32, 32)
580575
red_buf0[0] = red_buf0[0] + t0[0]
581576
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 1, 32, 32)
@@ -657,9 +652,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
657652
T.tvm_storage_sync("shared")
658653
if threadIdx_x < 16:
659654
red_buf0[0] = red_buf_staging[threadIdx_y * 16 + threadIdx_x]
660-
mask[0] = T.bitwise_and(
661-
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16))
662-
)
655+
mask[0] = T.tvm_warp_activemask()
663656
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 8, 32, 32)
664657
red_buf0[0] = red_buf0[0] + t0[0]
665658
t0[0] = T.tvm_warp_shuffle_down(mask[0], red_buf0[0], 4, 32, 32)

0 commit comments

Comments
 (0)