@@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
476476 argument_registers.push_back (reg->second );
477477 }
478478
479- // Next generate the invoke instruction.
480479 Target target;
481- if (targets_.size () == 1 ) {
482- // homogeneous execution.
483- for (auto kv : targets_) {
484- target = kv.second ;
485- }
480+
481+ if (!func->UseDefaultCompiler ()) {
482+ target = tvm::target::ext_dev ();
486483 } else {
487- // heterogeneous execution.
488- LOG (FATAL) << " Currently VM compiler doesn't support heterogeneous compilation" ;
484+ // Next generate the invoke instruction.
485+ if (targets_.size () == 1 ) {
486+ // homogeneous execution.
487+ const auto & it = targets_.begin ();
488+ target = (*it).second ;
489+ } else {
490+ // heterogeneous execution.
491+ LOG (FATAL) << " Currently VM compiler doesn't support heterogeneous compilation" ;
492+ }
489493 }
490494
491495 auto key = CCacheKeyNode::make (func, target);
492496 auto cfunc = engine_->Lower (key);
493497
494- // TODO(jroesch): support lowered funcs for multiple targets
495- CHECK_EQ (cfunc->funcs .size (), 1 );
496498 auto op_index = -1 ;
497- if (context_-> seen_funcs . find (cfunc-> funcs [ 0 ]) == context_-> seen_funcs . end ()) {
499+ if (!func-> UseDefaultCompiler ()) {
498500 op_index = context_->cached_funcs .size ();
499501 context_->cached_funcs .push_back (cfunc);
500- context_->seen_funcs [cfunc->funcs [0 ]] = op_index;
501502 } else {
502- op_index = context_->seen_funcs [cfunc->funcs [0 ]];
503+ // TODO(jroesch): support lowered funcs for multiple targets
504+ CHECK_EQ (cfunc->funcs .size (), 1 );
505+ if (context_->seen_funcs .find (cfunc->funcs [0 ]) == context_->seen_funcs .end ()) {
506+ op_index = context_->cached_funcs .size ();
507+ context_->cached_funcs .push_back (cfunc);
508+ context_->seen_funcs [cfunc->funcs [0 ]] = op_index;
509+ } else {
510+ op_index = context_->seen_funcs [cfunc->funcs [0 ]];
511+ }
503512 }
504513
505514 Emit (Instruction::InvokePacked (op_index,
@@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
950959 if (cached_funcs.size () == 0 ) {
951960 return ;
952961 }
953- std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs ;
954- for (auto & cfunc : cached_funcs) {
962+ std::unordered_map<std::string, Array<LoweredFunc>> funcs ;
963+ for (auto & cfunc : cached_funcs) {
955964 std::string target_str = cfunc->target ->str ();
956- if (tgt_funcs.count (target_str) == 0 ) {
957- tgt_funcs.emplace (target_str, Array<LoweredFunc>{cfunc->funcs [0 ]});
965+ if (target_str == " ext_dev" ) {
966+ continue ;
967+ } else if (funcs.count (target_str) == 0 ) {
968+ funcs.emplace (target_str, Array<LoweredFunc>{cfunc->funcs [0 ]});
958969 } else {
959- tgt_funcs [target_str].push_back (cfunc->funcs [0 ]);
970+ funcs [target_str].push_back (cfunc->funcs [0 ]);
960971 }
961972 }
962- Map<Target, Array<LoweredFunc>> funcs;
963- for (auto &it : tgt_funcs) {
964- funcs.Set (Target::Create (it.first ), it.second );
965- }
966973
967- if (const auto *f = runtime::Registry::Get (" relay.backend.build" )) {
968- // The target is just a dummy arg because funcs already contains corresponding target
969- // therefore target won't be used in the build function
970- runtime::Module mod = (*f)(funcs, Target (), target_host_);
974+ auto compile_engine = CompileEngine::Global ();
975+ auto ext_mods = compile_engine->LowerExternalFunctions ();
976+ runtime::Module mod;
977+ if (funcs.size () > 0 ) {
978+ mod = tvm::build (funcs, target_host_, tvm::BuildConfig::Current ());
971979 CHECK (mod.operator ->());
972- exec_->lib = mod;
973980 } else {
974- LOG (FATAL) << " relay.backend.build is not registered" ;
981+ CHECK_EQ (ext_mods.size (), 1U )
982+ << " Expect to have a TVM DSOModule when multiple runtime modules exist" ;
983+ }
984+ if (!ext_mods.empty ()) {
985+ if (funcs.size () == 0 ) {
986+ mod = ext_mods[0 ];
987+ } else {
988+ // Import all external runtime modules.
989+ for (auto it : ext_mods) {
990+ mod.Import (it);
991+ }
992+ }
975993 }
994+ exec_->lib = mod;
976995 size_t primitive_index = 0 ;
977996 for (auto cfunc : cached_funcs) {
978- exec_->primitive_map .insert ({cfunc->funcs [0 ]->name , primitive_index++});
997+ if (cfunc->target ->str () == " ext_dev" ) {
998+ exec_->primitive_map .insert ({cfunc->func_name , primitive_index++});
999+ } else {
1000+ exec_->primitive_map .insert ({cfunc->funcs [0 ]->name , primitive_index++});
1001+ }
9791002 }
9801003}
9811004
0 commit comments