@@ -572,12 +572,11 @@ class WSCodeEmitter : public StmtMutator {
572572 WSCodeEmitter (bool is_emitting_producer, IterVar thread_iv,
573573 Map<Var, Buffer> buffer_data_to_buffer,
574574 const WarpSpecializedRoleMarker &marker,
575- bool mbarrier_only = false )
575+ bool mbarrier_only = false , bool only_has_wgmma = false )
576576 : is_emitting_producer_(is_emitting_producer),
577577 buffer_data_to_buffer_ (buffer_data_to_buffer), marker_(marker),
578- thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
579-
580- bool onlyHasWgMMA () const { return only_has_wgmma_; }
578+ thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
579+ only_has_wgmma_(only_has_wgmma) {}
581580
582581 bool hasSimtCopy () const { return has_simt_copy_; }
583582
@@ -617,8 +616,6 @@ class WSCodeEmitter : public StmtMutator {
617616
618617 auto map = ExtractSyncPattern (op->seq );
619618
620- only_has_wgmma_ = WgMMACollector::HasWgMMA (SeqStmt (op->seq ));
621-
622619 /*
623620 std::cout << "Print ExtractSyncPattern" << std::endl;
624621 for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
@@ -1212,11 +1209,12 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12121209 block_realize.CopyOnWrite ()->block = block;
12131210 return block_realize;
12141211 }
1212+ only_has_wgmma_ = WgMMACollector::HasWgMMA (block->body );
12151213 WSCodeEmitter producer (true , thread_iv_, buffer_data_to_buffer_, marker);
1216- WSCodeEmitter consumer (false , thread_iv_, buffer_data_to_buffer_, marker);
1214+ WSCodeEmitter consumer (false , thread_iv_, buffer_data_to_buffer_, marker,
1215+ false , only_has_wgmma_);
12171216 Stmt producer_code = producer (block->body );
12181217 Stmt consumer_code = consumer (block->body );
1219- bool only_has_wgmma = consumer.onlyHasWgMMA ();
12201218 PrimExpr consumer_thread_extent = thread_iv_->dom ->extent ;
12211219 PrimExpr producer_thread_extent = thread_iv_->dom ->extent ;
12221220 // Need one warp-group for bulk-copy only case
@@ -1259,8 +1257,8 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12591257 PrimExpr arrive_thread_count =
12601258 producer.released_barrier_ .count (i)
12611259 ? (producer.hasSimtCopy () ? producer_thread_extent : 1 )
1262- : (only_has_wgmma ? FloorDiv (consumer_thread_extent, 128 )
1263- : consumer_thread_extent);
1260+ : (only_has_wgmma_ ? FloorDiv (consumer_thread_extent, 128 )
1261+ : consumer_thread_extent);
12641262 barrier_num_threads.push_back (arrive_thread_count);
12651263 }
12661264
@@ -1289,6 +1287,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12891287 bool disable_warp_specialized_ = false ;
12901288 bool disable_shuffle_elect_ = false ;
12911289 Array<IntImm> nreg_;
1290+ bool only_has_wgmma_ = false ;
12921291};
12931292
12941293class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
0 commit comments