Skip to content

Commit 060d9d2

Browse files
authored
[AOT] Introduce checks for return values from operators (#10424)
This matches the lowering of `call_cpacked` which checks only for an operator return of `0` in the main flow: https://github.com/apache/tvm/blob/bd14a4d36e0d364ef9bd34b2ee96cc09ce64d4b3/src/target/source/codegen_c_host.cc#L207-L231 This replaces: ```c (void)tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var); ``` with: ```c if (tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var) != 0 ) return -1; ``` when AOT generates the C output.
1 parent f9d3918 commit 060d9d2

File tree

8 files changed

+118
-31
lines changed

8 files changed

+118
-31
lines changed

apps/microtvm/ethosu/src/tvm_ethosu_runtime.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cm
3434
return 0;
3535
}
3636

37-
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
38-
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
39-
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
40-
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
37+
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) { return 0; }
38+
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) { return 0; }
39+
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) { return 0; }
40+
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) { return 0; }

include/tvm/tir/builtin.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,20 @@ TVM_DLL const Op& tvm_call_cpacked();
384384
*/
385385
TVM_DLL const Op& tvm_call_trace_packed();
386386

387+
/*!
388+
* \brief Checks the return value of another call is correct or returns a given value.
389+
*
390+
* \note This is meant to serve a specific case for AOT code generator whilst this
391+
* cannot be fully represented in TIR.
392+
*
393+
* Type tvm_check_return(expected, return_unexpected, nested_call) {
394+
* if (nested_call() != expected) {
395+
* return return_unexpected;
396+
* }
397+
* }
398+
*/
399+
TVM_DLL const Op& tvm_check_return();
400+
387401
/*!
388402
* \brief See pesudo code
389403
* Mark the content as thread local context, can get optimized

src/relay/backend/aot_executor_codegen.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
279279
* \param num the number to convert
280280
* \return PrimExpr representing num
281281
*/
282-
inline PrimExpr ConstInt32(size_t num) {
282+
inline PrimExpr ConstInt32(int32_t num) {
283283
ICHECK_LE(num, std::numeric_limits<int>::max());
284284
return tir::make_const(DataType::Int(32), static_cast<int>(num));
285285
}
@@ -333,6 +333,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
333333
args->insert(args->end(), sids.begin(), sids.end());
334334
}
335335

336+
/*
337+
* Wraps a call_extern with a tvm_check_return annotation if required otherwise
338+
* returns the passed Call
339+
*/
340+
tir::Call AddCheckReturn(tir::Call existing_call) {
341+
if (use_unpacked_api_) {
342+
Array<PrimExpr> args = {ConstInt32(0), ConstInt32(-1), existing_call};
343+
return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args);
344+
}
345+
346+
return existing_call;
347+
}
348+
336349
/*!
337350
* brief Create a function call
338351
* \param call_lowered_props The lowered function and the arguments to call it with
@@ -343,6 +356,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
343356
std::string func_name = call_lowered_props.lowered_func->name_hint;
344357
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
345358
std::vector<tir::Stmt> create_func_call_stmts;
359+
346360
// Pack the inputs
347361
for (const Expr& arg : call_lowered_props.arguments) {
348362
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
@@ -394,7 +408,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
394408
tir::Var context = device_contexts_.Get(global_var).value();
395409
args.push_back(context);
396410

397-
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
411+
tir::Evaluate func_call(
412+
AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));
398413
create_func_call_stmts.push_back(tir::SeqStmt({
399414
GenerateDeviceHook(context, "Open"),
400415
func_call,
@@ -407,7 +422,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
407422
create_func_call_stmts.push_back(func_call);
408423
} else {
409424
// call_extern calling convention without context
410-
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
425+
tir::Evaluate func_call(
426+
AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));
411427
create_func_call_stmts.push_back(func_call);
412428
}
413429

@@ -482,8 +498,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
482498
Array<String> sections = {"Device", device_name, hook};
483499
String device_hook_name = ToCFunctionStyle(PrefixName(sections));
484500

485-
tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
486-
{tvm::tir::StringImm(device_hook_name), context}));
501+
tir::Evaluate device_hook(
502+
AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
503+
{tvm::tir::StringImm(device_hook_name), context})));
487504
device_hooks.push_back(device_hook);
488505
}
489506
return tir::SeqStmt(device_hooks);
@@ -503,8 +520,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
503520
Array<String> sections = {"Device", device_name, hook};
504521
String device_hook = ToCFunctionStyle(PrefixName(sections));
505522

506-
return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
507-
{tvm::tir::StringImm(device_hook), context}));
523+
return tir::Evaluate(
524+
AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
525+
{tvm::tir::StringImm(device_hook), context})));
508526
}
509527

510528
/*!

src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cm
3434
return 0;
3535
}
3636

37-
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
38-
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
39-
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
40-
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
37+
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) { return 0; }
38+
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) { return 0; }
39+
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) { return 0; }
40+
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) { return 0; }

src/target/source/codegen_c.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
528528
if (auto* ptr_op = op->op.as<OpNode>()) {
529529
auto call_op = GetRef<Op>(ptr_op);
530530

531-
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
531+
if (op->op.same_as(builtin::tvm_check_return())) {
532+
const CallNode* call = op->args[2].as<CallNode>();
533+
os << "if (";
534+
VisitExpr_(call, os);
535+
os << " != ";
536+
PrintExpr(op->args[0], os);
537+
os << " ) return ";
538+
PrintExpr(op->args[1], os);
539+
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
532540
ICHECK_GE(op->args.size(), 1U);
533541
auto func = Downcast<StringImm>(op->args[0]);
534542
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
@@ -971,7 +979,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) {
971979
std::string vid = this->PrintExpr(op->value);
972980
if (vid != "") {
973981
this->PrintIndent();
974-
this->stream << "(void)" << vid << ";\n";
982+
this->stream << vid << ";\n";
975983
}
976984
}
977985

src/tir/op/builtin.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked)
184184
TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed)
185185
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
186186

187+
TIR_DEFINE_BUILTIN_FUNC(tvm_check_return)
188+
.set_num_inputs(3)
189+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
190+
187191
TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
188192
.set_num_inputs(1)
189193
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

tests/python/relay/aot/test_c_device_api.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,27 +137,37 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
137137
# Activate Device
138138
assert (
139139
str(main_func.body[0])
140-
== "tir.call_extern(" + '"TVMDeviceEthosUActivate",' + " device_context_ethos_u)\n"
140+
== "tir.tvm_check_return(0, -1, tir.call_extern("
141+
+ '"TVMDeviceEthosUActivate",'
142+
+ " device_context_ethos_u))\n"
141143
)
142144
# Open Device
143145
assert (
144146
str(main_func.body[1][0][0][0])
145-
== "tir.call_extern(" + '"TVMDeviceEthosUOpen",' + " device_context_ethos_u)\n"
147+
== "tir.tvm_check_return(0, -1, tir.call_extern("
148+
+ '"TVMDeviceEthosUOpen",'
149+
+ " device_context_ethos_u))\n"
146150
)
147151
# Device Call
148152
assert (
149153
str(main_func.body[1][0][0][1])
150-
== 'tir.call_extern("tvmgen_default_ethos_u_main_0", x_int8_buffer_var, output_buffer_var, device_context_ethos_u)\n'
154+
== "tir.tvm_check_return(0, -1, tir.call_extern("
155+
+ '"tvmgen_default_ethos_u_main_0",'
156+
+ " x_int8_buffer_var, output_buffer_var, device_context_ethos_u))\n"
151157
)
152158
# Close Device
153159
assert (
154160
str(main_func.body[1][0][0][2])
155-
== "tir.call_extern(" + '"TVMDeviceEthosUClose",' + " device_context_ethos_u)\n"
161+
== "tir.tvm_check_return(0, -1, tir.call_extern("
162+
+ '"TVMDeviceEthosUClose",'
163+
+ " device_context_ethos_u))\n"
156164
)
157165
# Deactivate Device
158166
assert (
159167
str(str(main_func.body[2]))
160-
== "tir.call_extern(" + '"TVMDeviceEthosUDeactivate",' + " device_context_ethos_u)\n"
168+
== "tir.tvm_check_return(0, -1, tir.call_extern("
169+
+ '"TVMDeviceEthosUDeactivate",'
170+
+ " device_context_ethos_u))\n"
161171
)
162172

163173

@@ -171,18 +181,18 @@ def test_device_api_hooks_packed_api(device_api_main_func):
171181
# Activate Device
172182
assert (
173183
str(main_func.body[0][0].value)
174-
== "@tir.call_extern("
184+
== "@tir.tvm_check_return(0, -1, tir.call_extern("
175185
+ '"TVMDeviceEthosUActivate",'
176186
+ " device_context_ethos_u: handle,"
177-
+ " dtype=int32)"
187+
+ " dtype=int32))"
178188
)
179189
# Open Device
180190
assert (
181191
str(main_func.body[1].body.body[0][0][0].value)
182-
== "@tir.call_extern("
192+
== "@tir.tvm_check_return(0, -1, tir.call_extern("
183193
+ '"TVMDeviceEthosUOpen",'
184194
+ " device_context_ethos_u: handle,"
185-
+ " dtype=int32)"
195+
+ " dtype=int32))"
186196
)
187197
# Device Call
188198
assert (
@@ -196,18 +206,18 @@ def test_device_api_hooks_packed_api(device_api_main_func):
196206
# Close Device
197207
assert (
198208
str(main_func.body[1].body.body[0][0][2].value)
199-
== "@tir.call_extern("
209+
== "@tir.tvm_check_return(0, -1, tir.call_extern("
200210
+ '"TVMDeviceEthosUClose",'
201211
+ " device_context_ethos_u: handle,"
202-
+ " dtype=int32)"
212+
+ " dtype=int32))"
203213
)
204214
# Deactivate Device
205215
assert (
206216
str(main_func.body[2][0].value)
207-
== "@tir.call_extern("
217+
== "@tir.tvm_check_return(0, -1, tir.call_extern("
208218
+ '"TVMDeviceEthosUDeactivate",'
209219
+ " device_context_ethos_u: handle,"
210-
+ " dtype=int32)"
220+
+ " dtype=int32))"
211221
)
212222

213223

@@ -217,7 +227,9 @@ def test_without_device_api_unpacked_api(non_device_api_main_func):
217227
main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True)
218228
assert (
219229
str(main_func.body)
220-
== 'tir.call_extern("tvmgen_default_fused_multiply", x_buffer_var, y_buffer_var, output_buffer_var)\n'
230+
== "tir.tvm_check_return(0, -1, tir.call_extern("
231+
+ '"tvmgen_default_fused_multiply",'
232+
+ " x_buffer_var, y_buffer_var, output_buffer_var))\n"
221233
)
222234

223235

tests/python/relay/aot/test_crt_aot.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,5 +920,36 @@ def test_workspace_calculation_cmsis_nn():
920920
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 9904
921921

922922

923+
def test_aot_codegen_checks_returns():
924+
"""This test checks whether AoT lowering creates calls that check the return value correctly"""
925+
x = relay.var("x", shape=(1, 10))
926+
y = relay.var("y", shape=(1, 10))
927+
z = relay.add(x, y)
928+
func = relay.Function([x, y], z)
929+
930+
compiled_test_mods = compile_models(
931+
models=AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None),
932+
interface_api="c",
933+
use_unpacked_api=True,
934+
)
935+
source = compiled_test_mods[0].executor_factory.lib.imported_modules[0].get_source()
936+
937+
main_ir_module = compiled_test_mods[0].executor_factory.lowered_ir_mods.items()[0][1]
938+
main_func = main_ir_module["__tvm_main__"]
939+
940+
# Check operator call is wrapped properly
941+
assert (
942+
str(main_func.body[1])
943+
== "tir.tvm_check_return(0, -1, tir.call_extern("
944+
+ '"tvmgen_default_fused_add",'
945+
+ " x_buffer_var, y_buffer_var, output_buffer_var))\n"
946+
)
947+
# TODO(Mousius) - Create a better place for C codegen tests
948+
assert (
949+
"if (tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var) != 0 ) return -1;"
950+
in source
951+
)
952+
953+
923954
if __name__ == "__main__":
924955
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)