Skip to content

Commit 40b964d

Browse files
committed
Refactor typed-operators lookup into use_typed_operators_
(Also contains minor clean up of output variables)
1 parent 8aa85be commit 40b964d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class AOTExecutorCodegen : public ExprVisitor {
138138
auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
139139
auto sid_value = sids_table_[sid];
140140

141-
if (target_host_->GetAttr<Bool>("typed-operators").value_or(Bool(true))) {
141+
if (use_typed_operators_) {
142142
tvm::PrimExpr set_tensor =
143143
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
144144
{sid_array, 0, tir::builtin::kArrData, sid_value});
@@ -168,7 +168,7 @@ class AOTExecutorCodegen : public ExprVisitor {
168168
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
169169
{tir::StringImm(params_by_expr_[expr])});
170170

171-
if (target_host_->GetAttr<Bool>("typed-operators").value_or(Bool(true))) {
171+
if (use_typed_operators_) {
172172
tvm::PrimExpr set_param_array =
173173
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
174174
{param_array, 0, tir::builtin::kArrData, param_handle});
@@ -220,7 +220,7 @@ class AOTExecutorCodegen : public ExprVisitor {
220220

221221
// Use tvm_call_packed to execute the function unless we're calling directly
222222
auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked();
223-
if (!target_host_->GetAttr<Bool>("typed-operators").value_or(Bool(true))) {
223+
if (!use_typed_operators_) {
224224
calling_pattern = tvm::tir::builtin::call_extern();
225225
}
226226

@@ -248,7 +248,7 @@ class AOTExecutorCodegen : public ExprVisitor {
248248
{in, 0, tir::builtin::kArrData});
249249
PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
250250
{out, 0, tir::builtin::kArrData});
251-
if (!target_host_->GetAttr<Bool>("typed-operators").value_or(Bool(true))) {
251+
if (!use_typed_operators_) {
252252
retval_get = in;
253253
tostore = out;
254254
}
@@ -551,6 +551,8 @@ class AOTExecutorCodegen : public ExprVisitor {
551551
TargetsMap targets_;
552552
/*! \brief target host */
553553
Target target_host_;
554+
/*! \brief untyped operators flag */
555+
Bool use_typed_operators_;
554556

555557
/*!
556558
* \brief parameters (i.e. ConstantNodes found in the graph).
@@ -580,10 +582,11 @@ class AOTExecutorCodegen : public ExprVisitor {
580582

581583
public:
582584
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
583-
: mod_(mod), return_sid_() {
585+
: mod_(mod), use_typed_operators_(true) {
584586
compile_engine_ = CompileEngine::Global();
585587
targets_ = targets;
586588
target_host_ = target_host;
589+
use_typed_operators_ = target_host->GetAttr<Bool>("typed-operators").value_or(Bool(true));
587590
}
588591

589592
LoweredOutput Codegen(relay::Function func) {
@@ -607,8 +610,7 @@ class AOTExecutorCodegen : public ExprVisitor {
607610
// Find the return sid
608611
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
609612
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
610-
auto output_var = tir::Var("output", DataType::Handle());
611-
main_signature_.push_back(output_var);
613+
main_signature_.push_back(tir::Var("output", DataType::Handle()));
612614
}
613615

614616
VisitExpr(func->body);

0 commit comments

Comments
 (0)