3838namespace tvm {
3939namespace tir {
4040
41- class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
42- public:
43- explicit UpdatePointerStorageScopeAllReduce (
44- const std::unordered_map<const VarNode*, String>& new_storage_scopes)
45- : UpdatePointerStorageScope(new_storage_scopes) {}
46-
47- Stmt VisitStmt_ (const AllocateNode* op) final {
48- auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr (op->buffer_var ));
49- auto new_scope = GetPtrStorageScope (remapped);
50- if (new_scope != GetPtrStorageScope (op->buffer_var )) {
51- Stmt body = StmtExprMutator::VisitStmt (op->body );
52- if (new_scope == " shared" ) {
53- // use volatile access to shared buffer.
54- body = AttrStmt (remapped, attr::volatile_scope, 1 , body);
55- }
56- return Allocate (remapped, op->dtype , op->extents , op->condition , body, op->annotations );
57- }
58- return StmtExprMutator::VisitStmt_ (op);
59- }
60- };
61-
6241class ThreadAllreduceBuilder final : public StmtExprMutator {
6342 public:
6443 explicit ThreadAllreduceBuilder (const TargetNode* target)
@@ -98,11 +77,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
9877
9978 if (auto it = alloc_remap_.find (node->buffer_var .get ()); it != alloc_remap_.end ()) {
10079 const AllocateNode* repl = it->second .as <AllocateNode>();
101- if (warp_allocs_.count (repl)) {
102- new_storage_scopes_[repl->buffer_var .get ()] = " local" ;
103- } else {
104- new_storage_scopes_[repl->buffer_var .get ()] = " shared" ;
105- }
10680 auto write_ptr = node.CopyOnWrite ();
10781 write_ptr->buffer_var = repl->buffer_var ;
10882 write_ptr->dtype = repl->dtype ;
@@ -161,8 +135,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
161135 return std::move (store);
162136 }
163137
164- std::unordered_map<const VarNode*, String> new_storage_scopes_;
165-
166138 private:
167139 // Thread entry
168140 struct ThreadEntry {
@@ -310,6 +282,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
310282 // In the second stage we use the first 16 lanes of the first warp to reduce
311283 // the remaining elements, and this reduction can also be optimized by
312284 // shuffle_down warp-level primitives.
285+ PrimExpr zero_index = make_const (reduce_index->dtype , 0 );
313286 if (IsWarpReduction (types, group_extent, reduce_extent, contiguous_reduce_extent)) {
314287 std::vector<PrimExpr> reduce_results;
315288 DataType mask_dtype = DataType::UInt (32 );
@@ -322,6 +295,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
322295 }
323296 std::tie (reduce_results, new_alloc_bufs) = MakeWarpAllreduce (
324297 values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
298+
299+ // Broadcast the reduction result from lane 0 to all other lanes.
300+ // This avoids to emit predicated stores, as all threads are
301+ // uniformly writing the same result.
302+ for (size_t i = 0 ; i < size; ++i) {
303+ Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer ;
304+ PrimExpr val = BufferLoad (buf, {zero_index});
305+ ICHECK_EQ (val->dtype , types[i]);
306+ PrimExpr splat = WarpShuffle (builtin::tvm_warp_shuffle (), new_alloc_bufs.back (), val,
307+ reduce_extent * group_index);
308+ seq.push_back (BufferStore (buf, splat, {zero_index}));
309+ }
325310 } else {
326311 int n_warps = reduce_extent / warp_size_;
327312 std::vector<Buffer> local_bufs;
@@ -352,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
352337 /* value=*/ reduce_results[i],
353338 /* indices=*/ {group_index * n_warps + floordiv (reduce_index, warp_size_)}));
354339 }
355- PrimExpr cond = floormod (reduce_index, warp_size_) == make_const (reduce_index-> dtype , 0 ) ;
340+ PrimExpr cond = floormod (reduce_index, warp_size_) == zero_index ;
356341 seq.push_back (IfThenElse (cond, SeqStmt::Flatten (write_staging_buf)));
357342 seq.push_back (SyncThread (" shared" ));
358343
@@ -369,6 +354,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
369354 /* predicate=*/ reduce_index < make_const (reduce_index->dtype , group_extent * n_warps),
370355 &seq);
371356 new_alloc_bufs.insert (new_alloc_bufs.end (), local_bufs.begin (), local_bufs.end ());
357+
358+ // 5. Create shared memory buffer(s) of `group_extent` elements, storing
359+ // the allreduce results so each thread can access.
360+ std::vector<Stmt> write_result;
361+ write_result.reserve (size);
362+ for (size_t i = 0 ; i < size; ++i) {
363+ new_alloc_bufs.push_back (Downcast<BufferLoad>(reduce_results[i])->buffer );
364+ Buffer broadcast_shared_buf = decl_buffer (
365+ /* shape=*/ {make_const (reduce_index->dtype , group_extent)},
366+ /* dtype=*/ buffers[i]->dtype , /* name=*/ " red_result" , /* storage_scope=*/ " shared" );
367+ write_result.push_back (
368+ BufferStore (broadcast_shared_buf, reduce_results[i], {zero_index}));
369+ // Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
370+ reduce_results[i] = BufferLoad (broadcast_shared_buf, {zero_index});
371+ }
372+ seq.push_back (IfThenElse (reduce_index == zero_index, SeqStmt::Flatten (write_result)));
373+ seq.push_back (SyncThread (" shared" ));
372374 }
373375
374376 // Write back allreduce results and update existing allocations.
@@ -379,12 +381,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
379381 ICHECK_EQ (reduce_results[i]->dtype , types[i]);
380382 load_remap_[buffers[i]->data .get ()] = reduce_results[i];
381383
382- Array<PrimExpr> extents{PrimExpr (1 )};
383- auto node = Allocate (buf->data , types[i], extents, pred, Evaluate (0 ));
384+ auto node = Allocate (buf->data , types[i], buf->shape , pred, Evaluate (0 ));
384385 alloc_remap_[buffers[i]->data .get ()] = node;
385386 var_remap_[buffers[i]->data .get ()] = buf->data ;
386387 buf_remap_[buffers[i].get ()] = buf;
387- warp_allocs_.insert (node.get ());
388388 }
389389 } else {
390390 std::vector<Buffer> shared_bufs (size);
@@ -400,7 +400,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
400400 // previous iteration on the same buffer.
401401 seq.emplace_back (SyncThread (" shared" ));
402402 for (size_t idx = 0 ; idx < size; ++idx) {
403- shared_bufs[idx] = decl_buffer ({1 }, types[idx], " red_buf" + std::to_string (idx));
403+ shared_bufs[idx] = decl_buffer ({1 }, types[idx], " red_buf" + std::to_string (idx), " shared " );
404404 seq.emplace_back (BufferStore (shared_bufs[idx], values[idx],
405405 {BufIndex (reduce_index, group_index, reduce_extent)}));
406406 }
@@ -426,9 +426,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
426426 Stmt body = SeqStmt::Flatten (seq);
427427 for (Buffer buf : new_alloc_bufs) {
428428 body = Allocate (buf->data , buf->dtype , buf->shape , const_true (buf->dtype .lanes ()), body);
429- if (buf.scope () != " shared" ) {
430- new_storage_scopes_[buf->data .get ()] = " local" ;
431- }
432429 }
433430
434431 return body;
@@ -457,12 +454,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
457454 std::vector<Stmt> load_values;
458455 load_values.reserve (n_buffers);
459456 for (int idx = 0 ; idx < n_buffers; ++idx) {
460- shared_bufs.push_back (decl_buffer (shape, dtypes[idx], " red_buf" + std::to_string (idx)));
457+ shared_bufs.push_back (
458+ decl_buffer (shape, dtypes[idx], " red_buf" + std::to_string (idx), " local" ));
461459 load_values.push_back (BufferStore (shared_bufs[idx], src_values[idx], zero_indices));
462460
463461 // Uses a local variable to store the shuffled data. Later
464462 // on, an allocation will be built for this local variable.
465- local_bufs.push_back (decl_buffer (shape, dtypes[idx], " t" + std::to_string (idx)));
463+ local_bufs.push_back (decl_buffer (shape, dtypes[idx], " t" + std::to_string (idx), " local " ));
466464 }
467465
468466 if (predicate.defined ()) {
@@ -474,7 +472,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
474472 // The mask for this reducer, as this reducer may sit inside
475473 // a divergent control flow. Here it uses a variable to cache the current
476474 // active channels.
477- Buffer mask_buffer = decl_buffer (shape, mask->dtype , " mask" );
475+ Buffer mask_buffer = decl_buffer (shape, mask->dtype , " mask" , " local " );
478476 {
479477 seq->emplace_back (BufferStore (mask_buffer, mask, zero_indices));
480478 // Push the buffer description. Later this will have an
@@ -543,18 +541,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
543541 }
544542 }
545543
546- // Broadcast the reduction result from lane 0 to all other lanes.
547- // This avoids to emit predicated stores, as all threads are
548- // uniformly writing the same result.
549- for (int i = 0 ; i < n_buffers; ++i) {
550- Buffer buf = shared_bufs[i];
551- PrimExpr val = BufferLoad (buf, zero_indices);
552- ICHECK_EQ (val->dtype , dtypes[i]);
553- PrimExpr splat =
554- WarpShuffle (builtin::tvm_warp_shuffle (), mask_buffer, val, reduce_extent * group_index);
555- seq->push_back (BufferStore (buf, splat, zero_indices));
556- }
557-
558544 std::vector<PrimExpr> reduce_results;
559545 reduce_results.reserve (n_buffers);
560546 for (int i = 0 ; i < n_buffers; ++i) {
@@ -791,8 +777,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
791777 std::unordered_map<const VarNode*, Var> var_remap_;
792778 // Buffer remap
793779 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
794- // Allocate from warp reductions
795- std::unordered_set<const void *> warp_allocs_;
796780 // Internal analyzer
797781 arith::Analyzer analyzer_;
798782};
@@ -806,9 +790,7 @@ Pass LowerThreadAllreduce() {
806790 ICHECK (target.defined ()) << " LowerThreadAllreduce: Require the target attribute" ;
807791 const TargetNode* target_node = target.as <TargetNode>();
808792 ThreadAllreduceBuilder thread_all_reduce (target_node);
809- auto reduce_body = thread_all_reduce (n->body );
810- n->body =
811- UpdatePointerStorageScopeAllReduce (thread_all_reduce.new_storage_scopes_ )(reduce_body);
793+ n->body = thread_all_reduce (n->body );
812794 return f;
813795 };
814796 return CreatePrimFuncPass (pass_func, 0 , " tir.LowerThreadAllreduce" , {});
0 commit comments