Skip to content

Commit 02c850f

Browse files
committed
vm external codegen (apache#4544)
1 parent 3a0a606 commit 02c850f

File tree

3 files changed

+99
-56
lines changed

3 files changed

+99
-56
lines changed

src/relay/backend/vm/compiler.cc

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
476476
argument_registers.push_back(reg->second);
477477
}
478478

479-
// Next generate the invoke instruction.
480479
Target target;
481-
if (targets_.size() == 1) {
482-
// homogeneous execution.
483-
for (auto kv : targets_) {
484-
target = kv.second;
485-
}
480+
481+
if (!func->UseDefaultCompiler()) {
482+
target = tvm::target::ext_dev();
486483
} else {
487-
// heterogeneous execution.
488-
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
484+
// Next generate the invoke instruction.
485+
if (targets_.size() == 1) {
486+
// homogeneous execution.
487+
const auto& it = targets_.begin();
488+
target = (*it).second;
489+
} else {
490+
// heterogeneous execution.
491+
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
492+
}
489493
}
490494

491495
auto key = CCacheKeyNode::make(func, target);
492496
auto cfunc = engine_->Lower(key);
493497

494-
// TODO(jroesch): support lowered funcs for multiple targets
495-
CHECK_EQ(cfunc->funcs.size(), 1);
496498
auto op_index = -1;
497-
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
499+
if (!func->UseDefaultCompiler()) {
498500
op_index = context_->cached_funcs.size();
499501
context_->cached_funcs.push_back(cfunc);
500-
context_->seen_funcs[cfunc->funcs[0]] = op_index;
501502
} else {
502-
op_index = context_->seen_funcs[cfunc->funcs[0]];
503+
// TODO(jroesch): support lowered funcs for multiple targets
504+
CHECK_EQ(cfunc->funcs.size(), 1);
505+
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
506+
op_index = context_->cached_funcs.size();
507+
context_->cached_funcs.push_back(cfunc);
508+
context_->seen_funcs[cfunc->funcs[0]] = op_index;
509+
} else {
510+
op_index = context_->seen_funcs[cfunc->funcs[0]];
511+
}
503512
}
504513

505514
Emit(Instruction::InvokePacked(op_index,
@@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
950959
if (cached_funcs.size() == 0) {
951960
return;
952961
}
953-
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
954-
for (auto &cfunc : cached_funcs) {
962+
std::unordered_map<std::string, Array<LoweredFunc>> funcs;
963+
for (auto& cfunc : cached_funcs) {
955964
std::string target_str = cfunc->target->str();
956-
if (tgt_funcs.count(target_str) == 0) {
957-
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
965+
if (target_str == "ext_dev") {
966+
continue;
967+
} else if (funcs.count(target_str) == 0) {
968+
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
958969
} else {
959-
tgt_funcs[target_str].push_back(cfunc->funcs[0]);
970+
funcs[target_str].push_back(cfunc->funcs[0]);
960971
}
961972
}
962-
Map<Target, Array<LoweredFunc>> funcs;
963-
for (auto &it : tgt_funcs) {
964-
funcs.Set(Target::Create(it.first), it.second);
965-
}
966973

967-
if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
968-
// The target is just a dummy arg because funcs already contains corresponding target
969-
// therefore target won't be used in the build function
970-
runtime::Module mod = (*f)(funcs, Target(), target_host_);
974+
auto compile_engine = CompileEngine::Global();
975+
auto ext_mods = compile_engine->LowerExternalFunctions();
976+
runtime::Module mod;
977+
if (funcs.size() > 0) {
978+
mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
971979
CHECK(mod.operator->());
972-
exec_->lib = mod;
973980
} else {
974-
LOG(FATAL) << "relay.backend.build is not registered";
981+
CHECK_EQ(ext_mods.size(), 1U)
982+
<< "Expect to have a TVM DSOModule when multiple runtime modules exist";
983+
}
984+
if (!ext_mods.empty()) {
985+
if (funcs.size() == 0) {
986+
mod = ext_mods[0];
987+
} else {
988+
// Import all external runtime modules.
989+
for (auto it : ext_mods) {
990+
mod.Import(it);
991+
}
992+
}
975993
}
994+
exec_->lib = mod;
976995
size_t primitive_index = 0;
977996
for (auto cfunc : cached_funcs) {
978-
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
997+
if (cfunc->target->str() == "ext_dev") {
998+
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
999+
} else {
1000+
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
1001+
}
9791002
}
9801003
}
9811004

src/runtime/vm/vm.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
800800
if (packed_funcs_.size() <= packed_index) {
801801
packed_funcs_.resize(packed_index + 1);
802802
}
803-
packed_funcs_[packed_index] = lib.GetFunction(packed_name);
803+
tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
804+
CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
805+
packed_funcs_[packed_index] = pf;
804806
}
805807
}
806808

tests/python/relay/test_external_codegen.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,54 @@
2626
from tvm import relay
2727
from tvm.contrib import util
2828

29-
def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
29+
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
30+
ctx=tvm.cpu()):
3031
if sys.platform == "win32":
3132
print("Skip test on Windows for now")
3233
return
3334

34-
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
35-
json, lib, _ = relay.build(mod, "llvm")
36-
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
37-
source_dir = os.path.join(test_dir, "..", "..", "..")
38-
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
39-
40-
kwargs = {}
41-
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
42-
tmp_path = util.tempdir()
43-
lib_name = 'lib.so'
44-
lib_path = tmp_path.relpath(lib_name)
45-
lib.export_library(lib_path, fcompile=False, **kwargs)
46-
lib = tvm.module.load(lib_path)
47-
48-
ctx = tvm.cpu()
49-
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
50-
51-
for name, data in map_inputs.items():
52-
rt_mod.set_input(name, data)
53-
54-
rt_mod.run()
55-
out = tvm.nd.empty(out_shape, ctx=ctx)
56-
out = rt_mod.get_output(0, out)
57-
58-
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
35+
def update_lib(lib):
36+
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
37+
source_dir = os.path.join(test_dir, "..", "..", "..")
38+
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
39+
40+
kwargs = {}
41+
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
42+
tmp_path = util.tempdir()
43+
lib_name = 'lib.so'
44+
lib_path = tmp_path.relpath(lib_name)
45+
lib.export_library(lib_path, fcompile=False, **kwargs)
46+
lib = tvm.module.load(lib_path)
47+
48+
return lib
49+
50+
def check_vm_result():
51+
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
52+
exe = relay.vm.compile(mod, target=target)
53+
code, lib = exe.save()
54+
lib = update_lib(lib)
55+
exe = relay.vm.Executable.load_exec(code, lib)
56+
vm = relay.vm.VirtualMachine(exe)
57+
vm.init(ctx)
58+
out = vm.run(**map_inputs)
59+
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
60+
61+
def check_graph_runtime_result():
62+
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
63+
json, lib, _ = relay.build(mod, target=target)
64+
lib = update_lib(lib)
65+
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
66+
67+
for name, data in map_inputs.items():
68+
rt_mod.set_input(name, data)
69+
rt_mod.run()
70+
out = tvm.nd.empty(out_shape, ctx=ctx)
71+
out = rt_mod.get_output(0, out)
72+
73+
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
74+
75+
check_vm_result()
76+
check_graph_runtime_result()
5977

6078

6179
def set_external_func_attr(func, compiler, ext_symbol):

0 commit comments

Comments
 (0)