Skip to content

Commit a3f2564

Browse files
committed
Refactor GEMM layout and testing for improved clarity and functionality
- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations. - Improved block realization handling in `gemm_py.cc` for better assignment of global symbols. - Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity. - Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations. These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
1 parent e36740d commit a3f2564

File tree

15 files changed

+599
-362
lines changed

15 files changed

+599
-362
lines changed

src/layout/gemm_layouts.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
205205
ICHECK(block_k % 16 == 0);
206206
if (transposed) {
207207
auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false);
208-
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({block_n / warp_n, 1}, true, false);
208+
auto warp_layout = base_layout->Replicate(block_m / warp_m)
209+
->Repeat({block_n / warp_n, 1}, true, false);
209210
auto block_layout =
210211
warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
211212
return block_layout;
212213
} else {
213214
auto base_layout =
214215
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
215-
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true);
216+
auto warp_layout = base_layout->Replicate(block_m / warp_m)
217+
->Repeat({1, block_n / warp_n}, true);
216218
auto block_layout =
217219
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
218220
return block_layout;

src/op/copy.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
402402
PassContext pass_ctx = PassContext::Current();
403403
bool disable_tma_lower =
404404
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
405-
406405
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
407406
T.layout_map, T.analyzer, T.buffer_oob);
408407
if (copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) {

src/op/gemm.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,20 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
244244
int best_m = 1;
245245
int best_n = 1;
246246
float best_balance = std::numeric_limits<float>::max();
247-
248247
// Try all possible combinations that satisfy the constraints
249248
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
250249
int n = num_warps / m;
251250

252251
// Calculate how balanced this partition is
253252
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
254253
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
254+
// m_per_warp and n_per_warp must be greater than 1
255+
if (m_per_warp < 1 || n_per_warp < 1)
256+
continue;
257+
// m * n must equal num_warps
258+
if (m * n != num_warps)
259+
continue;
260+
255261
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
256262

257263
if (balance < best_balance) {
@@ -266,7 +272,6 @@ GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
266272
} else {
267273
ICHECK(0) << "Unknown GemmWarpPolicy";
268274
}
269-
270275
// Store the computed values in the object's member variables
271276
this->m_warp = m_warp;
272277
this->n_warp = n_warp;

src/op/gemm_py.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,19 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
234234
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
235235
auto block = block_realize->block;
236236
{
237-
BlockNode* n = block.CopyOnWrite();
237+
BlockNode *n = block.CopyOnWrite();
238238
n->name_hint = global_symbol.value();
239239
}
240-
return BlockRealize(block_realize->iter_values, block_realize->predicate, block);
240+
return BlockRealize(block_realize->iter_values, block_realize->predicate,
241+
block);
241242
}
242243
// warp with block realize node
243244
return BlockRealize(
244-
/*iter_values=*/Array<PrimExpr>(),
245-
/*predicate=*/const_true(),
246-
/*block=*/
247-
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
248-
/*name_hint=*/global_symbol.value(), prim_func->body));
245+
/*iter_values=*/Array<PrimExpr>(),
246+
/*predicate=*/const_true(),
247+
/*block=*/
248+
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
249+
/*name_hint=*/global_symbol.value(), prim_func->body));
249250
} else {
250251
LOG(FATAL) << "No lower function found for gemm_py";
251252
}

src/target/codegen_cuda.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,16 +1331,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13311331
os << "}\n";
13321332
} else {
13331333
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
1334-
// need_cast_smem_ptr_to_int_ = true;
1335-
// this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
1336-
// local_elem_offset, smem_ptr,
1337-
// smem_elem_offset);
13381334
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
13391335
if (trans == 1)
13401336
func_name += "_trans";
1341-
// this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n";
13421337
this->PrintIndent();
1343-
this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n";
1338+
this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset
1339+
<< ", " << local_ptr << " + " << local_elem_offset << ");\n";
13441340
}
13451341
} else if (op->op.same_as(builtin::mma_store())) {
13461342
int m = Downcast<Integer>(op->args[0])->value;

src/transform/inject_pipeline.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -951,8 +951,9 @@ class PipelineInjector : private StmtExprMutator {
951951

952952
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
953953

954-
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
955-
BlockNode* n = block.CopyOnWrite();
954+
Array<Array<BufferRegion>> access =
955+
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
956+
BlockNode *n = block.CopyOnWrite();
956957
n->reads = access[0];
957958
n->writes = access[1];
958959

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from asyncio import threads
21
from tilelang import tvm as tvm
32
import tilelang.testing
43

@@ -90,7 +89,9 @@ def run_gemm_ss(
9089
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
9190
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
9291
})
93-
profiler = kernel.get_profiler()
92+
93+
print(kernel.get_kernel_source())
94+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
9495

9596
def ref_program(A, B):
9697
import torch
@@ -109,11 +110,21 @@ def ref_program(A, B):
109110
def test_gemm_ss():
110111
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
111112
# GEMM tests for float16
112-
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
113-
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
114-
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
115-
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
116-
113+
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2)
114+
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
115+
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2)
116+
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2)
117+
# n8 test
118+
run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
119+
120+
# int8 test
121+
run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
122+
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
123+
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
124+
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
125+
126+
# float8 tests
127+
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
117128

118129

119130
def matmul_rs(
@@ -208,7 +219,7 @@ def run_gemm_rs(
208219
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
209220
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
210221
})
211-
profiler = kernel.get_profiler()
222+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
212223

213224
def ref_program(A, B):
214225
import torch
@@ -226,8 +237,22 @@ def ref_program(A, B):
226237

227238
def test_gemm_rs():
228239
# GEMM tests for float16
229-
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 0)
230-
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 0)
240+
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
241+
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
242+
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
243+
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
244+
245+
# n8 tests
246+
run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
247+
248+
# int8 tests
249+
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
250+
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
251+
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
252+
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
253+
254+
# float8 tests
255+
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
231256

232257

233258
def matmul_sr(
@@ -322,7 +347,7 @@ def run_gemm_sr(
322347
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
323348
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
324349
})
325-
profiler = kernel.get_profiler()
350+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
326351

327352
def ref_program(A, B):
328353
import torch
@@ -345,6 +370,18 @@ def test_gemm_sr():
345370
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
346371
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
347372

373+
# n8 tests
374+
run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
375+
376+
# int8 tests
377+
run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
378+
run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
379+
run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
380+
run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
381+
382+
# float8 tests
383+
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
384+
348385

349386
def matmul_rr(
350387
M,
@@ -442,7 +479,7 @@ def run_gemm_rr(
442479
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
443480
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
444481
})
445-
profiler = kernel.get_profiler()
482+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
446483

447484
def ref_program(A, B):
448485
import torch
@@ -465,40 +502,20 @@ def test_gemm_rr():
465502
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
466503
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
467504
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
505+
# n8 tests
506+
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2)
507+
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2)
508+
509+
# int8 tests
510+
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
511+
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
512+
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
513+
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
514+
515+
# float8 tests
516+
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
468517

469518

470519
if __name__ == "__main__":
471520
# tilelang.testing.main()
472-
tilelang.disable_cache()
473-
# test_gemm_ss()
474-
# test_gemm_sr()
475-
# test_gemm_rs()
476-
# test_gemm_rr()
477-
478-
# run_gemm_sr(128, 128, 128, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
479-
# tilelang.testing.set_random_seed(42)
480-
run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
481-
# print("gemm fp16 nt ss done")
482-
# exit()
483-
484-
# run_gemm_rs(128, 128, 32, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
485-
# print("gemm fp16 nt rs done")
486-
# run_gemm_rs(128, 128, 32, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
487-
# print("gemm fp16 nn rs done")
488-
# run_gemm_rs(128, 128, 32, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
489-
# print("gemm fp16 tn rs done")
490-
# run_gemm_rs(128, 128, 32, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
491-
# print("gemm fp16 tt rs done")
492-
493-
# run_gemm_rs(16, 16, 16, True, False, "float16", "float16", "float16", 16, 16, 16, 0, 32)
494-
495-
# run_gemm_rr(128, 128, 32, False, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
496-
# print("gemm bf16 nn rr done")
497-
# run_gemm_rr(128, 128, 32, False, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
498-
# print("gemm bf16 nt rr done")
499-
# run_gemm_rr(128, 128, 32, True, False, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
500-
# print("gemm bf16 tn rr done")
501-
# run_gemm_rr(128, 128, 32, True, True, "bfloat16", "bfloat16", "float", 128, 128, 32, 0)
502-
# print("gemm bf16 tt rr done")
503-
504-
521+
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)

tilelang/intrinsics/mma_layout.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,49 @@ def shared_16x16_to_mma_a_32x8_layout(i, j):
5252
thread_id = 4 * (i % 8) + (j % 8) // 2
5353
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
5454

55+
5556
def shared_16x16_to_mma_a_32x8_layout_trans(i, j):
5657
return shared_16x16_to_mma_a_32x8_layout(j, i)
5758

59+
5860
# mma.sync matrix B layout, if wanna trans, please apply map_indices
5961
def shared_16x16_to_mma_b_32x8_layout(i, j):
6062
thread_id = 4 * (i % 8) + (j % 8) // 2
6163
return thread_id, 4 * (i // 8) + (j // 8) * 2 + (j % 2)
6264

65+
6366
def shared_16x16_to_mma_b_32x8_layout_trans(i, j):
6467
return shared_16x16_to_mma_b_32x8_layout(j, i)
6568

69+
6670
shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout
6771
shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout
6872
shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans
6973
shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans
7074

7175

72-
def shared_16x32_to_mma_32x16_layout(i, j):
76+
def shared_16x32_to_mma_a_32x16_layout(i, j):
7377
thread_id = 4 * (i % 8) + (j % 16) // 4
7478
return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
7579

7680

77-
def shared_32x16_to_mma_32x16_layout(i, j):
78-
thread_id = (i % 16) // 4 + 4 * (j % 8)
79-
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
81+
def shared_32x16_to_mma_a_32x16_layout_trans(i, j):
82+
return shared_16x32_to_mma_a_32x16_layout(j, i)
83+
84+
85+
def shared_16x32_to_mma_b_32x16_layout(i, j):
86+
thread_id = 4 * (i % 8) + (j % 16) // 4
87+
return thread_id, 8 * (i // 8) + (j // 16) * 4 + j % 4
88+
89+
90+
def shared_32x16_to_mma_b_32x16_layout_trans(i, j):
91+
return shared_16x32_to_mma_b_32x16_layout(j, i)
92+
93+
94+
shared_16x32_to_mma_32x16_layout_sr_a = shared_16x32_to_mma_a_32x16_layout
95+
shared_16x32_to_mma_32x16_layout_sr_b = shared_16x32_to_mma_b_32x16_layout
96+
shared_16x32_to_mma_32x16_layout_rs_a = shared_32x16_to_mma_a_32x16_layout_trans
97+
shared_16x32_to_mma_32x16_layout_rs_b = shared_32x16_to_mma_b_32x16_layout_trans
8098

8199

82100
def mma_32x8_to_shared_16x16_layout(thread_id, local_id):
@@ -85,6 +103,18 @@ def mma_32x8_to_shared_16x16_layout(thread_id, local_id):
85103
return row, col
86104

87105

106+
def mma_load_a_32x16_to_shared_16x32_layout(thread_id, local_id):
107+
row = 8 * (local_id % 8 // 4) + (thread_id // 4)
108+
col = 16 * (local_id // 8) + (thread_id % 4) * 4 + (local_id % 4)
109+
return row, col
110+
111+
112+
def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
113+
row = 8 * (local_id // 8) + (thread_id // 4)
114+
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
115+
return row, col
116+
117+
88118
def shared_16x16_to_mma_32x8_smoothlayout(i, j):
89119
return (i * 2 + j // 8, j % 8)
90120

0 commit comments

Comments
 (0)