Skip to content

Commit 38062eb

Browse files
committed
[TIR] Allreduce broadcast result to each thread in multi-warp case
PR #15327 introduces the warp-level primitive support in multi-warp allreduce. However, due to the specialty of the two-stage shuffle-down reduction implementation of the allreduce in multi-warp scenarios, PR #15327 did not broadcast the allreduce result to each reduction thread. This behavior does not align with the semantics of allreduce and is not ideal for many use cases. Therefore, this PR completes the implementation by inserting a stage of writing the reduction results to shared memory, so that each reduction thread across all the reduction warps can access the reduction results. This shared memory write-back stage will only be inserted in multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync` is used to broadcast the reduction results across reduction threads. Since in multi-warp settings we cannot leverage warp-level primitives to broadcast the value, we can only make use of shared memory. The numerical correctness are verified locally.
1 parent 03fecba commit 38062eb

File tree

2 files changed

+70
-76
lines changed

2 files changed

+70
-76
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,6 @@
3838
namespace tvm {
3939
namespace tir {
4040

41-
class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
42-
public:
43-
explicit UpdatePointerStorageScopeAllReduce(
44-
const std::unordered_map<const VarNode*, String>& new_storage_scopes)
45-
: UpdatePointerStorageScope(new_storage_scopes) {}
46-
47-
Stmt VisitStmt_(const AllocateNode* op) final {
48-
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
49-
auto new_scope = GetPtrStorageScope(remapped);
50-
if (new_scope != GetPtrStorageScope(op->buffer_var)) {
51-
Stmt body = StmtExprMutator::VisitStmt(op->body);
52-
if (new_scope == "shared") {
53-
// use volatile access to shared buffer.
54-
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
55-
}
56-
return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations);
57-
}
58-
return StmtExprMutator::VisitStmt_(op);
59-
}
60-
};
61-
6241
class ThreadAllreduceBuilder final : public StmtExprMutator {
6342
public:
6443
explicit ThreadAllreduceBuilder(const TargetNode* target)
@@ -98,11 +77,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
9877

9978
if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
10079
const AllocateNode* repl = it->second.as<AllocateNode>();
101-
if (warp_allocs_.count(repl)) {
102-
new_storage_scopes_[repl->buffer_var.get()] = "local";
103-
} else {
104-
new_storage_scopes_[repl->buffer_var.get()] = "shared";
105-
}
10680
auto write_ptr = node.CopyOnWrite();
10781
write_ptr->buffer_var = repl->buffer_var;
10882
write_ptr->dtype = repl->dtype;
@@ -161,8 +135,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
161135
return std::move(store);
162136
}
163137

164-
std::unordered_map<const VarNode*, String> new_storage_scopes_;
165-
166138
private:
167139
// Thread entry
168140
struct ThreadEntry {
@@ -310,6 +282,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
310282
// In the second stage we use the first 16 lanes of the first warp to reduce
311283
// the remaining elements, and this reduction can also be optimized by
312284
// shuffle_down warp-level primitives.
285+
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
313286
if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
314287
std::vector<PrimExpr> reduce_results;
315288
DataType mask_dtype = DataType::UInt(32);
@@ -322,6 +295,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
322295
}
323296
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
324297
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
298+
299+
// Broadcast the reduction result from lane 0 to all other lanes.
300+
// This avoids to emit predicated stores, as all threads are
301+
// uniformly writing the same result.
302+
for (size_t i = 0; i < size; ++i) {
303+
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
304+
PrimExpr val = BufferLoad(buf, {zero_index});
305+
ICHECK_EQ(val->dtype, types[i]);
306+
PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val,
307+
reduce_extent * group_index);
308+
seq.push_back(BufferStore(buf, splat, {zero_index}));
309+
}
325310
} else {
326311
int n_warps = reduce_extent / warp_size_;
327312
std::vector<Buffer> local_bufs;
@@ -352,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
352337
/*value=*/reduce_results[i],
353338
/*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
354339
}
355-
PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0);
340+
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
356341
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
357342
seq.push_back(SyncThread("shared"));
358343

@@ -369,6 +354,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
369354
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
370355
&seq);
371356
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
357+
358+
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
359+
// the allreduce results so each thread can access.
360+
std::vector<Stmt> write_result;
361+
write_result.reserve(size);
362+
for (size_t i = 0; i < size; ++i) {
363+
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
364+
Buffer broadcast_shared_buf = decl_buffer(
365+
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
366+
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
367+
write_result.push_back(
368+
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
369+
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
370+
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
371+
}
372+
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
373+
seq.push_back(SyncThread("shared"));
372374
}
373375

374376
// Write back allreduce results and update existing allocations.
@@ -379,12 +381,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
379381
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
380382
load_remap_[buffers[i]->data.get()] = reduce_results[i];
381383

382-
Array<PrimExpr> extents{PrimExpr(1)};
383-
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
384+
auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
384385
alloc_remap_[buffers[i]->data.get()] = node;
385386
var_remap_[buffers[i]->data.get()] = buf->data;
386387
buf_remap_[buffers[i].get()] = buf;
387-
warp_allocs_.insert(node.get());
388388
}
389389
} else {
390390
std::vector<Buffer> shared_bufs(size);
@@ -400,7 +400,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
400400
// previous iteration on the same buffer.
401401
seq.emplace_back(SyncThread("shared"));
402402
for (size_t idx = 0; idx < size; ++idx) {
403-
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
403+
shared_bufs[idx] = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx), "shared");
404404
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
405405
{BufIndex(reduce_index, group_index, reduce_extent)}));
406406
}
@@ -426,9 +426,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
426426
Stmt body = SeqStmt::Flatten(seq);
427427
for (Buffer buf : new_alloc_bufs) {
428428
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
429-
if (buf.scope() != "shared") {
430-
new_storage_scopes_[buf->data.get()] = "local";
431-
}
432429
}
433430

434431
return body;
@@ -457,12 +454,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
457454
std::vector<Stmt> load_values;
458455
load_values.reserve(n_buffers);
459456
for (int idx = 0; idx < n_buffers; ++idx) {
460-
shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx)));
457+
shared_bufs.push_back(
458+
decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
461459
load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
462460

463461
// Uses a local variable to store the shuffled data. Later
464462
// on, an allocation will be built for this local variable.
465-
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx)));
463+
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
466464
}
467465

468466
if (predicate.defined()) {
@@ -474,7 +472,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
474472
// The mask for this reducer, as this reducer may sit inside
475473
// a divergent control flow. Here it uses a variable to cache the current
476474
// active channels.
477-
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask");
475+
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
478476
{
479477
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
480478
// Push the buffer description. Later this will have an
@@ -543,18 +541,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
543541
}
544542
}
545543

546-
// Broadcast the reduction result from lane 0 to all other lanes.
547-
// This avoids to emit predicated stores, as all threads are
548-
// uniformly writing the same result.
549-
for (int i = 0; i < n_buffers; ++i) {
550-
Buffer buf = shared_bufs[i];
551-
PrimExpr val = BufferLoad(buf, zero_indices);
552-
ICHECK_EQ(val->dtype, dtypes[i]);
553-
PrimExpr splat =
554-
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
555-
seq->push_back(BufferStore(buf, splat, zero_indices));
556-
}
557-
558544
std::vector<PrimExpr> reduce_results;
559545
reduce_results.reserve(n_buffers);
560546
for (int i = 0; i < n_buffers; ++i) {
@@ -791,8 +777,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
791777
std::unordered_map<const VarNode*, Var> var_remap_;
792778
// Buffer remap
793779
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
794-
// Allocate from warp reductions
795-
std::unordered_set<const void*> warp_allocs_;
796780
// Internal analyzer
797781
arith::Analyzer analyzer_;
798782
};
@@ -806,9 +790,7 @@ Pass LowerThreadAllreduce() {
806790
ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
807791
const TargetNode* target_node = target.as<TargetNode>();
808792
ThreadAllreduceBuilder thread_all_reduce(target_node);
809-
auto reduce_body = thread_all_reduce(n->body);
810-
n->body =
811-
UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
793+
n->body = thread_all_reduce(n->body);
812794
return f;
813795
};
814796
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});

0 commit comments

Comments
 (0)