Skip to content

Commit 00b2fd8

Browse files
committed
Refactor test for InjectSetMaxNReg pass in TileLang
- Improved readability by restructuring conditional checks and assertions in the test cases. - Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic. - Ensured consistent formatting and spacing throughout the test functions for better maintainability.
1 parent b268e47 commit 00b2fd8

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

src/transform/annotate_warp_group_reg_alloc.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
#include <unordered_set>
1010
#include <vector>
1111

12-
#include "tir/transforms/ir_utils.h"
1312
#include "../op/builtin.h"
13+
#include "tir/transforms/ir_utils.h"
1414

1515
namespace tvm {
1616
namespace tl {
1717

1818
using namespace tir;
1919

2020
class SetMaxNRegCollector : public StmtExprVisitor {
21-
public:
21+
public:
2222
static Array<IntImm> Collect(const PrimFunc &f) {
2323
SetMaxNRegCollector collector;
2424
collector(f->body);
@@ -28,7 +28,7 @@ class SetMaxNRegCollector : public StmtExprVisitor {
2828
: collector.nreg_;
2929
}
3030

31-
private:
31+
private:
3232
void VisitStmt_(const EvaluateNode *op) final {
3333
if (const CallNode *call = op->value.as<CallNode>()) {
3434
if (call->op.same_as(set_max_nreg())) {
@@ -54,15 +54,15 @@ class SetMaxNRegCollector : public StmtExprVisitor {
5454
};
5555

5656
class SetMaxNRegInjector : public StmtExprMutator {
57-
public:
57+
public:
5858
static PrimFunc Inject(PrimFunc f) {
5959
auto T = SetMaxNRegInjector();
6060
T.nreg_ = SetMaxNRegCollector::Collect(f);
6161
f.CopyOnWrite()->body = T(f->body);
6262
return f;
6363
}
6464

65-
private:
65+
private:
6666
Stmt VisitStmt_(const EvaluateNode *op) final {
6767
if (const CallNode *call = op->value.as<CallNode>()) {
6868
if (call->op.same_as(set_max_nreg()) ||
@@ -97,7 +97,6 @@ class SetMaxNRegInjector : public StmtExprMutator {
9797
Optional<Stmt> consumer_body = if_then_else->else_case;
9898
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
9999

100-
101100
int dec_reg = nreg_[0].as<IntImmNode>()->value;
102101
int inc_reg = nreg_[1].as<IntImmNode>()->value;
103102

@@ -107,13 +106,13 @@ class SetMaxNRegInjector : public StmtExprMutator {
107106
// Only inject if we have valid register hints and no SIMT copy
108107
// For now, we assume no SIMT copy detection is available here
109108
// TODO: Add SIMT copy detection if needed
110-
bool has_simt_copy = false; // Placeholder
109+
bool has_simt_copy = false; // Placeholder
111110

112111
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
113112
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
114-
{inc_reg == 0 ? 240 : inc_reg, 1}));
113+
{inc_reg == 0 ? 240 : inc_reg, 1}));
115114
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
116-
{dec_reg == 0 ? 24 : dec_reg, 0}));
115+
{dec_reg == 0 ? 24 : dec_reg, 0}));
117116
}
118117

119118
// Inject register setting statements
@@ -127,7 +126,8 @@ class SetMaxNRegInjector : public StmtExprMutator {
127126
consumer_stmts.push_back(consumer_body.value());
128127
auto new_consumer_body = SeqStmt(consumer_stmts);
129128

130-
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body, new_consumer_body);
129+
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
130+
new_consumer_body);
131131
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
132132

133133
return new_attr;
@@ -153,8 +153,9 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
153153

154154
TVM_FFI_STATIC_INIT_BLOCK({
155155
namespace refl = tvm::ffi::reflection;
156-
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc", AnnotateWarpGroupRegAlloc);
156+
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
157+
AnnotateWarpGroupRegAlloc);
157158
});
158159

159-
} // namespace tl
160-
} // namespace tvm
160+
} // namespace tl
161+
} // namespace tvm

testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from tilelang import tvm as tvm
22
import tilelang as tl
3-
from tilelang.utils.target import determine_target
43
import tilelang.language as T
54
import tilelang.testing
65
from tvm import tir
76

87
tilelang.disable_cache()
98

9+
1010
def test_inject_set_max_nreg():
1111
"""Test the InjectSetMaxNReg pass"""
1212

@@ -37,21 +37,26 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16"
3737
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
3838
if v - 128 == 0:
3939
T.tma_load(
40-
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
41-
T.get_mbarrier(k % 3),
42-
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
40+
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1,
41+
0, 2, 2, 0), T.get_mbarrier(k % 3),
42+
T.tvm_access_ptr(
43+
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
4344
k * 32, by * 64)
44-
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
45+
T.evaluate(
46+
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
4547
else:
4648
# Consumer branch - should have set_max_nreg(240, 1)
4749
for k in range(16):
4850
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
4951
T.call_extern(
5052
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
51-
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
52-
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
53+
T.tvm_access_ptr(
54+
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
55+
T.tvm_access_ptr(
56+
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
5357
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
54-
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
58+
T.evaluate(
59+
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
5560

5661
# Apply the InjectSetMaxNReg pass
5762
func = before
@@ -64,15 +69,15 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16"
6469
set_max_nreg_calls = []
6570

6671
def collect_set_max_nreg(stmt):
67-
if isinstance(stmt, tvm.tir.Evaluate):
68-
if hasattr(stmt.value, 'op') and hasattr(stmt.value.op, 'name'):
69-
if stmt.value.op.name == "tl.set_max_nreg":
70-
set_max_nreg_calls.append((stmt.value.args[0].value, stmt.value.args[1].value))
72+
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
73+
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
74+
set_max_nreg_calls.append(stmt.value)
7175

7276
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
7377

7478
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
75-
assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
79+
assert len(set_max_nreg_calls
80+
) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
7681

7782
# Check that we have the expected register values
7883
reg_values = [call[0] for call in set_max_nreg_calls]
@@ -118,15 +123,16 @@ def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
118123
set_max_nreg_calls = []
119124

120125
def collect_set_max_nreg(stmt):
121-
if isinstance(stmt, tvm.tir.Evaluate):
122-
if hasattr(stmt.value, 'op') and hasattr(stmt.value.op, 'name'):
123-
if stmt.value.op.name == "tl.set_max_nreg":
124-
set_max_nreg_calls.append(stmt.value)
126+
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
127+
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
128+
set_max_nreg_calls.append(stmt.value)
125129

126130
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
127131

128132
# Should have no set_max_nreg calls when no_set_max_nreg is present
129-
assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
133+
assert len(
134+
set_max_nreg_calls
135+
) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
130136

131137
print("InjectSetMaxNReg with no_set_max_nreg test passed!")
132138

0 commit comments

Comments
 (0)