Skip to content

Commit e36740d

Browse files
committed
Refactor GEMM layout and Python integration for improved functionality
- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations. - Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes. - Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability. - Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow. These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
1 parent 1ab46ef commit e36740d

File tree

5 files changed

+29
-28
lines changed

5 files changed

+29
-28
lines changed

src/layout/gemm_layouts.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
205205
ICHECK(block_k % 16 == 0);
206206
if (transposed) {
207207
auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false);
208-
auto warp_layout = base_layout->Repeat({block_n / warp_n, 1}, true, false)->Replicate(block_m / warp_m);
208+
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({block_n / warp_n, 1}, true, false);
209209
auto block_layout =
210210
warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
211211
return block_layout;
212212
} else {
213213
auto base_layout =
214214
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
215-
auto warp_layout = base_layout->Repeat({1, block_n / warp_n}, true)->Replicate(block_m / warp_m);
215+
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true);
216216
auto block_layout =
217217
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
218218
return block_layout;

src/op/gemm_py.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,16 +227,25 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
227227
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
228228
auto prim_func = Downcast<PrimFunc>(
229229
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var));
230-
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
231230
ICHECK(prim_func->attrs.defined());
232231
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
233232
ICHECK(global_symbol.defined());
234-
auto block = block_realize->block;
235-
{
236-
BlockNode* n = block.CopyOnWrite();
237-
n->name_hint = global_symbol.value();
233+
if (prim_func->body.as<BlockRealizeNode>()) {
234+
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
235+
auto block = block_realize->block;
236+
{
237+
BlockNode* n = block.CopyOnWrite();
238+
n->name_hint = global_symbol.value();
239+
}
240+
return BlockRealize(block_realize->iter_values, block_realize->predicate, block);
238241
}
239-
return BlockRealize(block_realize->iter_values, block_realize->predicate, block);
242+
// warp with block realize node
243+
return BlockRealize(
244+
/*iter_values=*/Array<PrimExpr>(),
245+
/*predicate=*/const_true(),
246+
/*block=*/
247+
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
248+
/*name_hint=*/global_symbol.value(), prim_func->body));
240249
} else {
241250
LOG(FATAL) << "No lower function found for gemm_py";
242251
}

src/transform/inject_pipeline.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ class PipelineRewriter : public StmtExprMutator {
248248
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
249249
}
250250
}
251-
252251
ordered_stmts_.resize(pipeline_info_.size());
253252
for (const auto &[block, anno] : pipeline_info_) {
254253
ordered_stmts_.Set(anno.order, block);
@@ -676,11 +675,6 @@ class PipelineRewriter : public StmtExprMutator {
676675
new_block = Downcast<Block>(Substitute(
677676
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
678677

679-
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
680-
BlockNode* n = new_block.CopyOnWrite();
681-
n->reads = access[0];
682-
n->writes = access[1];
683-
684678
if (pipeline_info_[block].async) {
685679
auto &local_state = async_states_local[stage];
686680
local_state.producer_head = normalized_access_index;
@@ -957,6 +951,11 @@ class PipelineInjector : private StmtExprMutator {
957951

958952
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
959953

954+
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
955+
BlockNode* n = block.CopyOnWrite();
956+
n->reads = access[0];
957+
n->writes = access[1];
958+
960959
for (const auto &buffer : op->alloc_buffers) {
961960
buffer_data_to_buffer_.erase(buffer->data);
962961
}

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def run_gemm_ss(
9090
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
9191
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
9292
})
93-
print(kernel.get_kernel_source())
9493
profiler = kernel.get_profiler()
9594

9695
def ref_program(A, B):
@@ -209,7 +208,6 @@ def run_gemm_rs(
209208
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
210209
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
211210
})
212-
print(kernel.get_kernel_source())
213211
profiler = kernel.get_profiler()
214212

215213
def ref_program(A, B):
@@ -280,10 +278,7 @@ def main(
280278
else:
281279
T.copy(B[k * block_K, bx * block_N], B_shared)
282280
T.copy(B_shared, B_frag)
283-
# for i, j in T.Parallel(block_N, block_K):
284-
# B_frag[i, j] = B_shared[j, i]
285-
# T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
286-
T.gemm(A_shared, B_frag, C_local, trans_A, trans_B)
281+
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
287282
T.copy(C_local, C[by * block_M, bx * block_N])
288283

289284
return main
@@ -327,7 +322,6 @@ def run_gemm_sr(
327322
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
328323
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
329324
})
330-
print(kernel.get_kernel_source())
331325
profiler = kernel.get_profiler()
332326

333327
def ref_program(A, B):
@@ -448,7 +442,6 @@ def run_gemm_rr(
448442
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
449443
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
450444
})
451-
print(kernel.get_kernel_source())
452445
profiler = kernel.get_profiler()
453446

454447
def ref_program(A, B):
@@ -478,9 +471,13 @@ def test_gemm_rr():
478471
# tilelang.testing.main()
479472
tilelang.disable_cache()
480473
# test_gemm_ss()
481-
run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
474+
# test_gemm_sr()
475+
# test_gemm_rs()
476+
# test_gemm_rr()
477+
478+
# run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
482479
# tilelang.testing.set_random_seed(42)
483-
# run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
480+
run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
484481
# print("gemm fp16 nt ss done")
485482
# exit()
486483

tilelang/engine/phase.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
140140
mod = tilelang.transform.IfStmtBinding()(mod)
141141
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
142142
mod = tilelang.transform.PipelinePlanning()(mod)
143-
print("after pipeline planning")
144-
print(mod)
145143
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
146-
print("after inject software pipeline")
147-
print(mod)
148144
mod = tilelang.transform.MergeIfStmt()(mod)
149145
if allow_fence_proxy(target=target):
150146
# in hopper device, wgmma is an async proxy

0 commit comments

Comments
 (0)