@@ -85,33 +85,51 @@ class TECompilerImpl : public TECompilerNode {
8585 return LowerShapeFuncInternal (key)->cached_func ;
8686 }
8787
88- Map<Target, IRModule> GetLoweredFunctions () {
89- std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
90- lowered_functions;
88+ IRModule GetLoweredFunctions () {
89+ IRModule mod;
90+ // TODO(@electriclilies): This chunk of code is pretty much the same for
91+ // normal cache and shape func cache. Consider making a helper here to do that.
92+ // Additionaly, might be good to overhaul the mod->Update(mod) function (it's broken!)
9193 for (const auto & it : cache_) {
9294 auto source_func = it.first ;
95+ // TODO(@electriclilies): Does the lowered_func module only contain one function?
9396 auto lowered_func = it.second ;
94- auto target = source_func->target ;
9597
96- if (!lowered_functions.count (target)) {
97- lowered_functions[target] = IRModule (Map<GlobalVar, BaseFunc>({}));
98- }
98+ IRModule lowered_mod = lowered_func->cached_func ->funcs ;
9999
100- lowered_functions[target]->Update (lowered_func->cached_func ->funcs );
100+ // Annotate functions with their target and put them in the return module
101+ for (auto kv : lowered_mod->functions ) {
102+ const GlobalVar& var = kv.first ;
103+ const BaseFunc& func = kv.second ;
104+
105+ if (func->IsInstance <relay::FunctionNode>()) {
106+ const relay::Function relay_func = Downcast<relay::Function>(func);
107+ mod->Update (var, WithAttr (relay_func, tvm::attr::kTarget , source_func->target ));
108+ } else if (func->IsInstance <tir::PrimFuncNode>()) {
109+ const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
110+ mod->Update (var, WithAttr (prim_func, tvm::attr::kTarget , source_func->target ));
111+ } else {
112+ LOG (FATAL) << " Expected to find only relay functions and prim functions in the cache, "
113+ " but found: "
114+ << func->GetTypeKey ();
115+ }
116+ }
101117 }
102118
103119 for (const auto & it : shape_func_cache_) {
104120 auto source_func = it.first ;
105121 auto lowered_func = it.second ;
106122 auto target = source_func->target ;
123+ IRModule lowered_mod = lowered_func->cached_func ->funcs ;
107124
108- if (!lowered_functions.count (target)) {
109- lowered_functions[target] = IRModule (Map<GlobalVar, BaseFunc>({}));
125+ for (auto kv : lowered_mod->functions ) {
126+ const GlobalVar& var = kv.first ;
127+ const BaseFunc& func = kv.second ;
128+ const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
129+ mod->Update (var, WithAttr (prim_func, tvm::attr::kTarget , source_func->target ));
110130 }
111-
112- lowered_functions[target]->Update (lowered_func->cached_func ->funcs );
113131 }
114- return backend::TargetStrModuleMapToTargetModuleMap (lowered_functions) ;
132+ return mod ;
115133 }
116134
117135 Array<tvm::runtime::Module> LowerExternalFunctions () {
@@ -830,9 +848,9 @@ void UpdateFunctionMetadata(Function relay_func,
830848 function_metadata.Set (prim_fn_var.value ()->name_hint , fi);
831849}
832850
833- LoweredModule LowerTE (const IRModule& module , TargetMap targets, DeviceMap device_context_map,
834- backend::StaticMemoryPlan memory_plan, const String& module_name,
835- std::function<void (Function)> process_fn) {
851+ IRModule LowerTE (const IRModule& module , TargetMap targets, DeviceMap device_context_map,
852+ backend::StaticMemoryPlan memory_plan, const String& module_name,
853+ std::function<void (Function)> process_fn) {
836854 DLOG (INFO) << " lowering module:\n " << PrettyPrint (module );
837855
838856 TECompiler compiler;
@@ -864,76 +882,24 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
864882 (*te_compiler_update_weights)(weight_map);
865883 }
866884
867- LoweredModule lowered_module;
868- lowered_module.main_module = updated_module;
869- lowered_module.per_target_module = compiler->GetLoweredFunctions ();
870- lowered_module.external_mods = compiler->LowerExternalFunctions ();
871- lowered_module.main_func_info = func_info;
872- return lowered_module;
873- }
874-
875- IRModule LoweredModuleToIRModule (LoweredModule mod) {
876- IRModule unified_module;
877-
878- // Copy the main module and its typedefs
879- for (const auto & kv : mod.main_module ->functions ) {
880- unified_module->Add (kv.first , kv.second );
881- }
882- for (const auto & kv : mod.main_module ->type_definitions ) {
883- unified_module->AddTypeDef (kv.first , kv.second );
884- }
885+ // Copy the lowered functions into the return module
886+ std::cout << " Getting lowered funcs" << std::endl;
887+ updated_module->Update (compiler->GetLoweredFunctions ());
885888
886- // Annotate the per-target functions with their target and add them to the unified module
887- for (const auto & kv : mod.per_target_module ) {
888- const Target target = kv.first ;
889- const IRModule target_module = kv.second ;
890-
891- // Right now, per-target functions are TIR functions, which don't have type definitions, so
892- // there should be no type defs in the per_target_modules
893- size_t ty_def_size = target_module->type_definitions .size ();
894- ICHECK (ty_def_size == 0 )
895- << " Expected there to be no type definitions in the per_target_modules, but found "
896- << ty_def_size;
897-
898- for (const auto & kv : target_module->functions ) {
899- const GlobalVar& var = kv.first ;
900- const BaseFunc& func = kv.second ;
901- if (func->IsInstance <tir::PrimFuncNode>()) {
902- tir::PrimFunc primFunc =
903- WithAttr (Downcast<tir::PrimFunc>(std::move (func)), tvm::attr::kTarget , target);
904- unified_module->Add (var, primFunc);
905- } else if (func->IsInstance <relay::FunctionNode>()) {
906- relay::Function relayFunc =
907- WithAttr (Downcast<relay::Function>(std::move (func)), tvm::attr::kTarget , target);
908- unified_module->Add (var, relayFunc);
909- } else {
910- LOG (FATAL)
911- << " We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
912- << func->GetTypeKey ();
913- }
914- }
915- }
889+ // Annotate the module with the external modules and function info
890+ updated_module = WithAttr (updated_module, " external_mods" , compiler->LowerExternalFunctions ());
891+ updated_module = WithAttr (updated_module, " main_func_info" , func_info);
916892
917- IRModule ret_mod = WithAttr (unified_module, " external_mods" , mod.external_mods );
918- ret_mod = WithAttr (ret_mod, " main_func_info" , mod.main_func_info );
919- return ret_mod;
893+ return updated_module;
920894}
921895
922- LoweredModule IRModuleToLoweredModule (IRModule mod) {
923- IRModule main_mod;
924- // Copy just the TypeDefs from the IRModule to the LoweredModule's main module
925- // This is the only time we need to do this since there are no TypeDefs in TIR
926- for (const auto & kv : mod->type_definitions ) {
927- main_mod->AddTypeDef (kv.first , kv.second );
928- }
929-
930- Map<Target, IRModule> per_target_modules;
896+ Map<Target, IRModule> GetPerTargetModules (IRModule mod) {
897+ std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
898+ per_target_modules;
931899 for (const auto & kv : mod->functions ) {
932900 const GlobalVar& var = kv.first ;
933901 const BaseFunc& func = kv.second ;
934- if (func->IsInstance <relay::FunctionNode>()) {
935- main_mod->Add (var, func);
936- } else if (func->IsInstance <tir::PrimFuncNode>()) {
902+ if (func->IsInstance <tir::PrimFuncNode>()) {
937903 // Extract target
938904 Optional<Target> target = func->GetAttr <Target>(tvm::attr::kTarget );
939905 ICHECK (target) << " Target should be set at this point" ;
@@ -943,43 +909,44 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
943909 // Initialize the IRModule for this target and add the function
944910 IRModule target_module;
945911 target_module->Add (var, func);
946- per_target_modules. Set ( target.value (), target_module) ;
912+ per_target_modules[ target.value ()] = target_module;
947913 } else {
948914 // The IRModule for this target is initialized, so just add the function.
949915 IRModule target_module = per_target_modules.at (target.value ());
950916 target_module->Add (var, func);
951917 }
952- } else {
918+ } else if (!func-> IsInstance <relay::FunctionNode>()) {
953919 LOG (FATAL)
954920 << " The function types in the IRModule should be RelayFunction or PrimFunc, but got "
955921 << func->GetTypeKey ();
956922 }
957923 }
924+ return per_target_modules;
925+ }
958926
959- // Put the LoweredModule together
960- LoweredModule lowered_module;
961- lowered_module.main_module = main_mod;
962- lowered_module.per_target_module = per_target_modules;
963-
964- // Extract external modules and main func info, add to lowered module if they exist
965- auto external_mods = mod->GetAttr <Array<tvm::runtime::Module>>(" external_mods" );
966- if (external_mods) {
967- lowered_module.external_mods = external_mods.value ();
927+ IRModule GetMainModule (IRModule mod) {
928+ IRModule main_module;
929+ // Copy the type defs
930+ for (const auto & kv : mod->type_definitions ) {
931+ main_module->AddTypeDef (kv.first , kv.second );
968932 }
969- auto main_func_info = mod->GetAttr <backend::FunctionInfo>(" main_func_info" );
970- if (main_func_info) {
971- lowered_module.main_func_info = main_func_info.value ();
933+ // Copy all Relay functions (we don't include PrimFuncs)
934+ for (auto kv : mod->functions ) {
935+ const GlobalVar& var = kv.first ;
936+ const BaseFunc& func = kv.second ;
937+ if (func->IsInstance <tvm::relay::FunctionNode>()) {
938+ main_module->Add (var, func);
939+ }
972940 }
973- return lowered_module ;
941+ return main_module ;
974942}
975943
976944Pass LowerTEPass (TargetMap targets, DeviceMap device_context_map,
977945 backend::StaticMemoryPlan memory_plan, const String& module_name,
978946 std::function<void (Function)> process_fn) {
979947 runtime::TypedPackedFunc<IRModule (IRModule, PassContext)> pass_func = [=](IRModule module ,
980948 PassContext ctx) {
981- return LoweredModuleToIRModule (
982- LowerTE (module , targets, device_context_map, memory_plan, module_name, process_fn));
949+ return LowerTE (module , targets, device_context_map, memory_plan, module_name, process_fn);
983950 };
984951 return tvm::transform::CreateModulePass (pass_func, 1 , " LowerTE" , {});
985952}
0 commit comments