Skip to content

Commit d59c1a8

Browse files
committed
[BugFix][TIR] Fix multi-grouped multi-warp allreduce
PR #15327 and #15373 introduced multi-warp allreduce implementation. At the time of the introduction, I tested the correctness numerically via the workload of "taking a matrix of ones as input, computing the summation over each row". Both PR passed this numerical tess, while I didn't realize that this test is not complete and cannot guarantee the correctness. The previous implementation has bug which can be tested by turning the input matrix from ones to random floating-point numbers. This will expose the issues of the previous implementation. Therefore, this PR fixes the issues, and add the numerical tests for multi-warp allreduce into `test_allreduce_cuda.py`. By reducing some of the redundant tests in that file, we hope this can reduce the testing time a bit while still guarantee the correctness. Sorry for not testing the implementation completely before.
1 parent d6407be commit d59c1a8

File tree

3 files changed

+38
-28
lines changed

3 files changed

+38
-28
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
7676
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
7777

7878
if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
79-
const AllocateNode* repl = it->second.as<AllocateNode>();
79+
Buffer buf = Downcast<Buffer>(it->second);
8080
auto write_ptr = node.CopyOnWrite();
81-
write_ptr->buffer_var = repl->buffer_var;
82-
write_ptr->dtype = repl->dtype;
83-
write_ptr->extents = repl->extents;
84-
write_ptr->condition = repl->condition;
81+
write_ptr->buffer_var = buf->data;
82+
write_ptr->dtype = buf->dtype;
83+
write_ptr->extents = buf->shape;
84+
write_ptr->condition = const_true(buf->dtype.lanes());
85+
86+
if (buf.scope() == "shared") {
87+
// Use volatile access to shared buffer.
88+
write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, write_ptr->body);
89+
}
8590
}
8691
return std::move(node);
8792
}
@@ -344,15 +349,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
344349
// 4. Load staging buffer.
345350
// Second round of allreduce.
346351
for (size_t i = 0; i < size; ++i) {
347-
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
352+
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
353+
/*indices=*/{group_index * n_warps + reduce_index});
348354
}
349355
if (n_warps < warp_size_) {
350-
mask = mask & (((1 << n_warps) - 1) << group_index);
356+
mask = mask & (((1 << n_warps) - 1) << (group_index * n_warps));
351357
}
352358
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
353359
values, types, combiner, reduce_index, n_warps, group_index, mask,
354-
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
355-
&seq);
360+
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
356361
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
357362

358363
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
@@ -365,9 +370,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
365370
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
366371
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
367372
write_result.push_back(
368-
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
373+
BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
369374
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
370-
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
375+
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
371376
}
372377
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
373378
seq.push_back(SyncThread("shared"));
@@ -382,7 +387,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
382387
load_remap_[buffers[i]->data.get()] = reduce_results[i];
383388

384389
auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
385-
alloc_remap_[buffers[i]->data.get()] = node;
390+
alloc_remap_[buffers[i]->data.get()] = buf;
386391
var_remap_[buffers[i]->data.get()] = buf->data;
387392
buf_remap_[buffers[i].get()] = buf;
388393
}
@@ -400,7 +405,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
400405
// previous iteration on the same buffer.
401406
seq.emplace_back(SyncThread("shared"));
402407
for (size_t idx = 0; idx < size; ++idx) {
403-
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared");
408+
shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
409+
types[idx], "red_buf" + std::to_string(idx), "shared");
404410
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
405411
{BufIndex(reduce_index, group_index, reduce_extent)}));
406412
}
@@ -414,9 +420,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
414420
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
415421
ICHECK_EQ(load->dtype, types[idx]);
416422
load_remap_[buffers[idx]->data.get()] = load;
417-
alloc_remap_[buffers[idx]->data.get()] =
418-
Allocate(shared_bufs[idx]->data, types[idx],
419-
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
423+
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
420424
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
421425
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
422426
}
@@ -772,7 +776,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
772776
// The load remap
773777
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
774778
// Allocate remap
775-
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
779+
std::unordered_map<const VarNode*, Buffer> alloc_remap_;
776780
// BufferVar remap
777781
std::unordered_map<const VarNode*, Var> var_remap_;
778782
// Buffer remap

tests/python/unittest/test_subwarp_reduction_cuda.py renamed to tests/python/unittest/test_allreduce_cuda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def check_max(d1: int, d2: int, d3: int):
9595

9696
for d1 in range(1, 5):
9797
for d2 in range(1, 5):
98-
for d3 in range(2, 33):
98+
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
99+
if d1 * d2 * d3 > 1024:
100+
continue
99101
check_sum(d1, d2, d3)
100102
check_max(d1, d2, d3)
101103

tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
387387
for i in range(128):
388388
threadIdx_x = T.launch_thread("threadIdx.x", 128)
389389
red_result = T.allocate([1], "float32", "shared")
390+
T.attr(red_result, "volatile_scope", 1)
390391
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
391392
with T.attr(
392393
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -463,6 +464,7 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
463464
T.func_attr({"target": T.target("cuda", host="llvm")})
464465
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
465466
red_result = T.allocate([1], "float32", "shared")
467+
T.attr(red_result, "volatile_scope", 1)
466468
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
467469
with T.attr(
468470
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
@@ -550,6 +552,7 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
550552
T.func_attr({"target": T.target("cuda", host="llvm")})
551553
threadIdx_y = T.launch_thread("threadIdx.y", 4)
552554
red_result = T.allocate([4], "float32", "shared")
555+
T.attr(red_result, "volatile_scope", 1)
553556
threadIdx_x = T.launch_thread("threadIdx.x", 128)
554557
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
555558
with T.attr(
@@ -585,23 +588,23 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
585588
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
586589
T.tvm_storage_sync("shared")
587590
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
588-
if threadIdx_x < 16:
589-
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
591+
if threadIdx_x < 4:
592+
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
590593
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
591594
mask_3[0] = T.bitwise_and(
592-
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y))
595+
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(15, threadIdx_y * 4))
593596
)
594597
t0_3 = T.Buffer((1,), data=t0, scope="local")
595598
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 2, 32, 32)
596599
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
597600
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
598601
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
599602
if threadIdx_x == 0:
600-
red_result_1[0] = red_buf0_3[0]
603+
red_result_1[threadIdx_y] = red_buf0_3[0]
601604
T.tvm_storage_sync("shared")
602605
if threadIdx_x == 0:
603606
B_1 = T.Buffer((4,), data=B.data)
604-
B_1[threadIdx_y] = red_result_1[0]
607+
B_1[threadIdx_y] = red_result_1[threadIdx_y]
605608

606609

607610
class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
@@ -636,6 +639,7 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
636639
threadIdx_y = T.launch_thread("threadIdx.y", 2)
637640
in_thread_B = T.allocate([1], "float32", "local")
638641
red_result = T.allocate([2], "float32", "shared")
642+
T.attr(red_result, "volatile_scope", 1)
639643
threadIdx_x = T.launch_thread("threadIdx.x", 512)
640644
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
641645
in_thread_B_1[0] = T.float32(0)
@@ -675,11 +679,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
675679
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
676680
T.tvm_storage_sync("shared")
677681
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
678-
if threadIdx_x < 32:
679-
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
682+
if threadIdx_x < 16:
683+
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 16 + threadIdx_x]
680684
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
681685
mask_3[0] = T.bitwise_and(
682-
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y))
686+
T.tvm_warp_activemask(), T.Cast("uint32", T.shift_left(65535, threadIdx_y * 16))
683687
)
684688
t0_3 = T.Buffer((1,), data=t0, scope="local")
685689
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 8, 32, 32)
@@ -691,11 +695,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
691695
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
692696
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
693697
if threadIdx_x == 0:
694-
red_result_1[0] = red_buf0_3[0]
698+
red_result_1[threadIdx_y] = red_buf0_3[0]
695699
T.tvm_storage_sync("shared")
696700
if threadIdx_x == 0:
697701
B_1 = T.Buffer((2,), data=B.data)
698-
B_1[threadIdx_y] = red_result_1[0]
702+
B_1[threadIdx_y] = red_result_1[threadIdx_y]
699703

700704

701705
if __name__ == "__main__":

0 commit comments

Comments
 (0)