@@ -85,33 +85,46 @@ class TECompilerImpl : public TECompilerNode {
8585 return LowerShapeFuncInternal (key)->cached_func ;
8686 }
8787
88- Map<Target, IRModule> GetLoweredFunctions () {
89- std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
90- lowered_functions;
88+ IRModule GetLoweredFunctions () {
89+ IRModule mod;
90+ // Extract lowered functions from the cache
9191 for (const auto & it : cache_) {
9292 auto source_func = it.first ;
9393 auto lowered_func = it.second ;
94- auto target = source_func->target ;
9594
96- if (!lowered_functions.count (target)) {
97- lowered_functions[target] = IRModule (Map<GlobalVar, BaseFunc>({}));
98- }
95+ IRModule lowered_mod = lowered_func->cached_func ->funcs ;
9996
100- lowered_functions[target]->Update (lowered_func->cached_func ->funcs );
101- }
97+ // Annotate functions with their target and put them in the return module
98+ for (auto kv : lowered_mod->functions ) {
99+ const GlobalVar& var = kv.first ;
100+ const BaseFunc& func = kv.second ;
102101
102+ // Only add functions that are not external functions
103+ if (!func->GetAttr <String>(attr::kCompiler ).defined ()) {
104+ ICHECK (func->IsInstance <tir::PrimFuncNode>())
105+ << " Expected all functions that are not external to be PrimFuncs, but found "
106+ << func->GetTypeKey ();
107+ const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
108+ mod->Update (var, WithAttr (prim_func, tvm::attr::kTarget , source_func->target ));
109+ }
110+ }
111+ }
112+ // Extract lowered dynamic shape functions from the shape cache
103113 for (const auto & it : shape_func_cache_) {
104114 auto source_func = it.first ;
105115 auto lowered_func = it.second ;
106116 auto target = source_func->target ;
107-
108- if (!lowered_functions.count (target)) {
109- lowered_functions[target] = IRModule (Map<GlobalVar, BaseFunc>({}));
117+ IRModule lowered_mod = lowered_func->cached_func ->funcs ;
118+
119+ // Annotate functions with their target and put them in the return module
120+ for (auto kv : lowered_mod->functions ) {
121+ const GlobalVar& var = kv.first ;
122+ const BaseFunc& func = kv.second ;
123+ const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
124+ mod->Update (var, WithAttr (prim_func, tvm::attr::kTarget , source_func->target ));
110125 }
111-
112- lowered_functions[target]->Update (lowered_func->cached_func ->funcs );
113126 }
114- return backend::TargetStrModuleMapToTargetModuleMap (lowered_functions) ;
127+ return mod ;
115128 }
116129
117130 Array<tvm::runtime::Module> LowerExternalFunctions () {
@@ -830,9 +843,9 @@ void UpdateFunctionMetadata(Function relay_func,
830843 function_metadata.Set (prim_fn_var.value ()->name_hint , fi);
831844}
832845
833- LoweredModule LowerTE (const IRModule& module , TargetMap targets, DeviceMap device_context_map,
834- backend::StaticMemoryPlan memory_plan, const String& module_name,
835- std::function<void (Function)> process_fn) {
846+ IRModule LowerTE (const IRModule& module , TargetMap targets, DeviceMap device_context_map,
847+ backend::StaticMemoryPlan memory_plan, const String& module_name,
848+ std::function<void (Function)> process_fn) {
836849 DLOG (INFO) << " lowering module:\n " << PrettyPrint (module );
837850
838851 TECompiler compiler;
@@ -864,76 +877,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
864877 (*te_compiler_update_weights)(weight_map);
865878 }
866879
867- LoweredModule lowered_module;
868- lowered_module.main_module = updated_module;
869- lowered_module.per_target_module = compiler->GetLoweredFunctions ();
870- lowered_module.external_mods = compiler->LowerExternalFunctions ();
871- lowered_module.main_func_info = func_info;
872- return lowered_module;
873- }
880+ // Copy the lowered functions into the return module
881+ updated_module->Update (compiler->GetLoweredFunctions ());
874882
875- IRModule LoweredModuleToIRModule (LoweredModule mod) {
876- IRModule unified_module;
883+ // Annotate the module with the external modules and function info
884+ updated_module = WithAttr (updated_module, " external_mods" , compiler->LowerExternalFunctions ());
885+ updated_module = WithAttr (updated_module, " main_func_info" , func_info);
877886
878- // Copy the main module and its typedefs
879- for (const auto & kv : mod.main_module ->functions ) {
880- unified_module->Add (kv.first , kv.second );
881- }
882- for (const auto & kv : mod.main_module ->type_definitions ) {
883- unified_module->AddTypeDef (kv.first , kv.second );
884- }
885-
886- // Annotate the per-target functions with their target and add them to the unified module
887- for (const auto & kv : mod.per_target_module ) {
888- const Target target = kv.first ;
889- const IRModule target_module = kv.second ;
890-
891- // Right now, per-target functions are TIR functions, which don't have type definitions, so
892- // there should be no type defs in the per_target_modules
893- size_t ty_def_size = target_module->type_definitions .size ();
894- ICHECK (ty_def_size == 0 )
895- << " Expected there to be no type definitions in the per_target_modules, but found "
896- << ty_def_size;
897-
898- for (const auto & kv : target_module->functions ) {
899- const GlobalVar& var = kv.first ;
900- const BaseFunc& func = kv.second ;
901- if (func->IsInstance <tir::PrimFuncNode>()) {
902- tir::PrimFunc primFunc =
903- WithAttr (Downcast<tir::PrimFunc>(std::move (func)), tvm::attr::kTarget , target);
904- unified_module->Add (var, primFunc);
905- } else if (func->IsInstance <relay::FunctionNode>()) {
906- relay::Function relayFunc =
907- WithAttr (Downcast<relay::Function>(std::move (func)), tvm::attr::kTarget , target);
908- unified_module->Add (var, relayFunc);
909- } else {
910- LOG (FATAL)
911- << " We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
912- << func->GetTypeKey ();
913- }
914- }
915- }
916-
917- IRModule ret_mod = WithAttr (unified_module, " external_mods" , mod.external_mods );
918- ret_mod = WithAttr (ret_mod, " main_func_info" , mod.main_func_info );
919- return ret_mod;
887+ return updated_module;
920888}
921889
922- LoweredModule IRModuleToLoweredModule (IRModule mod) {
923- IRModule main_mod;
924- // Copy just the TypeDefs from the IRModule to the LoweredModule's main module
925- // This is the only time we need to do this since there are no TypeDefs in TIR
926- for (const auto & kv : mod->type_definitions ) {
927- main_mod->AddTypeDef (kv.first , kv.second );
928- }
929-
930- Map<Target, IRModule> per_target_modules;
890+ Map<Target, IRModule> GetPerTargetModules (IRModule mod) {
891+ std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
892+ per_target_modules;
931893 for (const auto & kv : mod->functions ) {
932894 const GlobalVar& var = kv.first ;
933895 const BaseFunc& func = kv.second ;
934- if (func->IsInstance <relay::FunctionNode>()) {
935- main_mod->Add (var, func);
936- } else if (func->IsInstance <tir::PrimFuncNode>()) {
896+ if (func->IsInstance <tir::PrimFuncNode>()) {
937897 // Extract target
938898 Optional<Target> target = func->GetAttr <Target>(tvm::attr::kTarget );
939899 ICHECK (target) << " Target should be set at this point" ;
@@ -943,44 +903,47 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
943903 // Initialize the IRModule for this target and add the function
944904 IRModule target_module;
945905 target_module->Add (var, func);
946- per_target_modules. Set ( target.value (), target_module) ;
906+ per_target_modules[ target.value ()] = target_module;
947907 } else {
948908 // The IRModule for this target is initialized, so just add the function.
949909 IRModule target_module = per_target_modules.at (target.value ());
950910 target_module->Add (var, func);
951911 }
952- } else {
912+ } else if (!func-> IsInstance <relay::FunctionNode>()) {
953913 LOG (FATAL)
954914 << " The function types in the IRModule should be RelayFunction or PrimFunc, but got "
955915 << func->GetTypeKey ();
956916 }
957917 }
918+ return per_target_modules;
919+ }
958920
959- // Put the LoweredModule together
960- LoweredModule lowered_module;
961- lowered_module.main_module = main_mod;
962- lowered_module.per_target_module = per_target_modules;
963-
964- // Extract external modules and main func info, add to lowered module if they exist
965- auto external_mods = mod->GetAttr <Array<tvm::runtime::Module>>(" external_mods" );
966- if (external_mods) {
967- lowered_module.external_mods = external_mods.value ();
921+ IRModule GetMainModule (IRModule mod) {
922+ IRModule main_module;
923+ // Copy the type defs
924+ for (const auto & kv : mod->type_definitions ) {
925+ main_module->AddTypeDef (kv.first , kv.second );
968926 }
969- auto main_func_info = mod->GetAttr <backend::FunctionInfo>(" main_func_info" );
970- if (main_func_info) {
971- lowered_module.main_func_info = main_func_info.value ();
927+ // Copy all Relay functions (we don't include PrimFuncs)
928+ for (auto kv : mod->functions ) {
929+ const GlobalVar& var = kv.first ;
930+ const BaseFunc& func = kv.second ;
931+ if (func->IsInstance <tvm::relay::FunctionNode>()) {
932+ main_module->Add (var, func);
933+ }
972934 }
973- return lowered_module ;
935+ return main_module ;
974936}
975937
976938Pass LowerTEPass (TargetMap targets, DeviceMap device_context_map,
977939 backend::StaticMemoryPlan memory_plan, const String& module_name,
978940 std::function<void (Function)> process_fn) {
979941 runtime::TypedPackedFunc<IRModule (IRModule, PassContext)> pass_func = [=](IRModule module ,
980942 PassContext ctx) {
981- return LoweredModuleToIRModule (
982- LowerTE (module , targets, device_context_map, memory_plan, module_name, process_fn));
943+ return LowerTE (module , targets, device_context_map, memory_plan, module_name, process_fn);
983944 };
945+ // TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
946+ // be called afterwards
984947 return tvm::transform::CreateModulePass (pass_func, 1 , " LowerTE" , {});
985948}
986949} // namespace tec
0 commit comments