Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
}

bool Gemm::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}

if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
Expand Down Expand Up @@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
B->dtype.bits(), trans_B ? 2 : 1);
results.Set(B, ABLayout);
} else {
ICHECK(0) << "WGMMA only support B in shared.";
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
}
} else if (TargetIsCDNA(T.target)) {
auto fragment =
Expand Down Expand Up @@ -490,4 +496,4 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
} // namespace tvm
13 changes: 13 additions & 0 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The wg_wait template parameter is unused in this function. Consider removing it to simplify the function signature. Since this function implements a non-WGMMA path, the warp-group wait parameter is not applicable here.

TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

template <int num_mma> TL_DEVICE void wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}
Expand Down
19 changes: 9 additions & 10 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,12 +572,11 @@ class WSCodeEmitter : public StmtMutator {
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false)
bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}

bool onlyHasWgMMA() const { return only_has_wgmma_; }
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}

bool hasSimtCopy() const { return has_simt_copy_; }

Expand Down Expand Up @@ -617,8 +616,6 @@ class WSCodeEmitter : public StmtMutator {

auto map = ExtractSyncPattern(op->seq);

only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));

/*
std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Expand Down Expand Up @@ -1212,11 +1209,12 @@ class WarpSpecializedRewriter : public StmtExprMutator {
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
false, only_has_wgmma_);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
bool only_has_wgmma = consumer.onlyHasWgMMA();
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
Expand Down Expand Up @@ -1259,8 +1257,8 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
: (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count);
}

Expand Down Expand Up @@ -1289,6 +1287,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
bool only_has_wgmma_ = false;
};

class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
Expand Down
Loading