Skip to content

Commit 91a7bb2

Browse files
authored
[TileOp] Introduce a experimental python defined T.gemm_v2 (#793)
* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability - Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`. - Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. - Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety. - Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks. * Refactor GEMM and frontend legalize operations for improved clarity and functionality - Updated `gemm_py.h` to include the correct header for GEMM operations. - Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity. - Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose. - Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity. - Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite. - Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process. * Enhance CUDA code generation and testing for GEMM operations - Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting. - Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope. - Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts. - Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling. - Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. * Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations. - Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage. - Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations. - Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. * Refactor GEMM layout and Python integration for improved functionality - Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations. - Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes. - Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability. - Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow. These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. * Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations. - Improved block realization handling in `gemm_py.cc` for better assignment of global symbols. - Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity. - Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations. These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. * tfloat32 support. * lint fix * lint fix * Refactor shared memory allocation in GEMM tests - Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`. - This change simplifies the allocation process and aligns with the updated GEMM function signatures.
1 parent 9fd6bb3 commit 91a7bb2

36 files changed

+2938
-247
lines changed

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Checks: >
4141
-clang-analyzer-optin.cplusplus.UninitializedObject,
4242
-cppcoreguidelines-pro-type-static-cast-downcast,
4343
-performance-unnecessary-value-param,
44+
-performance-enum-size,
4445
4546
WarningsAsErrors: '*'
4647

3rdparty/tvm

Submodule tvm updated from 1fc7578 to eddefbd

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
132132
if(USE_CUDA)
133133
tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS
134134
src/runtime/*.cc
135+
src/target/ptx.cc
135136
src/target/codegen_cuda.cc
136137
src/target/rt_mod_cuda.cc
137138
)

src/op/copy.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
402402
PassContext pass_ctx = PassContext::Current();
403403
bool disable_tma_lower =
404404
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
405-
406405
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
407406
T.layout_map, T.analyzer, T.buffer_oob);
408407
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {

src/op/gemm.cc

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,6 @@ namespace tl {
1818

1919
using namespace tir;
2020

21-
/**
22-
* @brief Compute the prime factorization of an integer.
23-
*
24-
* Returns the prime factors of x in non-decreasing order by repeatedly dividing
25-
* out the smallest possible factor.
26-
*
27-
* @param x Integer to factorize. If x <= 1, an empty vector is returned.
28-
* @return std::vector<int> Prime factors of x (with multiplicity), in
29-
* non-decreasing order.
30-
*/
31-
static std::vector<int> toPrimeFactors(int x) {
32-
int i = 2;
33-
std::vector<int> result;
34-
while (x > 1) {
35-
if (x % i == 0) {
36-
x /= i;
37-
result.push_back(i);
38-
} else {
39-
i++;
40-
}
41-
}
42-
return result;
43-
}
44-
4521
/**
4622
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
4723
* map.
@@ -268,14 +244,20 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
268244
int best_m = 1;
269245
int best_n = 1;
270246
float best_balance = std::numeric_limits<float>::max();
271-
272247
// Try all possible combinations that satisfy the constraints
273248
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
274249
int n = num_warps / m;
275250

276251
// Calculate how balanced this partition is
277252
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
278253
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
254+
// m_per_warp and n_per_warp must be greater than 1
255+
if (m_per_warp < 1 || n_per_warp < 1)
256+
continue;
257+
// m * n must equal num_warps
258+
if (m * n != num_warps)
259+
continue;
260+
279261
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
280262

281263
if (balance < best_balance) {
@@ -290,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
290272
} else {
291273
ICHECK(0) << "Unknown GemmWarpPolicy";
292274
}
293-
294275
// Store the computed values in the object's member variables
295276
this->m_warp = m_warp;
296277
this->n_warp = n_warp;
@@ -632,5 +613,21 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
632613
.set_attr<TCallEffectKind>("TCallEffectKind",
633614
Integer(CallEffectKind::kOpaque));
634615

616+
TVM_REGISTER_OP("tl.GemmWarpPolicy")
617+
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");
618+
619+
TVM_FFI_STATIC_INIT_BLOCK({
620+
GemmNode::RegisterReflection();
621+
GemmWarpPolicyNode::RegisterReflection();
622+
namespace refl = tvm::ffi::reflection;
623+
refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
624+
[](GemmWarpPolicy policy, int M, int N, int block_size,
625+
Target target, bool is_wgmma) {
626+
policy->ComputeWarpPartition(M, N, block_size, target,
627+
is_wgmma);
628+
return;
629+
});
630+
});
631+
635632
} // namespace tl
636633
} // namespace tvm

src/op/gemm_py.cc

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
/*!
2+
* \file tl/op/gemm_py.cc
3+
* \brief Implementation of General Matrix Multiplication (GEMM) operators
4+
*/
5+
6+
#include "gemm_py.h"
7+
8+
#include "builtin.h"
9+
#include <tvm/tir/builtin.h>
10+
#include <tvm/tir/op.h>
11+
#include <tvm/tir/op_attr_types.h>
12+
#include <tvm/tir/transform.h>
13+
14+
#include "../target/utils.h"
15+
#include "tvm/ffi/string.h"
16+
17+
namespace tvm {
18+
namespace tl {
19+
20+
using namespace tir;
21+
22+
/**
23+
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
24+
* map.
25+
*
26+
* This constructor deserializes operator parameters from `args` and resolves
27+
* buffer references via `vmap`, populating an internal GemmPyNode with:
28+
* - device pointers for A, B, C and their corresponding Buffer objects,
29+
* - transpose flags for A and B,
30+
* - matrix dimensions M, N, K,
31+
* - warp allocation policy and clear_accum flag,
32+
* - strides and memory offsets for A and B,
33+
* - optional kPack (must be 1 or 2) and optional wg_wait.
34+
*
35+
* The populated GemmPyNode is stored into the wrapper's internal `data_`.
36+
*
37+
* @param args Positional serialized arguments produced by the TL frontend:
38+
* expected layout is:
39+
* [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
40+
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
41+
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
42+
* (optional) kPack (Int), (optional) wg_wait (Int)]
43+
* @param vmap Mapping from access pointer vars to Buffer objects used to
44+
* resolve the Buffer corresponding to each pointer argument.
45+
*
46+
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
47+
* fails with an ICHECK (runtime assertion). No other validation is
48+
* performed here.
49+
*/
50+
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
51+
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
52+
53+
node->Aptr = args[0];
54+
node->Bptr = args[1];
55+
node->Cptr = args[2];
56+
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
57+
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
58+
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
59+
node->trans_A = args[3].as<Bool>().value();
60+
node->trans_B = args[4].as<Bool>().value();
61+
node->M = args[5].as<IntImm>().value()->value;
62+
node->N = args[6].as<IntImm>().value()->value;
63+
node->K = args[7].as<IntImm>().value()->value;
64+
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
65+
node->clear_accum = args[9].as<Bool>().value();
66+
node->stride_A = args[10].as<IntImm>().value()->value;
67+
node->stride_B = args[11].as<IntImm>().value()->value;
68+
node->offset_A = args[12].as<IntImm>().value()->value;
69+
node->offset_B = args[13].as<IntImm>().value()->value;
70+
if (args.size() > 14) {
71+
node->kPack = args[14].as<IntImm>().value()->value;
72+
if (node->kPack != 1 && node->kPack != 2) {
73+
ICHECK(false) << "kPack must be 1 or 2";
74+
}
75+
}
76+
if (args.size() > 15) {
77+
node->wg_wait = args[15].as<IntImm>().value()->value;
78+
}
79+
data_ = std::move(node);
80+
}
81+
82+
/**
83+
* @brief Create a copy of this GemmPyNode as a TileOperator.
84+
*
85+
* Constructs a new GemmPyNode by copying the current node state and returns it
86+
* wrapped in a Gemm TileOperator.
87+
*
88+
* @return TileOperator A Gemm operator that owns a copy of this node.
89+
*/
90+
TileOperator GemmPyNode::Clone() const {
91+
auto op = make_object<GemmPyNode>(*this);
92+
return GemmPy(op);
93+
}
94+
95+
GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
96+
Target target) const {
97+
int warp_size = TargetGetWarpSize(target);
98+
int num_warps = block_size / warp_size;
99+
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
100+
(num_warps % 4 == 0) && CheckWGMMA();
101+
if (allow_wgmma) {
102+
return GemmInst::kWGMMA;
103+
} else if (TargetIsCDNA(target)) {
104+
return GemmInst::kMFMA;
105+
} else if (TargetIsCuda(target)) {
106+
return GemmInst::kMMA;
107+
} else {
108+
ICHECK(0) << "Unsupported target for gemm: " << target->str();
109+
}
110+
}
111+
112+
/**
113+
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
114+
*
115+
* Evaluates device-memory placement, data-type combinations, transpose flags,
116+
* and K divisibility constraints required for the Hopper WGMMA code path.
117+
*
118+
* The check returns true only when:
119+
* - B resides in shared memory ("shared" or "shared.dyn"); and
120+
* - (C, A, B) dtypes match one of the supported combinations below and K
121+
* satisfies the required alignment; and
122+
* - for combinations that require specific orientations, A is not transposed
123+
* and B is transposed.
124+
*
125+
* Supported combinations and constraints:
126+
* - C=float16:
127+
* - A=float16, B=float16: K % 16 == 0
128+
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
129+
* 32 == 0
130+
* - C=float32:
131+
* - A=float16, B=float16: K % 16 == 0
132+
* - A=bfloat16, B=bfloat16: K % 16 == 0
133+
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
134+
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
135+
* - C=int32:
136+
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
137+
* and K % 32 == 0
138+
*
139+
* @return true if WGMMA is supported for the current buffers, dtypes, and
140+
* transpose/shape constraints; false otherwise.
141+
*/
142+
bool GemmPyNode::CheckWGMMA() const {
143+
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
144+
return false;
145+
}
146+
147+
if (C->dtype == DataType::Float(16)) {
148+
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
149+
return K % 16 == 0;
150+
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
151+
return (!trans_A) && trans_B && K % 32 == 0;
152+
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
153+
return (!trans_A) && trans_B && K % 32 == 0;
154+
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
155+
return (!trans_A) && trans_B && K % 32 == 0;
156+
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
157+
return (!trans_A) && trans_B && K % 32 == 0;
158+
else
159+
return false;
160+
} else if (C->dtype == DataType::Float(32)) {
161+
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
162+
return K % 16 == 0;
163+
else if (A->dtype == DataType::BFloat(16) &&
164+
B->dtype == DataType::BFloat(16))
165+
return K % 16 == 0;
166+
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
167+
return (!trans_A) && trans_B && K % 8 == 0;
168+
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
169+
return (!trans_A) && trans_B && K % 32 == 0;
170+
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
171+
return (!trans_A) && trans_B && K % 32 == 0;
172+
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
173+
return (!trans_A) && trans_B && K % 32 == 0;
174+
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
175+
return (!trans_A) && trans_B && K % 32 == 0;
176+
else
177+
return false;
178+
} else if (C->dtype == DataType::Int(32)) {
179+
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
180+
return (!trans_A) && trans_B && K % 32 == 0;
181+
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
182+
return (!trans_A) && trans_B && K % 32 == 0;
183+
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
184+
return (!trans_A) && trans_B && K % 32 == 0;
185+
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
186+
return (!trans_A) && trans_B && K % 32 == 0;
187+
else
188+
return false;
189+
} else {
190+
return false;
191+
}
192+
}
193+
194+
/**
195+
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
196+
* attribute.
197+
*
198+
* Examines the target's "arch" string and, if it matches the pattern
199+
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
200+
* match that pattern, returns 0.
201+
*
202+
* Preconditions: the target must have an "arch" attribute (this is checked via
203+
* ICHECK).
204+
*
205+
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
206+
* the arch string does not match "sm_<num>".
207+
*/
208+
static int GetArchInt(Target target) {
209+
int arch_int = 0;
210+
auto s = target->GetAttr<String>("arch");
211+
ICHECK(s.defined());
212+
std::string arch = s.value();
213+
if (arch.rfind("sm_", 0) == 0) {
214+
arch_int = std::stoi(arch.substr(3));
215+
} else {
216+
arch_int = 0;
217+
}
218+
return arch_int;
219+
}
220+
221+
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
222+
auto block_size = *as_const_int(T.thread_bounds->extent);
223+
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
224+
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
225+
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
226+
227+
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
228+
auto prim_func = Downcast<PrimFunc>(
229+
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
230+
ICHECK(prim_func->attrs.defined());
231+
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
232+
ICHECK(global_symbol.defined());
233+
if (prim_func->body.as<BlockRealizeNode>()) {
234+
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
235+
auto block = block_realize->block;
236+
{
237+
BlockNode *n = block.CopyOnWrite();
238+
n->name_hint = global_symbol.value();
239+
}
240+
return BlockRealize(block_realize->iter_values, block_realize->predicate,
241+
block);
242+
}
243+
// warp with block realize node
244+
return BlockRealize(
245+
/*iter_values=*/Array<PrimExpr>(),
246+
/*predicate=*/const_true(),
247+
/*block=*/
248+
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
249+
/*name_hint=*/global_symbol.value(), prim_func->body));
250+
} else {
251+
LOG(FATAL) << "No lower function found for gemm_py";
252+
}
253+
}
254+
255+
LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
256+
InferLevel level) const {
257+
if (completed_)
258+
return {};
259+
LayoutMap results;
260+
261+
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
262+
results = Downcast<LayoutMap>(
263+
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds));
264+
} else {
265+
LOG(FATAL) << "No infer layout function found for gemm_py";
266+
}
267+
268+
completed_ = true;
269+
return results;
270+
}
271+
272+
TIR_REGISTER_TL_OP(GemmPy, gemm_py)
273+
.set_num_inputs(5)
274+
.set_attr<TCallEffectKind>("TCallEffectKind",
275+
Integer(CallEffectKind::kOpaque));
276+
277+
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });
278+
} // namespace tl
279+
} // namespace tvm

0 commit comments

Comments
 (0)