Skip to content

Commit 83cdffb

Browse files
mikepapadimYuchenJin
authored andcommitted
[Relay] Replace compile engine with TE compiler in the VM (apache#8501)
* [VM] Add imports to new TE in VM compiler * [VM] Add comments to compile engine usages * [VM] Replace depreceated CachedFunc of compile_engine with TE_compiler * [VM] rm compiler engine compiler.cc * [VM] Replace compile engine with TECompiler in memory allocator * [VM] Add relay interface to te_compiler * [Relay] Fix linting errors * Move TEcompiler to VMCompilerContext; add global func into IRmodule when lowering in TEcompiler * add back the check * skip the check for ext func in tecompiler * skip tvm::build for external functions * trigger ci * retrigger ci * retrigger ci * remove the unnecessary loop in tecompiler Co-authored-by: YuchenJin <yuchenj@cs.washington.edu>
1 parent b3ed427 commit 83cdffb

File tree

5 files changed

+23
-20
lines changed

5 files changed

+23
-20
lines changed

src/relay/backend/build_module.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ class RelayBuildModule : public runtime::ModuleNode {
490490

491491
auto lowered_funcs = executor_codegen_->GetIRModule();
492492

493+
// No need to build for external functions.
494+
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
495+
lowered_funcs.Set("ext_dev", IRModule());
496+
}
497+
493498
// Generate a placeholder function that attaches linked params as its arguments.
494499
if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
495500
CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";

src/relay/backend/te_compiler.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode {
195195
auto target = Target("ext_dev");
196196
auto global_var = GlobalVar(func_name);
197197
global_var->checked_type_ = key->source_func->checked_type();
198+
ir_module->Add(global_var, key->source_func);
198199
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
199200
return value;
200201
}
@@ -347,12 +348,6 @@ class LowerTensorExpr : public ExprMutator {
347348
<< ext_func->prim_fn_var->name_hint;
348349

349350
Map<GlobalVar, tir::PrimFunc> prim_fns;
350-
351-
for (auto prim_fn : ext_func->funcs->functions) {
352-
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
353-
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
354-
}
355-
356351
relay::Function func_with_metadata = func;
357352
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
358353
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);

src/relay/backend/vm/compiler.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
#include <vector>
4646

4747
#include "../../../target/source/codegen_source_base.h"
48-
#include "../../backend/compile_engine.h"
4948
#include "../../op/op_common.h"
5049
#include "../../transforms/pass_utils.h"
5150
#include "../utils.h"
@@ -79,6 +78,7 @@ namespace vm {
7978
using namespace tvm::runtime;
8079
using namespace tvm::runtime::vm;
8180
using namespace relay::transform;
81+
using namespace tec;
8282

8383
// (@jroesch): VM passes, eventually declare as passes.
8484
bool IsClosure(const Function& func);
@@ -253,7 +253,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
253253
ExprDeviceMap expr_device_map)
254254
: last_register_(0),
255255
registers_num_(0),
256-
engine_(CompileEngine::Global()),
257256
context_(context),
258257
target_host_(target_host),
259258
expr_device_map_(std::move(expr_device_map)) {
@@ -465,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
465464
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
466465
// Lower shape function
467466
CCacheKey key(func, target_host_);
468-
auto cfunc = engine_->LowerShapeFunc(key);
467+
auto cfunc = context_->compiler->LowerShapeFunc(key);
469468
int op_index = -1;
470469
// pick the only function inside the context
471470
ICHECK_EQ(cfunc->funcs->functions.size(), 1);
@@ -551,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
551550

552551
CCacheKey key(func, target);
553552
auto mangle_fn = [](String name) { return name; };
554-
auto cfunc = engine_->Lower(key, mangle_fn);
553+
auto cfunc = context_->compiler->Lower(key, mangle_fn);
555554

556555
auto op_index = -1;
557556
if (func->GetAttr<String>(attr::kCompiler).defined()) {
@@ -857,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
857856
size_t last_register_;
858857
/*! \brief Total number of virtual registers allocated. */
859858
size_t registers_num_;
860-
/*! \brief Compiler engine to lower primitive functions. */
861-
CompileEngine engine_;
862859
/*! \brief Global shared meta data */
863860
VMCompilerContext* context_;
864861
/*! \brief Target devices. */
@@ -1134,8 +1131,8 @@ void VMCompiler::Codegen() {
11341131
}
11351132
}
11361133

1137-
auto compile_engine = CompileEngine::Global();
1138-
auto ext_mods = compile_engine->LowerExternalFunctions();
1134+
auto ext_mods = context_.compiler->LowerExternalFunctions();
1135+
11391136
runtime::Module lib;
11401137
if (funcs.size() > 0) {
11411138
lib = tvm::build(funcs, target_host_);
@@ -1146,7 +1143,6 @@ void VMCompiler::Codegen() {
11461143
}
11471144
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata());
11481145
exec_->SetLib(lib);
1149-
CompileEngine::Global()->Clear();
11501146
}
11511147

11521148
ExprDeviceMap VMCompiler::AnalyzeContext() const {

src/relay/backend/vm/compiler.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343

4444
#include "../../../runtime/vm/naive_allocator.h"
4545
#include "../../../runtime/vm/profiler/vm.h"
46-
#include "../../backend/compile_engine.h"
4746
#include "../../transforms/pass_utils.h"
47+
#include "../te_compiler.h"
48+
#include "../te_compiler_cache.h"
4849

4950
namespace tvm {
5051
namespace relay {
@@ -75,12 +76,14 @@ struct VMCompilerContext {
7576
TagMap tag_map;
7677
// Map from global var to a unique integer
7778
GlobalMap global_map;
79+
// TEcompiler for lowering
80+
tec::TECompiler compiler;
7881
// List of constants
7982
std::vector<NDArray> constants;
8083
// Device type for constants
8184
std::vector<Index> const_device_type;
8285
// List of cached functions
83-
std::vector<CachedFunc> cached_funcs;
86+
std::vector<tec::CachedFunc> cached_funcs;
8487
// The functions that have been lowered.
8588
std::unordered_map<tir::PrimFunc, size_t, ObjectPtrHash, ObjectPtrEqual> seen_funcs;
8689
};

src/relay/transforms/memory_alloc.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@
4141
#include <unordered_set>
4242
#include <vector>
4343

44-
#include "../backend/compile_engine.h"
44+
#include "../backend/te_compiler.h"
45+
#include "../backend/te_compiler_cache.h"
4546
#include "../op/memory/memory.h"
4647
#include "../op/vm/vm.h"
4748
#include "./pass_utils.h"
4849
#include "let_list.h"
4950
#include "pattern_utils.h"
5051

5152
using namespace tvm::runtime;
53+
using namespace tvm::relay::tec;
5254

5355
namespace tvm {
5456
namespace relay {
@@ -271,9 +273,11 @@ class DialectRewriter : public ExprMutator {
271273
Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
272274
const std::vector<Expr>& new_args) {
273275
Array<Expr> shape_func_ins;
274-
auto engine = CompileEngine::Global();
276+
277+
TECompiler compiler;
278+
275279
CCacheKey key(func, target_host_);
276-
auto cfunc = engine->LowerShapeFunc(key);
280+
auto cfunc = compiler->LowerShapeFunc(key);
277281
auto input_states = cfunc->shape_func_param_states;
278282

279283
Array<Integer> is_inputs;

0 commit comments

Comments
 (0)