@@ -138,7 +138,7 @@ class AOTExecutorCodegen : public ExprVisitor {
138
138
auto sid_array = te::Var (MakeString (" sid_" , sid, " _value" ), DataType::Handle ());
139
139
auto sid_value = sids_table_[sid];
140
140
141
- if (target_host_-> GetAttr <Bool>( " typed-operators " ). value_or ( Bool ( true )) ) {
141
+ if (use_typed_operators_ ) {
142
142
tvm::PrimExpr set_tensor =
143
143
tvm::tir::Call (DataType::Handle (), tvm::tir::builtin::tvm_struct_set (),
144
144
{sid_array, 0 , tir::builtin::kArrData , sid_value});
@@ -168,7 +168,7 @@ class AOTExecutorCodegen : public ExprVisitor {
168
168
auto param_handle = tvm::tir::Call (DataType::Handle (), tvm::tir::builtin::lookup_param (),
169
169
{tir::StringImm (params_by_expr_[expr])});
170
170
171
- if (target_host_-> GetAttr <Bool>( " typed-operators " ). value_or ( Bool ( true )) ) {
171
+ if (use_typed_operators_ ) {
172
172
tvm::PrimExpr set_param_array =
173
173
tvm::tir::Call (DataType::Handle (), tvm::tir::builtin::tvm_struct_set (),
174
174
{param_array, 0 , tir::builtin::kArrData , param_handle});
@@ -220,7 +220,7 @@ class AOTExecutorCodegen : public ExprVisitor {
220
220
221
221
// Use tvm_call_packed to execute the function unless we're calling directly
222
222
auto calling_pattern = tvm::tir::builtin::tvm_call_cpacked ();
223
- if (!target_host_-> GetAttr <Bool>( " typed-operators " ). value_or ( Bool ( true )) ) {
223
+ if (!use_typed_operators_ ) {
224
224
calling_pattern = tvm::tir::builtin::call_extern ();
225
225
}
226
226
@@ -248,7 +248,7 @@ class AOTExecutorCodegen : public ExprVisitor {
248
248
{in, 0 , tir::builtin::kArrData });
249
249
PrimExpr tostore = tvm::tir::Call (DataType::Handle (), tvm::tir::builtin::tvm_struct_get (),
250
250
{out, 0 , tir::builtin::kArrData });
251
- if (!target_host_-> GetAttr <Bool>( " typed-operators " ). value_or ( Bool ( true )) ) {
251
+ if (!use_typed_operators_ ) {
252
252
retval_get = in;
253
253
tostore = out;
254
254
}
@@ -551,6 +551,8 @@ class AOTExecutorCodegen : public ExprVisitor {
551
551
TargetsMap targets_;
552
552
/* ! \brief target host */
553
553
Target target_host_;
554
+ /* ! \brief untyped operators flag */
555
+ Bool use_typed_operators_;
554
556
555
557
/* !
556
558
* \brief parameters (i.e. ConstantNodes found in the graph).
@@ -580,10 +582,11 @@ class AOTExecutorCodegen : public ExprVisitor {
580
582
581
583
public:
582
584
AOTExecutorCodegen (runtime::Module* mod, const TargetsMap& targets, Target target_host)
583
- : mod_(mod), return_sid_( ) {
585
+ : mod_(mod), use_typed_operators_( true ) {
584
586
compile_engine_ = CompileEngine::Global ();
585
587
targets_ = targets;
586
588
target_host_ = target_host;
589
+ use_typed_operators_ = target_host->GetAttr <Bool>(" typed-operators" ).value_or (Bool (true ));
587
590
}
588
591
589
592
LoweredOutput Codegen (relay::Function func) {
@@ -607,8 +610,7 @@ class AOTExecutorCodegen : public ExprVisitor {
607
610
// Find the return sid
608
611
return_sid_ = AotReturnSidVisitor (storage_device_map_).FindReturnSid (func);
609
612
for (unsigned int output_index = 0 ; output_index < return_sid_.size (); output_index++) {
610
- auto output_var = tir::Var (" output" , DataType::Handle ());
611
- main_signature_.push_back (output_var);
613
+ main_signature_.push_back (tir::Var (" output" , DataType::Handle ()));
612
614
}
613
615
614
616
VisitExpr (func->body );
0 commit comments