Skip to content

Commit 1f60529

Browse files
authored
[Hexagon] Resolve breakage in test_hexagon/test_cache_read_write (#10520)
* [Hexagon] Resolve breakage in test_hexagon/test_cache_read_write Breakage was caused by #9727, which didn't account for the new `builtin::mem_copy()` when computing the stack size in `StackSizeChecker`. * Added comment indicating need for StackSizeChecker::MakeMemCopy. * Updated unittests to run all contrib/test_hexagon at CI. * CI bump * Fix lint formatting error. * Updated fix to remove StackSizeChecker entirely. * Bugfix, verify the precheck's allocations, not own. * Bugfix, pass context information to the precheck.
1 parent 0fa3540 commit 1f60529

File tree

4 files changed

+96
-158
lines changed

4 files changed

+96
-158
lines changed

src/tir/transforms/lower_tvm_builtin.cc

Lines changed: 90 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
namespace tvm {
3535
namespace tir {
3636

37-
class StackSizeChecker : public StmtExprVisitor {
37+
// Calculate the statistics of packed function.
38+
// These information are needed during codegen.
39+
class BuiltinLower : public StmtExprMutator {
3840
public:
3941
struct StackSizes {
4042
// If a tvm_stack_make_shape call has no arguments, it is still
@@ -46,159 +48,86 @@ class StackSizeChecker : public StmtExprVisitor {
4648
uint64_t arg_stack{0};
4749
};
4850

49-
static StackSizes Check(Stmt stmt) {
50-
StackSizeChecker visitor;
51-
visitor.VisitStmt(stmt);
52-
return visitor.max_stack_;
53-
}
54-
55-
private:
56-
void VisitStmt_(const ForNode* op) final {
57-
if (op->kind == ForKind::kParallel) {
58-
// Parallel for loops have their own stack and allocations, so
59-
// stop the recursion here.
60-
return;
61-
} else {
62-
this->VisitStmt(op->body);
63-
}
64-
}
65-
void VisitExpr_(const CallNode* op) final {
66-
if (op->op.same_as(builtin::tvm_call_packed())) {
67-
return MakeCallPacked(op, /* use_string_lookup */ true);
68-
} else if (op->op.same_as(builtin::tvm_call_cpacked())) {
69-
return MakeCallPacked(op, /* use_string_lookup */ false);
70-
} else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
71-
return MakeCallTracePacked(op);
72-
} else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
73-
return MakeShape(op);
74-
} else if (op->op.same_as(builtin::tvm_stack_make_array())) {
75-
return MakeArray(op);
76-
} else {
77-
return StmtExprVisitor::VisitExpr_(op);
78-
}
79-
}
80-
// call shape
81-
void MakeShape(const CallNode* op) {
82-
// if args.size() == 0, it is still valid and represents a scalar
83-
// shape (). Therefore, -1 is used to represent "no shape
84-
// arguments exist", while 0 represents "shape arguments exist,
85-
// all of which are size 0".
86-
if (current_stack_.shape_stack == -1) {
87-
current_stack_.shape_stack = 0;
88-
}
89-
current_stack_.shape_stack += op->args.size();
90-
StmtExprVisitor::VisitExpr_(op);
91-
}
92-
// make array
93-
void MakeArray(const CallNode* op) {
94-
current_stack_.array_stack += 1;
95-
StmtExprVisitor::VisitExpr_(op);
96-
}
97-
// call packed.
98-
void MakeCallPacked(const CallNode* op, bool use_string_lookup) {
99-
StackSizes restore_stack = current_stack_;
100-
101-
size_t arg_count = op->args.size();
102-
103-
// cpacked expects a resource_handle parameter
104-
if (!use_string_lookup) {
105-
arg_count--;
106-
}
107-
108-
current_stack_.arg_stack += arg_count;
109-
// Specially handle the buffer packed intrinsic
110-
StmtExprVisitor::VisitExpr_(op);
111-
// Record the amount of stack space needed, then reset the stack
112-
// position to its previous location.
113-
UpdateMaxStack();
114-
current_stack_ = restore_stack;
115-
}
116-
117-
void MakeCallTracePacked(const CallNode* op) {
118-
StackSizes restore_stack = current_stack_;
119-
120-
size_t args_size = op->args.size();
121-
ICHECK_GT(args_size, 0);
122-
current_stack_.arg_stack += args_size;
123-
124-
StmtExprVisitor::VisitExpr_(op);
125-
// Record the amount of stack space needed, then reset the stack
126-
// position to its previous location.
127-
UpdateMaxStack();
128-
current_stack_ = restore_stack;
129-
130-
// However, the arguments to this CallNode remain on top of the
131-
// stack, so we can use more than one packed function's arguments
132-
// with the one stack.
133-
current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1;
134-
}
135-
136-
void UpdateMaxStack() {
137-
max_stack_.arg_stack = std::max(current_stack_.arg_stack, max_stack_.arg_stack);
138-
max_stack_.shape_stack = std::max(current_stack_.shape_stack, max_stack_.shape_stack);
139-
max_stack_.array_stack = std::max(current_stack_.array_stack, max_stack_.array_stack);
140-
}
141-
142-
StackSizes current_stack_;
143-
StackSizes max_stack_;
144-
};
145-
146-
// Calculate the statistics of packed function.
147-
// These information are needed during codegen.
148-
class BuiltinLower : public StmtExprMutator {
149-
public:
15051
// Record stack frame for existing scope.
15152
struct AllocaScope {
15253
Buffer stack_shape;
15354
Var stack_array = Var("stack_array", DataType::Handle());
15455
Var stack_value = Var("stack_value", DataType::Handle());
15556
Buffer stack_tcode;
15657

157-
int64_t max_shape_stack{-1};
158-
uint64_t max_array_stack{0};
159-
uint64_t max_arg_stack{0};
58+
StackSizes max_sizes;
59+
StackSizes run_sizes;
16060

161-
int64_t run_shape_stack{-1};
162-
uint64_t run_array_stack{0};
163-
uint64_t run_arg_stack{0};
61+
void UpdateMax() {
62+
max_sizes.shape_stack = std::max(max_sizes.shape_stack, run_sizes.shape_stack);
63+
max_sizes.array_stack = std::max(max_sizes.array_stack, run_sizes.array_stack);
64+
max_sizes.arg_stack = std::max(max_sizes.arg_stack, run_sizes.arg_stack);
65+
}
66+
67+
void AssertMaxIsValid() const {
68+
ICHECK((max_sizes.shape_stack >= run_sizes.shape_stack) ||
69+
(max_sizes.array_stack >= run_sizes.array_stack) ||
70+
(max_sizes.arg_stack >= run_sizes.arg_stack));
71+
}
16472
};
16573

16674
Stmt Build(Stmt stmt) { return this->VisitBodyAndRealizeAlloca(stmt); }
16775

76+
StackSizes GetMaxStack(Stmt stmt) {
77+
BuiltinLower precheck;
78+
precheck.is_precheck_ = true;
79+
precheck.device_id_ = this->device_id_;
80+
precheck.device_type_ = this->device_type_;
81+
82+
precheck.alloca_scope_.emplace_back();
83+
auto& scope = precheck.alloca_scope_.back();
84+
scope.stack_shape =
85+
decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape");
86+
scope.stack_tcode =
87+
decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode");
88+
89+
precheck.VisitStmt(stmt);
90+
91+
ICHECK_EQ(precheck.alloca_scope_.size(), 1);
92+
return precheck.alloca_scope_[0].max_sizes;
93+
}
94+
16895
// Allcoate stack frames, only at parallel-for or root.
16996
Stmt VisitBodyAndRealizeAlloca(Stmt stmt) {
170-
// Initial check to identify maximum stack sizes. These are used
171-
// to construct Buffer objects to hold the stack, which are then
172-
// used when mutating.
173-
auto max_sizes = StackSizeChecker::Check(stmt);
97+
// Only perform the precheck up to the point where we would add a
98+
// new scope.
99+
if (is_precheck_) {
100+
return stmt;
101+
}
174102

175103
alloca_scope_.emplace_back();
176104
auto& scope = alloca_scope_.back();
177105

178-
if (max_sizes.shape_stack != -1) {
179-
scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), max_sizes.shape_stack)},
106+
// Initial check to identify maximum stack sizes. These are used
107+
// to construct Buffer objects to hold the stack, which are then
108+
// used when mutating.
109+
scope.max_sizes = GetMaxStack(stmt);
110+
111+
if (scope.max_sizes.shape_stack != -1) {
112+
scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)},
180113
DataType::Int(64), "stack_shape");
181-
stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", max_sizes.shape_stack), stmt);
114+
stmt =
115+
LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt);
182116
}
183117

184-
if (max_sizes.array_stack != 0) {
185-
stmt = LetStmt(scope.stack_array, StackAlloca("array", max_sizes.array_stack), stmt);
118+
if (scope.max_sizes.array_stack != 0) {
119+
stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt);
186120
}
187121

188-
if (max_sizes.arg_stack != 0) {
189-
scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), max_sizes.arg_stack)},
122+
if (scope.max_sizes.arg_stack != 0) {
123+
scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)},
190124
DataType::Int(32), "stack_tcode");
191-
stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", max_sizes.arg_stack), stmt);
125+
stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt);
192126

193-
stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", max_sizes.arg_stack), stmt);
127+
stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack),
128+
stmt);
194129
}
195130

196-
// Copy these values from the earlier search, for use in bounds
197-
// checks.
198-
scope.max_shape_stack = max_sizes.shape_stack;
199-
scope.max_array_stack = max_sizes.array_stack;
200-
scope.max_arg_stack = max_sizes.arg_stack;
201-
202131
stmt = this->VisitStmt(stmt);
203132

204133
ICHECK(!alloca_scope_.empty());
@@ -213,8 +142,8 @@ class BuiltinLower : public StmtExprMutator {
213142

214143
auto stmt = StmtExprMutator::VisitStmt(s);
215144
auto& scope = alloca_scope_.back();
216-
ICHECK_EQ(scope.run_shape_stack, -1);
217-
ICHECK_EQ(scope.run_array_stack, 0);
145+
ICHECK_EQ(scope.run_sizes.shape_stack, -1);
146+
ICHECK_EQ(scope.run_sizes.array_stack, 0);
218147

219148
auto prep_seq = std::move(prep_seq_stack_.back());
220149
prep_seq_stack_.pop_back();
@@ -364,11 +293,11 @@ class BuiltinLower : public StmtExprMutator {
364293
ICHECK(!alloca_scope_.empty());
365294
auto& scope = alloca_scope_.back();
366295
auto& prep_seq = prep_seq_stack_.back();
367-
if (scope.run_shape_stack == -1) {
368-
scope.run_shape_stack = 0;
296+
if (scope.run_sizes.shape_stack == -1) {
297+
scope.run_sizes.shape_stack = 0;
369298
}
370-
int64_t stack_begin = scope.run_shape_stack;
371-
scope.run_shape_stack += op->args.size();
299+
int64_t stack_begin = scope.run_sizes.shape_stack;
300+
scope.run_sizes.shape_stack += op->args.size();
372301
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
373302
op = expr.as<CallNode>();
374303
// no need to perform any store for a scalar shape
@@ -384,8 +313,8 @@ class BuiltinLower : public StmtExprMutator {
384313
auto& scope = alloca_scope_.back();
385314
auto& prep_seq = prep_seq_stack_.back();
386315

387-
size_t idx = scope.run_array_stack;
388-
scope.run_array_stack += 1;
316+
size_t idx = scope.run_sizes.array_stack;
317+
scope.run_sizes.array_stack += 1;
389318
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
390319
op = expr.as<CallNode>();
391320

@@ -426,9 +355,9 @@ class BuiltinLower : public StmtExprMutator {
426355
auto& scope = alloca_scope_.back();
427356
auto& prep_seq = prep_seq_stack_.back();
428357

429-
int64_t restore_shape_stack = scope.run_shape_stack;
430-
size_t restore_array_stack = scope.run_array_stack;
431-
size_t arg_stack_begin = scope.run_arg_stack;
358+
int64_t restore_shape_stack = scope.run_sizes.shape_stack;
359+
size_t restore_array_stack = scope.run_sizes.array_stack;
360+
size_t arg_stack_begin = scope.run_sizes.arg_stack;
432361

433362
size_t arg_count = op->args.size();
434363

@@ -437,7 +366,7 @@ class BuiltinLower : public StmtExprMutator {
437366
arg_count--;
438367
}
439368

440-
scope.run_arg_stack += arg_count;
369+
scope.run_sizes.arg_stack += arg_count;
441370
// Specially handle the buffer packed intrinsic
442371
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
443372
op = expr.as<CallNode>();
@@ -460,12 +389,14 @@ class BuiltinLower : public StmtExprMutator {
460389
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
461390
}
462391
// Verify stack size matches earlier value.
463-
ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack);
464-
ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack);
465-
ICHECK_LE(scope.run_array_stack, scope.max_array_stack);
466-
scope.run_shape_stack = restore_shape_stack;
467-
scope.run_array_stack = restore_array_stack;
468-
scope.run_arg_stack = arg_stack_begin;
392+
if (is_precheck_) {
393+
scope.UpdateMax();
394+
} else {
395+
scope.AssertMaxIsValid();
396+
}
397+
scope.run_sizes.shape_stack = restore_shape_stack;
398+
scope.run_sizes.array_stack = restore_array_stack;
399+
scope.run_sizes.arg_stack = arg_stack_begin;
469400
Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
470401
ConstInt32(arg_stack_begin),
471402
ConstInt32(arg_stack_begin + op->args.size() - 1)};
@@ -486,10 +417,10 @@ class BuiltinLower : public StmtExprMutator {
486417
auto& scope = alloca_scope_.back();
487418
auto& prep_seq = prep_seq_stack_.back();
488419

489-
int64_t restore_shape_stack = scope.run_shape_stack;
490-
size_t restore_array_stack = scope.run_array_stack;
491-
size_t arg_stack_begin = scope.run_arg_stack;
492-
scope.run_arg_stack += op->args.size();
420+
int64_t restore_shape_stack = scope.run_sizes.shape_stack;
421+
size_t restore_array_stack = scope.run_sizes.array_stack;
422+
size_t arg_stack_begin = scope.run_sizes.arg_stack;
423+
scope.run_sizes.arg_stack += op->args.size();
493424
size_t args_size = op->args.size();
494425
ICHECK_GT(args_size, 0);
495426
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
@@ -510,14 +441,16 @@ class BuiltinLower : public StmtExprMutator {
510441
prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
511442
}
512443
// Verify stack size matches earlier value.
513-
ICHECK_LE(scope.run_arg_stack, scope.max_arg_stack);
514-
ICHECK_LE(scope.run_shape_stack, scope.max_shape_stack);
515-
ICHECK_LE(scope.run_array_stack, scope.max_array_stack);
516-
scope.run_shape_stack = restore_shape_stack;
517-
scope.run_array_stack = restore_array_stack;
444+
if (is_precheck_) {
445+
scope.UpdateMax();
446+
} else {
447+
scope.AssertMaxIsValid();
448+
}
449+
scope.run_sizes.shape_stack = restore_shape_stack;
450+
scope.run_sizes.array_stack = restore_array_stack;
518451
// Update the top of the stack, so we can use more than one
519452
// packed function's arguments with the one stack.
520-
scope.run_arg_stack = arg_stack_begin + args_size - 1;
453+
scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1;
521454
Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
522455
ConstInt32(arg_stack_begin),
523456
ConstInt32(arg_stack_begin + op->args.size() - 1),
@@ -575,6 +508,8 @@ class BuiltinLower : public StmtExprMutator {
575508
PrimExpr device_type_;
576509
PrimExpr device_id_;
577510

511+
bool is_precheck_{false};
512+
578513
// Record all stack frames.
579514
std::vector<AllocaScope> alloca_scope_;
580515
};

tests/python/contrib/test_hexagon/test_cache_read_write.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def intrin_func(ins, outs):
6363

6464

6565
@requires_hexagon_toolchain
66-
def test_cache_read_write(android_serial_number, tvm_tracker_host, tvm_tracker_port):
66+
def test_cache_read_write(
67+
android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket
68+
):
6769
size = 128
6870
outer_shape = (size,)
6971
factor = 16
@@ -115,6 +117,7 @@ def test_cache_read_write(android_serial_number, tvm_tracker_host, tvm_tracker_p
115117
"rpc_tracker_host": tvm_tracker_host,
116118
"rpc_tracker_port": tvm_tracker_port,
117119
"rpc_server_port": 7070,
120+
"adb_server_socket": adb_server_socket,
118121
}
119122
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
120123
launcher.upload(dso_binary_path, dso_binary)

tests/scripts/task_python_hexagon.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ source tests/scripts/setup-pytest-env.sh
2424

2525
make cython3
2626

27-
run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon/test_launcher.py
27+
run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon

tests/scripts/task_python_hexagon_simulator.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim"
3535
# HEXAGON_TOOLCHAIN is already set
3636
export HEXAGON_SDK_ROOT=${HEXAGON_SDK_PATH}
3737
export ANDROID_SERIAL_NUMBER=simulator
38-
run_pytest ctypes python-contrib-hexagon-simulator tests/python/contrib/test_hexagon/test_launcher.py
38+
run_pytest ctypes python-contrib-hexagon-simulator tests/python/contrib/test_hexagon
3939

4040
kill ${TRACKER_PID}

0 commit comments

Comments
 (0)