Skip to content

Commit 6ed813b

Browse files
Remove LoweredModule
1 parent 2f1c845 commit 6ed813b

File tree

7 files changed

+130
-172
lines changed

7 files changed

+130
-172
lines changed

include/tvm/runtime/container/map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ class Map : public ObjectRef {
13531353
* Otherwise make a new copy of the array to ensure the current handle
13541354
* hold a unique copy.
13551355
*
1356-
* \return Handle to the internal node container(which ganrantees to be unique)
1356+
* \return Handle to the internal node container(which guarantees to be unique)
13571357
*/
13581358
MapNode* CopyOnWrite() {
13591359
if (data_.get() == nullptr) {

include/tvm/target/target.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <tvm/target/target_kind.h>
3232

3333
#include <string>
34-
#include <unordered_map>
3534
#include <unordered_set>
3635
#include <vector>
3736

src/relay/backend/aot_executor_codegen.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
583583
// performing the preexisting AOT executor code generation phase.
584584
IRModule mod = IRModule::FromExpr(func);
585585

586-
IRModule new_mod =
586+
IRModule lowered_mod =
587587
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
588588
// We need to maintain the constant map for external
589589
// functions so we pass this processing function which
@@ -598,9 +598,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
598598
tec::UpdateFunctionMetadata(func, this->function_metadata_);
599599
})(mod);
600600

601-
tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
602-
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
603-
auto lowered_main = lowered_module.main_module->Lookup("main");
601+
Optional<backend::FunctionInfo> main_func_info =
602+
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
603+
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
604+
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
605+
auto lowered_main = lowered_mod->Lookup("main");
606+
604607
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
605608

606609
// Post-lowering storage map for writing main func - this should be the same map as previously
@@ -662,8 +665,13 @@ class AOTExecutorCodegen : public MixedModeVisitor {
662665

663666
ret.function_metadata = std::move(function_metadata_);
664667

665-
ret.lowered_funcs = lowered_module.per_target_module;
666-
ret.external_mods = lowered_module.external_mods;
668+
Optional<Array<tvm::runtime::Module>> external_modules =
669+
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
670+
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
671+
672+
// This is the point where we separate the functions in the module by target
673+
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
674+
ret.external_mods = external_modules.value();
667675

668676
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
669677
ret.lowered_funcs[target_host_]->Update(mod_run);

src/relay/backend/graph_executor_codegen.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
221221
device_context_map.insert({expr, dev});
222222
}
223223

224-
IRModule new_mod =
224+
IRModule lowered_mod =
225225
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
226226
// We need to maintain the constant map for external
227227
// functions so we pass this processing function which
@@ -236,9 +236,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
236236
tec::UpdateFunctionMetadata(func, this->function_metadata_);
237237
})(mod);
238238

239-
tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
240-
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
241-
auto main_module = lowered_module.main_module;
239+
Optional<backend::FunctionInfo> main_func_info =
240+
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
241+
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
242+
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
243+
244+
// Get only the Relay functions out of the lowered module so we can run type inference on them
245+
IRModule main_module = tec::GetMainModule(lowered_mod);
242246
main_module = relay::transform::InferType()(main_module);
243247
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));
244248

@@ -270,8 +274,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
270274
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
271275
}
272276
ret.function_metadata = std::move(function_metadata_);
273-
ret.lowered_funcs = lowered_module.per_target_module;
274-
ret.external_mods = lowered_module.external_mods;
277+
278+
Optional<Array<tvm::runtime::Module>> external_modules =
279+
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
280+
// This is the point where we separate the functions in the module by target
281+
282+
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
283+
ret.external_mods = external_modules.value();
275284
return ret;
276285
}
277286

src/relay/backend/interpreter.cc

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
292292
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
293293
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
294294
public:
295-
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
295+
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
296+
// LoweredModule.
296297
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
297298
: mod_(mod),
298299
per_target_module_(per_target_module),
@@ -902,20 +903,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
902903
* functions needed by the rewritten module.
903904
*/
904905
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
905-
// Run minimal transforms on module to establish invariants needed by interpreter.
906-
transform::Sequential seq({transform::SimplifyInference(),
907-
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
908-
// attribute.
909-
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
910-
// eta expand to support constructors in argument position
911-
transform::EtaExpand(
912-
/*expand_constructor=*/true, /*expand_global_var=*/false),
913-
transform::InferType()});
914-
915-
transform::PassContext pass_ctx = transform::PassContext::Current();
916-
With<transform::PassContext> ctx(pass_ctx);
917-
mod = seq(mod);
918-
906+
// Things to initialize to pass into tec::LowerTEPass
919907
// We only have one device-specific target.
920908
tec::TargetMap targets = {{device.device_type, target}};
921909

@@ -925,13 +913,25 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
925913
// No need for a memory plan.
926914
backend::StaticMemoryPlan memory_plan; /*=nullptr*/
927915

916+
// Run minimal transforms on module to establish invariants needed by interpreter.
917+
transform::Sequential seq(
918+
{transform::SimplifyInference(),
919+
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
920+
// attribute.
921+
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
922+
// eta expand to support constructors in argument position
923+
transform::EtaExpand(
924+
/*expand_constructor=*/true, /*expand_global_var=*/false),
925+
transform::InferType(),
926+
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
927+
[](Function func) { /* no-op */ })});
928+
929+
transform::PassContext pass_ctx = transform::PassContext::Current();
930+
With<transform::PassContext> ctx(pass_ctx);
931+
mod = seq(mod);
932+
928933
// Lower all primitive functions reachable from expr.
929-
// TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to
930-
// be merged into IRModule.
931-
LoweredModule lowered_module =
932-
tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp",
933-
[](Function func) { /* no-op */ });
934-
return {lowered_module.main_module, lowered_module.per_target_module};
934+
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
935935
}
936936

937937
/*! \brief Check if an expression could be changed by \p Prepare.

src/relay/backend/te_compiler.cc

Lines changed: 63 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,51 @@ class TECompilerImpl : public TECompilerNode {
8585
return LowerShapeFuncInternal(key)->cached_func;
8686
}
8787

88-
Map<Target, IRModule> GetLoweredFunctions() {
89-
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
90-
lowered_functions;
88+
IRModule GetLoweredFunctions() {
89+
IRModule mod;
90+
// TODO(@electriclilies): This chunk of code is pretty much the same for
91+
// normal cache and shape func cache. Consider making a helper here to do that.
92+
// Additionaly, might be good to overhaul the mod->Update(mod) function (it's broken!)
9193
for (const auto& it : cache_) {
9294
auto source_func = it.first;
95+
// TODO(@electriclilies): Does the lowered_func module only contain one function?
9396
auto lowered_func = it.second;
94-
auto target = source_func->target;
9597

96-
if (!lowered_functions.count(target)) {
97-
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
98-
}
98+
IRModule lowered_mod = lowered_func->cached_func->funcs;
9999

100-
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
100+
// Annotate functions with their target and put them in the return module
101+
for (auto kv : lowered_mod->functions) {
102+
const GlobalVar& var = kv.first;
103+
const BaseFunc& func = kv.second;
104+
105+
if (func->IsInstance<relay::FunctionNode>()) {
106+
const relay::Function relay_func = Downcast<relay::Function>(func);
107+
mod->Update(var, WithAttr(relay_func, tvm::attr::kTarget, source_func->target));
108+
} else if (func->IsInstance<tir::PrimFuncNode>()) {
109+
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
110+
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
111+
} else {
112+
LOG(FATAL) << "Expected to find only relay functions and prim functions in the cache, "
113+
"but found: "
114+
<< func->GetTypeKey();
115+
}
116+
}
101117
}
102118

103119
for (const auto& it : shape_func_cache_) {
104120
auto source_func = it.first;
105121
auto lowered_func = it.second;
106122
auto target = source_func->target;
123+
IRModule lowered_mod = lowered_func->cached_func->funcs;
107124

108-
if (!lowered_functions.count(target)) {
109-
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
125+
for (auto kv : lowered_mod->functions) {
126+
const GlobalVar& var = kv.first;
127+
const BaseFunc& func = kv.second;
128+
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
129+
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
110130
}
111-
112-
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
113131
}
114-
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
132+
return mod;
115133
}
116134

117135
Array<tvm::runtime::Module> LowerExternalFunctions() {
@@ -830,9 +848,9 @@ void UpdateFunctionMetadata(Function relay_func,
830848
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
831849
}
832850

833-
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
834-
backend::StaticMemoryPlan memory_plan, const String& module_name,
835-
std::function<void(Function)> process_fn) {
851+
IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
852+
backend::StaticMemoryPlan memory_plan, const String& module_name,
853+
std::function<void(Function)> process_fn) {
836854
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);
837855

838856
TECompiler compiler;
@@ -864,76 +882,24 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
864882
(*te_compiler_update_weights)(weight_map);
865883
}
866884

867-
LoweredModule lowered_module;
868-
lowered_module.main_module = updated_module;
869-
lowered_module.per_target_module = compiler->GetLoweredFunctions();
870-
lowered_module.external_mods = compiler->LowerExternalFunctions();
871-
lowered_module.main_func_info = func_info;
872-
return lowered_module;
873-
}
874-
875-
IRModule LoweredModuleToIRModule(LoweredModule mod) {
876-
IRModule unified_module;
877-
878-
// Copy the main module and its typedefs
879-
for (const auto& kv : mod.main_module->functions) {
880-
unified_module->Add(kv.first, kv.second);
881-
}
882-
for (const auto& kv : mod.main_module->type_definitions) {
883-
unified_module->AddTypeDef(kv.first, kv.second);
884-
}
885+
// Copy the lowered functions into the return module
886+
std::cout << "Getting lowered funcs" << std::endl;
887+
updated_module->Update(compiler->GetLoweredFunctions());
885888

886-
// Annotate the per-target functions with their target and add them to the unified module
887-
for (const auto& kv : mod.per_target_module) {
888-
const Target target = kv.first;
889-
const IRModule target_module = kv.second;
890-
891-
// Right now, per-target functions are TIR functions, which don't have type definitions, so
892-
// there should be no type defs in the per_target_modules
893-
size_t ty_def_size = target_module->type_definitions.size();
894-
ICHECK(ty_def_size == 0)
895-
<< "Expected there to be no type definitions in the per_target_modules, but found "
896-
<< ty_def_size;
897-
898-
for (const auto& kv : target_module->functions) {
899-
const GlobalVar& var = kv.first;
900-
const BaseFunc& func = kv.second;
901-
if (func->IsInstance<tir::PrimFuncNode>()) {
902-
tir::PrimFunc primFunc =
903-
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
904-
unified_module->Add(var, primFunc);
905-
} else if (func->IsInstance<relay::FunctionNode>()) {
906-
relay::Function relayFunc =
907-
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
908-
unified_module->Add(var, relayFunc);
909-
} else {
910-
LOG(FATAL)
911-
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
912-
<< func->GetTypeKey();
913-
}
914-
}
915-
}
889+
// Annotate the module with the external modules and function info
890+
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
891+
updated_module = WithAttr(updated_module, "main_func_info", func_info);
916892

917-
IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
918-
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
919-
return ret_mod;
893+
return updated_module;
920894
}
921895

922-
LoweredModule IRModuleToLoweredModule(IRModule mod) {
923-
IRModule main_mod;
924-
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
925-
// This is the only time we need to do this since there are no TypeDefs in TIR
926-
for (const auto& kv : mod->type_definitions) {
927-
main_mod->AddTypeDef(kv.first, kv.second);
928-
}
929-
930-
Map<Target, IRModule> per_target_modules;
896+
Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
897+
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
898+
per_target_modules;
931899
for (const auto& kv : mod->functions) {
932900
const GlobalVar& var = kv.first;
933901
const BaseFunc& func = kv.second;
934-
if (func->IsInstance<relay::FunctionNode>()) {
935-
main_mod->Add(var, func);
936-
} else if (func->IsInstance<tir::PrimFuncNode>()) {
902+
if (func->IsInstance<tir::PrimFuncNode>()) {
937903
// Extract target
938904
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
939905
ICHECK(target) << "Target should be set at this point";
@@ -943,43 +909,44 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
943909
// Initialize the IRModule for this target and add the function
944910
IRModule target_module;
945911
target_module->Add(var, func);
946-
per_target_modules.Set(target.value(), target_module);
912+
per_target_modules[target.value()] = target_module;
947913
} else {
948914
// The IRModule for this target is initialized, so just add the function.
949915
IRModule target_module = per_target_modules.at(target.value());
950916
target_module->Add(var, func);
951917
}
952-
} else {
918+
} else if (!func->IsInstance<relay::FunctionNode>()) {
953919
LOG(FATAL)
954920
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
955921
<< func->GetTypeKey();
956922
}
957923
}
924+
return per_target_modules;
925+
}
958926

959-
// Put the LoweredModule together
960-
LoweredModule lowered_module;
961-
lowered_module.main_module = main_mod;
962-
lowered_module.per_target_module = per_target_modules;
963-
964-
// Extract external modules and main func info, add to lowered module if they exist
965-
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
966-
if (external_mods) {
967-
lowered_module.external_mods = external_mods.value();
927+
IRModule GetMainModule(IRModule mod) {
928+
IRModule main_module;
929+
// Copy the type defs
930+
for (const auto& kv : mod->type_definitions) {
931+
main_module->AddTypeDef(kv.first, kv.second);
968932
}
969-
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
970-
if (main_func_info) {
971-
lowered_module.main_func_info = main_func_info.value();
933+
// Copy all Relay functions (we don't include PrimFuncs)
934+
for (auto kv : mod->functions) {
935+
const GlobalVar& var = kv.first;
936+
const BaseFunc& func = kv.second;
937+
if (func->IsInstance<tvm::relay::FunctionNode>()) {
938+
main_module->Add(var, func);
939+
}
972940
}
973-
return lowered_module;
941+
return main_module;
974942
}
975943

976944
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
977945
backend::StaticMemoryPlan memory_plan, const String& module_name,
978946
std::function<void(Function)> process_fn) {
979947
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
980948
PassContext ctx) {
981-
return LoweredModuleToIRModule(
982-
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
949+
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
983950
};
984951
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
985952
}

0 commit comments

Comments
 (0)