Skip to content

Commit dcb01d2

Browse files
electricliliesylc
authored andcommitted
Remove LoweredModule (apache#8886)
* Remove LoweredModule * Clean up some comments * QEMU flaky tests * Don't add external functions to the LoweredFunctions module * QEMU flaky test * Respond to feedback * flaky test
1 parent 5a53d0f commit dcb01d2

File tree

7 files changed

+129
-174
lines changed

7 files changed

+129
-174
lines changed

include/tvm/runtime/container/map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ class Map : public ObjectRef {
13531353
* Otherwise make a new copy of the array to ensure the current handle
13541354
* hold a unique copy.
13551355
*
1356-
* \return Handle to the internal node container(which ganrantees to be unique)
1356+
* \return Handle to the internal node container(which guarantees to be unique)
13571357
*/
13581358
MapNode* CopyOnWrite() {
13591359
if (data_.get() == nullptr) {

include/tvm/target/target.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <tvm/target/target_kind.h>
3232

3333
#include <string>
34-
#include <unordered_map>
3534
#include <unordered_set>
3635
#include <vector>
3736

src/relay/backend/aot_executor_codegen.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
583583
// performing the preexisting AOT executor code generation phase.
584584
IRModule mod = IRModule::FromExpr(func);
585585

586-
IRModule new_mod =
586+
IRModule lowered_mod =
587587
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
588588
// We need to maintain the constant map for external
589589
// functions so we pass this processing function which
@@ -598,9 +598,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
598598
tec::UpdateFunctionMetadata(func, this->function_metadata_);
599599
})(mod);
600600

601-
tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
602-
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
603-
auto lowered_main = lowered_module.main_module->Lookup("main");
601+
Optional<backend::FunctionInfo> main_func_info =
602+
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
603+
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
604+
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
605+
auto lowered_main = lowered_mod->Lookup("main");
606+
604607
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
605608

606609
// Post-lowering storage map for writing main func - this should be the same map as previously
@@ -662,8 +665,13 @@ class AOTExecutorCodegen : public MixedModeVisitor {
662665

663666
ret.function_metadata = std::move(function_metadata_);
664667

665-
ret.lowered_funcs = lowered_module.per_target_module;
666-
ret.external_mods = lowered_module.external_mods;
668+
Optional<Array<tvm::runtime::Module>> external_modules =
669+
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
670+
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
671+
672+
// This is the point where we separate the functions in the module by target
673+
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
674+
ret.external_mods = external_modules.value();
667675

668676
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
669677
ret.lowered_funcs[target_host_]->Update(mod_run);

src/relay/backend/graph_executor_codegen.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
221221
device_context_map.insert({expr, dev});
222222
}
223223

224-
IRModule new_mod =
224+
IRModule lowered_mod =
225225
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
226226
// We need to maintain the constant map for external
227227
// functions so we pass this processing function which
@@ -236,9 +236,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
236236
tec::UpdateFunctionMetadata(func, this->function_metadata_);
237237
})(mod);
238238

239-
tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
240-
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
241-
auto main_module = lowered_module.main_module;
239+
Optional<backend::FunctionInfo> main_func_info =
240+
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
241+
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
242+
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
243+
244+
// Get only the Relay functions out of the lowered module so we can run type inference on them
245+
IRModule main_module = tec::GetMainModule(lowered_mod);
242246
main_module = relay::transform::InferType()(main_module);
243247
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));
244248

@@ -270,8 +274,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
270274
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
271275
}
272276
ret.function_metadata = std::move(function_metadata_);
273-
ret.lowered_funcs = lowered_module.per_target_module;
274-
ret.external_mods = lowered_module.external_mods;
277+
278+
Optional<Array<tvm::runtime::Module>> external_modules =
279+
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
280+
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
281+
282+
// This is the point where we separate the functions in the module by target
283+
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
284+
ret.external_mods = external_modules.value();
275285
return ret;
276286
}
277287

src/relay/backend/interpreter.cc

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
292292
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
293293
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
294294
public:
295-
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
295+
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
296+
// LoweredModule.
296297
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
297298
: mod_(mod),
298299
per_target_module_(per_target_module),
@@ -902,20 +903,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
902903
* functions needed by the rewritten module.
903904
*/
904905
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
905-
// Run minimal transforms on module to establish invariants needed by interpreter.
906-
transform::Sequential seq({transform::SimplifyInference(),
907-
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
908-
// attribute.
909-
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
910-
// eta expand to support constructors in argument position
911-
transform::EtaExpand(
912-
/*expand_constructor=*/true, /*expand_global_var=*/false),
913-
transform::InferType()});
914-
915-
transform::PassContext pass_ctx = transform::PassContext::Current();
916-
With<transform::PassContext> ctx(pass_ctx);
917-
mod = seq(mod);
918-
906+
// Things to initialize to pass into tec::LowerTEPass
919907
// We only have one device-specific target.
920908
tec::TargetMap targets = {{device.device_type, target}};
921909

@@ -925,13 +913,25 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
925913
// No need for a memory plan.
926914
backend::StaticMemoryPlan memory_plan; /*=nullptr*/
927915

916+
// Run minimal transforms on module to establish invariants needed by interpreter.
917+
transform::Sequential seq(
918+
{transform::SimplifyInference(),
919+
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
920+
// attribute.
921+
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
922+
// eta expand to support constructors in argument position
923+
transform::EtaExpand(
924+
/*expand_constructor=*/true, /*expand_global_var=*/false),
925+
transform::InferType(),
926+
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
927+
[](Function func) { /* no-op */ })});
928+
929+
transform::PassContext pass_ctx = transform::PassContext::Current();
930+
With<transform::PassContext> ctx(pass_ctx);
931+
mod = seq(mod);
932+
928933
// Lower all primitive functions reachable from expr.
929-
// TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to
930-
// be merged into IRModule.
931-
LoweredModule lowered_module =
932-
tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp",
933-
[](Function func) { /* no-op */ });
934-
return {lowered_module.main_module, lowered_module.per_target_module};
934+
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
935935
}
936936

937937
/*! \brief Check if an expression could be changed by \p Prepare.

src/relay/backend/te_compiler.cc

Lines changed: 61 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,46 @@ class TECompilerImpl : public TECompilerNode {
8585
return LowerShapeFuncInternal(key)->cached_func;
8686
}
8787

88-
Map<Target, IRModule> GetLoweredFunctions() {
89-
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
90-
lowered_functions;
88+
IRModule GetLoweredFunctions() {
89+
IRModule mod;
90+
// Extract lowered functions from the cache
9191
for (const auto& it : cache_) {
9292
auto source_func = it.first;
9393
auto lowered_func = it.second;
94-
auto target = source_func->target;
9594

96-
if (!lowered_functions.count(target)) {
97-
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
98-
}
95+
IRModule lowered_mod = lowered_func->cached_func->funcs;
9996

100-
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
101-
}
97+
// Annotate functions with their target and put them in the return module
98+
for (auto kv : lowered_mod->functions) {
99+
const GlobalVar& var = kv.first;
100+
const BaseFunc& func = kv.second;
102101

102+
// Only add functions that are not external functions
103+
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
104+
ICHECK(func->IsInstance<tir::PrimFuncNode>())
105+
<< "Expected all functions that are not external to be PrimFuncs, but found "
106+
<< func->GetTypeKey();
107+
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
108+
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
109+
}
110+
}
111+
}
112+
// Extract lowered dynamic shape functions from the shape cache
103113
for (const auto& it : shape_func_cache_) {
104114
auto source_func = it.first;
105115
auto lowered_func = it.second;
106116
auto target = source_func->target;
107-
108-
if (!lowered_functions.count(target)) {
109-
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
117+
IRModule lowered_mod = lowered_func->cached_func->funcs;
118+
119+
// Annotate functions with their target and put them in the return module
120+
for (auto kv : lowered_mod->functions) {
121+
const GlobalVar& var = kv.first;
122+
const BaseFunc& func = kv.second;
123+
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
124+
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
110125
}
111-
112-
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
113126
}
114-
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
127+
return mod;
115128
}
116129

117130
Array<tvm::runtime::Module> LowerExternalFunctions() {
@@ -830,9 +843,9 @@ void UpdateFunctionMetadata(Function relay_func,
830843
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
831844
}
832845

833-
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
834-
backend::StaticMemoryPlan memory_plan, const String& module_name,
835-
std::function<void(Function)> process_fn) {
846+
IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
847+
backend::StaticMemoryPlan memory_plan, const String& module_name,
848+
std::function<void(Function)> process_fn) {
836849
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);
837850

838851
TECompiler compiler;
@@ -864,76 +877,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
864877
(*te_compiler_update_weights)(weight_map);
865878
}
866879

867-
LoweredModule lowered_module;
868-
lowered_module.main_module = updated_module;
869-
lowered_module.per_target_module = compiler->GetLoweredFunctions();
870-
lowered_module.external_mods = compiler->LowerExternalFunctions();
871-
lowered_module.main_func_info = func_info;
872-
return lowered_module;
873-
}
880+
// Copy the lowered functions into the return module
881+
updated_module->Update(compiler->GetLoweredFunctions());
874882

875-
IRModule LoweredModuleToIRModule(LoweredModule mod) {
876-
IRModule unified_module;
883+
// Annotate the module with the external modules and function info
884+
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
885+
updated_module = WithAttr(updated_module, "main_func_info", func_info);
877886

878-
// Copy the main module and its typedefs
879-
for (const auto& kv : mod.main_module->functions) {
880-
unified_module->Add(kv.first, kv.second);
881-
}
882-
for (const auto& kv : mod.main_module->type_definitions) {
883-
unified_module->AddTypeDef(kv.first, kv.second);
884-
}
885-
886-
// Annotate the per-target functions with their target and add them to the unified module
887-
for (const auto& kv : mod.per_target_module) {
888-
const Target target = kv.first;
889-
const IRModule target_module = kv.second;
890-
891-
// Right now, per-target functions are TIR functions, which don't have type definitions, so
892-
// there should be no type defs in the per_target_modules
893-
size_t ty_def_size = target_module->type_definitions.size();
894-
ICHECK(ty_def_size == 0)
895-
<< "Expected there to be no type definitions in the per_target_modules, but found "
896-
<< ty_def_size;
897-
898-
for (const auto& kv : target_module->functions) {
899-
const GlobalVar& var = kv.first;
900-
const BaseFunc& func = kv.second;
901-
if (func->IsInstance<tir::PrimFuncNode>()) {
902-
tir::PrimFunc primFunc =
903-
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
904-
unified_module->Add(var, primFunc);
905-
} else if (func->IsInstance<relay::FunctionNode>()) {
906-
relay::Function relayFunc =
907-
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
908-
unified_module->Add(var, relayFunc);
909-
} else {
910-
LOG(FATAL)
911-
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
912-
<< func->GetTypeKey();
913-
}
914-
}
915-
}
916-
917-
IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
918-
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
919-
return ret_mod;
887+
return updated_module;
920888
}
921889

922-
LoweredModule IRModuleToLoweredModule(IRModule mod) {
923-
IRModule main_mod;
924-
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
925-
// This is the only time we need to do this since there are no TypeDefs in TIR
926-
for (const auto& kv : mod->type_definitions) {
927-
main_mod->AddTypeDef(kv.first, kv.second);
928-
}
929-
930-
Map<Target, IRModule> per_target_modules;
890+
Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
891+
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
892+
per_target_modules;
931893
for (const auto& kv : mod->functions) {
932894
const GlobalVar& var = kv.first;
933895
const BaseFunc& func = kv.second;
934-
if (func->IsInstance<relay::FunctionNode>()) {
935-
main_mod->Add(var, func);
936-
} else if (func->IsInstance<tir::PrimFuncNode>()) {
896+
if (func->IsInstance<tir::PrimFuncNode>()) {
937897
// Extract target
938898
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
939899
ICHECK(target) << "Target should be set at this point";
@@ -943,44 +903,47 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
943903
// Initialize the IRModule for this target and add the function
944904
IRModule target_module;
945905
target_module->Add(var, func);
946-
per_target_modules.Set(target.value(), target_module);
906+
per_target_modules[target.value()] = target_module;
947907
} else {
948908
// The IRModule for this target is initialized, so just add the function.
949909
IRModule target_module = per_target_modules.at(target.value());
950910
target_module->Add(var, func);
951911
}
952-
} else {
912+
} else if (!func->IsInstance<relay::FunctionNode>()) {
953913
LOG(FATAL)
954914
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
955915
<< func->GetTypeKey();
956916
}
957917
}
918+
return per_target_modules;
919+
}
958920

959-
// Put the LoweredModule together
960-
LoweredModule lowered_module;
961-
lowered_module.main_module = main_mod;
962-
lowered_module.per_target_module = per_target_modules;
963-
964-
// Extract external modules and main func info, add to lowered module if they exist
965-
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
966-
if (external_mods) {
967-
lowered_module.external_mods = external_mods.value();
921+
IRModule GetMainModule(IRModule mod) {
922+
IRModule main_module;
923+
// Copy the type defs
924+
for (const auto& kv : mod->type_definitions) {
925+
main_module->AddTypeDef(kv.first, kv.second);
968926
}
969-
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
970-
if (main_func_info) {
971-
lowered_module.main_func_info = main_func_info.value();
927+
// Copy all Relay functions (we don't include PrimFuncs)
928+
for (auto kv : mod->functions) {
929+
const GlobalVar& var = kv.first;
930+
const BaseFunc& func = kv.second;
931+
if (func->IsInstance<tvm::relay::FunctionNode>()) {
932+
main_module->Add(var, func);
933+
}
972934
}
973-
return lowered_module;
935+
return main_module;
974936
}
975937

976938
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
977939
backend::StaticMemoryPlan memory_plan, const String& module_name,
978940
std::function<void(Function)> process_fn) {
979941
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
980942
PassContext ctx) {
981-
return LoweredModuleToIRModule(
982-
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
943+
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
983944
};
945+
// TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
946+
// be called afterwards
984947
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
985948
}
986949
} // namespace tec

0 commit comments

Comments
 (0)