Skip to content

Commit f4a828f

Browse files
xwhzzLeiWang1999
andauthored
[Enhancement][Bugfix] Fix bug in warp specialized pass and add gemm_sr fallback support for Hopper (#712)
* bug fix and support gemm_sr fallback for hopper * Update gemm.cc --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 1b308ba commit f4a828f

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

src/op/gemm.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
241241
}
242242

243243
bool Gemm::CheckWGMMA() const {
244+
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
245+
return false;
246+
}
247+
244248
if (C->dtype == DataType::Float(16)) {
245249
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
246250
return K % 16 == 0;
@@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
443447
B->dtype.bits(), trans_B ? 2 : 1);
444448
results.Set(B, ABLayout);
445449
} else {
446-
ICHECK(0) << "WGMMA only support B in shared.";
450+
auto fragment =
451+
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
452+
results.Set(B, fragment->BindThreadRange(thread_range));
447453
}
448454
} else if (TargetIsCDNA(T.target)) {
449455
auto fragment =
@@ -490,4 +496,4 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
490496
Integer(CallEffectKind::kOpaque));
491497

492498
} // namespace tl
493-
} // namespace tvm
499+
} // namespace tvm

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
624624
}
625625
}
626626

627+
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
628+
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
629+
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
630+
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
631+
TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
632+
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
633+
using MMA =
634+
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
635+
trans_B, clear_accum, lda, ldb, offset_a,
636+
offset_b, A_type, B_type, C_type>;
637+
MMA::body_sr(pA, pB, accum);
638+
}
639+
627640
template <int num_mma> TL_DEVICE void wait_wgmma() {
628641
cute::warpgroup_wait<num_mma>();
629642
}

src/transform/warp_specialized_rewriter.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,11 @@ class WSCodeEmitter : public StmtMutator {
572572
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
573573
Map<Var, Buffer> buffer_data_to_buffer,
574574
const WarpSpecializedRoleMarker &marker,
575-
bool mbarrier_only = false)
575+
bool mbarrier_only = false, bool only_has_wgmma = false)
576576
: is_emitting_producer_(is_emitting_producer),
577577
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
578-
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}
579-
580-
bool onlyHasWgMMA() const { return only_has_wgmma_; }
578+
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
579+
only_has_wgmma_(only_has_wgmma) {}
581580

582581
bool hasSimtCopy() const { return has_simt_copy_; }
583582

@@ -617,8 +616,6 @@ class WSCodeEmitter : public StmtMutator {
617616

618617
auto map = ExtractSyncPattern(op->seq);
619618

620-
only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));
621-
622619
/*
623620
std::cout << "Print ExtractSyncPattern" << std::endl;
624621
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
@@ -1212,11 +1209,12 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12121209
block_realize.CopyOnWrite()->block = block;
12131210
return block_realize;
12141211
}
1212+
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
12151213
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
1216-
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
1214+
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
1215+
false, only_has_wgmma_);
12171216
Stmt producer_code = producer(block->body);
12181217
Stmt consumer_code = consumer(block->body);
1219-
bool only_has_wgmma = consumer.onlyHasWgMMA();
12201218
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
12211219
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
12221220
// Need one warp-group for bulk-copy only case
@@ -1259,8 +1257,8 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12591257
PrimExpr arrive_thread_count =
12601258
producer.released_barrier_.count(i)
12611259
? (producer.hasSimtCopy() ? producer_thread_extent : 1)
1262-
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128)
1263-
: consumer_thread_extent);
1260+
: (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
1261+
: consumer_thread_extent);
12641262
barrier_num_threads.push_back(arrive_thread_count);
12651263
}
12661264

@@ -1289,6 +1287,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
12891287
bool disable_warp_specialized_ = false;
12901288
bool disable_shuffle_elect_ = false;
12911289
Array<IntImm> nreg_;
1290+
bool only_has_wgmma_ = false;
12921291
};
12931292

12941293
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {

0 commit comments

Comments
 (0)