Skip to content

Commit 03af3e7

Browse files
committed
Add kernel selection option for GEMM in environment settings
- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations. - Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value. - Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.
1 parent 7089b00 commit 03af3e7

25 files changed

+1318
-716
lines changed

examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tilelang
22
import tilelang.language as T
33

4-
4+
tilelang.disable_cache()
55
# add decorator @tilelang.jit if you want to return a torch function
66
# @tilelang.jit
77
@tilelang.jit(out_idx=[2])
@@ -56,6 +56,8 @@ def main(M=16384, N=16384, K=16384):
5656
block_N = 128
5757
block_K = 64
5858
jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
59+
60+
print(jit_kernel.get_kernel_source())
5961

6062
import torch
6163

src/op/gemm.cc

Lines changed: 315 additions & 226 deletions
Large diffs are not rendered by default.

src/op/gemm.h

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class GemmWarpPolicyNode : public Object {
4040
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
4141
}
4242

43-
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
43+
std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
4444
Target target,
4545
GemmInst gemm_inst) const;
4646

@@ -84,47 +84,47 @@ class GemmWarpPolicy : public ObjectRef {
8484

8585
class GemmNode : public TileOperatorNode {
8686
public:
87-
bool CheckWGMMA() const;
88-
tir::Buffer A, B, C;
89-
// pointer to the A, B, C
90-
PrimExpr Aptr, Bptr, Cptr;
91-
bool trans_A, trans_B;
92-
int M, N, K;
93-
int stride_A, stride_B;
94-
int offset_A, offset_B;
95-
PrimExpr clear_accum = const_false();
87+
bool checkWgmma() const;
88+
tir::Buffer a_, b_, c_;
89+
// BufferRegion for A, B and C
90+
BufferRegion aRegion_, bRegion_, cRegion_;
91+
bool transA_, transB_;
92+
int m_, n_, k_;
93+
int strideA_, strideB_;
94+
int offsetA_, offsetB_;
95+
PrimExpr clearAccum_ = const_false();
9696
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
9797
// only will be enabled under cdna mfma instructions
98-
int kPack = 1;
99-
int wg_wait = 0;
100-
PrimExpr mbarptr;
101-
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
102-
Array<PrimExpr> C_coords;
103-
mutable GemmWarpPolicy policy;
98+
int kPack_ = 1;
99+
int wgWait_ = 0;
100+
PrimExpr mbarPtr_;
101+
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
102+
Array<PrimExpr> cCoords_;
103+
mutable GemmWarpPolicy policy_;
104104
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
105105

106106
static void RegisterReflection() {
107107
namespace refl = tvm::ffi::reflection;
108108
refl::ObjectDef<GemmNode>()
109-
.def_ro("A", &GemmNode::A)
110-
.def_ro("B", &GemmNode::B)
111-
.def_ro("C", &GemmNode::C)
112-
.def_ro("Aptr", &GemmNode::Aptr)
113-
.def_ro("Bptr", &GemmNode::Bptr)
114-
.def_ro("Cptr", &GemmNode::Cptr)
115-
.def_ro("trans_A", &GemmNode::trans_A)
116-
.def_ro("trans_B", &GemmNode::trans_B)
117-
.def_ro("M", &GemmNode::M)
118-
.def_ro("N", &GemmNode::N)
119-
.def_ro("K", &GemmNode::K)
120-
.def_ro("stride_A", &GemmNode::stride_A)
121-
.def_ro("stride_B", &GemmNode::stride_B)
122-
.def_ro("offset_A", &GemmNode::offset_A)
123-
.def_ro("offset_B", &GemmNode::offset_B)
124-
.def_ro("clear_accum", &GemmNode::clear_accum)
125-
.def_ro("kPack", &GemmNode::kPack)
126-
.def_ro("wg_wait", &GemmNode::wg_wait)
127-
.def_ro("policy", &GemmNode::policy);
109+
.def_ro("a", &GemmNode::a_)
110+
.def_ro("b", &GemmNode::b_)
111+
.def_ro("c", &GemmNode::c_)
112+
.def_ro("aRegion", &GemmNode::aRegion_)
113+
.def_ro("bRegion", &GemmNode::bRegion_)
114+
.def_ro("cRegion", &GemmNode::cRegion_)
115+
.def_ro("transA", &GemmNode::transA_)
116+
.def_ro("transB", &GemmNode::transB_)
117+
.def_ro("m", &GemmNode::m_)
118+
.def_ro("n", &GemmNode::n_)
119+
.def_ro("k", &GemmNode::k_)
120+
.def_ro("strideA", &GemmNode::strideA_)
121+
.def_ro("strideB", &GemmNode::strideB_)
122+
.def_ro("offsetA", &GemmNode::offsetA_)
123+
.def_ro("offsetB", &GemmNode::offsetB_)
124+
.def_ro("clearAccum", &GemmNode::clearAccum_)
125+
.def_ro("kPack", &GemmNode::kPack_)
126+
.def_ro("wgWait", &GemmNode::wgWait_)
127+
.def_ro("policy", &GemmNode::policy_);
128128
}
129129

130130
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
@@ -134,9 +134,9 @@ class GemmNode : public TileOperatorNode {
134134
TileOperator Clone() const;
135135

136136
private:
137-
GemmInst GetGemmInst(int block_size, Target target) const;
138-
bool AllowTCGEN5MMA(Target target) const;
139-
bool AllowWGMMA(int block_size, Target target) const;
137+
GemmInst getGemmInst(int block_size, Target target) const;
138+
bool allowTcgen5Mma(Target target) const;
139+
bool allowWgmma(int block_size, Target target) const;
140140

141141
mutable bool completed_ = false;
142142
};

0 commit comments

Comments
 (0)