Skip to content

Commit a769ece

Browse files
authored
[AOT] Initial implementation of --unpacked-api (#8023)
* [AOT] Initial implementation of --no-typed-operators Based on the discussions in the AOT embedded improvements RFC, this adds a flag to the target which changes the internal operators to an unpacked API. The unpacked API spreads the input buffers across the operator function, for example: int32_t operator(void* arg0, void* arg1); As opposed to the traditional packed API: int32_t operator(void** args); Uneffected is the entrypoint function, which retains a packed API for compatibility with other parts of TVM. This is done by changing the passes taken by none entrypoint (CallingConv::kEntryPoint) functions. * Move entrypoint generation outside of main passes This removes the logic for deciding the entrypoint from the compiler passes and instead moves it into the metadata code generation. By moving the generation, we can generate a variety of entrypoints on top of the compiler output (such as the micro entrypoint discussed in the RFC). * Use buffers in make_unpacked_api tests * Enable --no-typed-operators for llvm * Change --no-typed-operators to --typed-operators=0 to match other options * Refactor typed-operators lookup into use_typed_operators_ (Also contains minor clean up of output variables) * Rename --typed-operators to --unpacked-api (Also moves the entrypoint name to a constant) * Move all properties into init list to avoid double init * Remove AutoTVM breaking default and improve clarity
1 parent 3e34e11 commit a769ece

File tree

10 files changed

+478
-63
lines changed

10 files changed

+478
-63
lines changed

include/tvm/tir/transform.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,17 @@ TVM_DLL Pass InstrumentBoundCheckers();
212212
*/
213213
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
214214

215+
/*!
216+
* \brief Transform the high-level PrimFunc to a C signature that can be used
217+
* to call the operator directly.
218+
*
219+
* The main task of this function is to create code that maps the values in the
220+
* api_args to Var that is required by body
221+
*
222+
* \return The pass.
223+
*/
224+
TVM_DLL Pass MakeUnpackedAPI();
225+
215226
/*!
216227
* \brief Remap the thread axis
217228
*

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,17 @@ def MakePackedAPI(num_unpacked_params=0):
347347
return _ffi_api.MakePackedAPI(num_unpacked_params)
348348

349349

350+
def MakeUnpackedAPI():
351+
"""Transform the PrimFuncs in the module to a C API compatible with internal calls.
352+
353+
Returns
354+
-------
355+
fpass : tvm.transform.Pass
356+
The result pass
357+
"""
358+
return _ffi_api.MakeUnpackedAPI()
359+
360+
350361
def SplitHostDevice():
351362
"""Split the function into a host function and device functions.
352363

src/driver/driver_api.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,15 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
200200
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
201201
mixed_pass_list.push_back(tir::transform::InferFragment());
202202
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
203-
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
203+
204+
if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
205+
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
206+
} else {
207+
mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
208+
}
209+
204210
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
211+
205212
auto opt_mixed = transform::Sequential(mixed_pass_list);
206213
mod_mixed = opt_mixed(std::move(mod_mixed));
207214

src/relay/backend/aot_executor_codegen.cc

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,17 @@ class AOTExecutorCodegen : public ExprVisitor {
137137
// Pack the sid inside the TVMValue
138138
auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
139139
auto sid_value = sids_table_[sid];
140-
tvm::PrimExpr set_tensor =
141-
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
142-
{sid_array, 0, tir::builtin::kArrData, sid_value});
143-
stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
140+
141+
if (!use_unpacked_api_) {
142+
tvm::PrimExpr set_tensor =
143+
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
144+
{sid_array, 0, tir::builtin::kArrData, sid_value});
145+
stmts_.push_back(
146+
tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
147+
} else {
148+
stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0)));
149+
}
150+
144151
sid_vars.push_back(sid_array);
145152
}
146153
return sid_vars;
@@ -161,16 +168,16 @@ class AOTExecutorCodegen : public ExprVisitor {
161168
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
162169
{tir::StringImm(params_by_expr_[expr])});
163170

164-
tvm::PrimExpr set_param_array =
165-
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
166-
{param_array, 0, tir::builtin::kArrData, param_handle});
167-
lookup_call.push_back(tir::Evaluate(set_param_array));
168-
169-
tir::Stmt lookup_body = tir::SeqStmt(lookup_call);
171+
if (!use_unpacked_api_) {
172+
tvm::PrimExpr set_param_array =
173+
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
174+
{param_array, 0, tir::builtin::kArrData, param_handle});
175+
stmts_.push_back(
176+
tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
177+
} else {
178+
stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
179+
}
170180

171-
// Allocate the DLTensors on the stack
172-
lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body);
173-
stmts_.push_back(lookup_body);
174181
return param_array;
175182
}
176183

@@ -206,15 +213,20 @@ class AOTExecutorCodegen : public ExprVisitor {
206213
}
207214

208215
auto ret_expr = Downcast<Expr>(call);
209-
210216
// Pack the return(s) value. A call node can produce multiple outputs
211217
for (const auto& var : PackSid(ret_expr)) {
212218
args.push_back(var);
213219
}
214220

215-
// Use tvm_call_packed to execute the function
216-
create_func_call_stmts.push_back(tir::Evaluate(
217-
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)));
221+
// Use tvm_call_packed to execute the function unless we're calling directly
222+
auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked();
223+
if (use_unpacked_api_) {
224+
calling_pattern = tvm::tir::builtin::call_extern();
225+
}
226+
227+
create_func_call_stmts.push_back(
228+
tir::Evaluate(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));
229+
218230
tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
219231
stmts_.push_back(body);
220232
}
@@ -226,16 +238,20 @@ class AOTExecutorCodegen : public ExprVisitor {
226238
* copy-on-write fashion.
227239
*/
228240
void CopyToOutput(te::Var out, te::Var in, size_t size) {
229-
auto retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
230-
{in, 0, tir::builtin::kArrData});
231-
232241
// Define intermediate DLTensor to load/store the data
233242
auto tmp0 = te::Var("tmp0", DataType::Handle());
234243
auto tmp1 = te::Var("tmp1", DataType::Handle());
235244
te::Var loop_idx("i", DataType::Int(32));
236245
auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true());
237-
auto tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
238-
{out, 0, tir::builtin::kArrData});
246+
247+
PrimExpr retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
248+
{in, 0, tir::builtin::kArrData});
249+
PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
250+
{out, 0, tir::builtin::kArrData});
251+
if (use_unpacked_api_) {
252+
retval_get = in;
253+
tostore = out;
254+
}
239255

240256
// Copy the variable from the input to the output
241257
tir::Stmt copy = tir::For(
@@ -535,6 +551,15 @@ class AOTExecutorCodegen : public ExprVisitor {
535551
TargetsMap targets_;
536552
/*! \brief target host */
537553
Target target_host_;
554+
/*!
555+
* \brief unpacked api toggle
556+
* When set to true the code generated will use unpacked calls to functions:
557+
* func(void* arg0, void* arg1)
558+
* Rather than packed calls:
559+
* func(void* args)
560+
* Defaults to using the packed calling convention
561+
*/
562+
Bool use_unpacked_api_;
538563

539564
/*!
540565
* \brief parameters (i.e. ConstantNodes found in the graph).
@@ -564,21 +589,20 @@ class AOTExecutorCodegen : public ExprVisitor {
564589

565590
public:
566591
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
567-
: mod_(mod), return_sid_() {
568-
compile_engine_ = CompileEngine::Global();
569-
targets_ = targets;
570-
target_host_ = target_host;
571-
}
592+
: mod_(mod),
593+
targets_(targets),
594+
target_host_(target_host),
595+
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
596+
compile_engine_(CompileEngine::Global()) {}
572597

573598
LoweredOutput Codegen(relay::Function func) {
574599
// Get the module, storage map and token sizes
575600
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
576601
storage_device_map_ = (*pf)(func);
577602

578-
int input_index = 0;
579603
for (auto input : func->params) {
580604
input_vars_.push_back(input);
581-
main_signature_.push_back(tir::Var(MakeString("input_", input_index), DataType::Handle()));
605+
main_signature_.push_back(tir::Var("input", DataType::Handle()));
582606
}
583607

584608
// Define the storage allocator ids
@@ -592,7 +616,7 @@ class AOTExecutorCodegen : public ExprVisitor {
592616
// Find the return sid
593617
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
594618
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
595-
main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle()));
619+
main_signature_.push_back(tir::Var("output", DataType::Handle()));
596620
}
597621

598622
VisitExpr(func->body);

src/target/source/source_module.cc

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,59 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
192192
<< "}\n";
193193
}
194194

195+
void GenerateEntrypointForUnpackedAPI() {
196+
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
197+
int total_args = (metadata_->num_inputs + metadata_->num_outputs);
198+
for (int i = 0; i < total_args; ++i) {
199+
code_ << "arg" << i;
200+
if (i + 1 != total_args) {
201+
code_ << ",";
202+
}
203+
}
204+
code_ << ");\n";
205+
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
206+
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
207+
"out_type_code, void* resource_handle) {\n";
208+
code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
209+
for (int i = 0; i < metadata_->num_inputs; ++i) {
210+
code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
211+
}
212+
for (int i = 0; i < metadata_->num_outputs; ++i) {
213+
int j = metadata_->num_inputs + i;
214+
code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data";
215+
if (i + 1 != metadata_->num_outputs) {
216+
code_ << ",";
217+
}
218+
}
219+
code_ << ");\n";
220+
code_ << "}\n";
221+
}
222+
223+
void GenerateEntrypointForPackedAPI() {
224+
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
225+
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
226+
"out_type_code, void* resource_handle);\n";
227+
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
228+
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
229+
"out_type_code, void* resource_handle) {\n";
230+
code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix;
231+
code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n";
232+
code_ << "}\n";
233+
}
234+
195235
void GenerateAOTDescriptor() {
196236
code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n";
197237
code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n";
198238
code_ << "#ifdef __cplusplus\n";
199239
code_ << "extern \"C\"\n";
200240
code_ << "#endif\n";
201-
code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
202-
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
203-
"out_type_code, void* resource_handle);\n";
241+
if (target_->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
242+
GenerateEntrypointForUnpackedAPI();
243+
} else {
244+
GenerateEntrypointForPackedAPI();
245+
}
204246
code_ << "const tvm_model_t network = {\n"
205-
<< " .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n"
247+
<< " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n"
206248
<< " .num_input_tensors = " << metadata_->num_inputs << ",\n"
207249
<< " .num_output_tensors = " << metadata_->num_outputs << ", \n"
208250
<< "};\n";

src/target/target_kind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
298298
.add_attr_option<Bool>("system-lib")
299299
.add_attr_option<String>("runtime")
300300
.add_attr_option<Bool>("link-params", Bool(false))
301+
.add_attr_option<Bool>("unpacked-api")
301302
.set_default_keys({"cpu"});
302303

303304
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
@@ -308,6 +309,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU)
308309
.add_attr_option<String>("march")
309310
.add_attr_option<String>("executor")
310311
.add_attr_option<Integer>("workspace-byte-alignment")
312+
.add_attr_option<Bool>("unpacked-api")
311313
.set_default_keys({"cpu"});
312314

313315
TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)

0 commit comments

Comments
 (0)