Skip to content

Commit cb157bd

Browse files
committed
Fix vm build. (apache#35)
1 parent 402426d commit cb157bd

File tree

3 files changed

+42
-52
lines changed

3 files changed

+42
-52
lines changed

python/tvm/relax/vm.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
from typing import List, Optional, Union, Dict, Tuple
1919
import tvm
20+
from tvm import relax
21+
from tvm.ir.module import IRModule
2022
from tvm.runtime import Object, Device, Module, PackedFunc
2123
from tvm._ffi.base import _LIB, check_call
24+
from tvm.tir.function import PrimFunc
2225
from . import _ffi_api
2326
from . import transform
2427
from ..rpc.base import RPC_SESS_MASK
@@ -169,5 +172,22 @@ def build(mod: tvm.IRModule,
169172
new_mod = transform.call_dps_rewrite(new_mod)
170173
new_mod = transform.vm_memory_lower(new_mod)
171174
new_mod = transform.vm_shape_lower(new_mod)
172-
ex, lib = _ffi_api.VMBuild(new_mod, target, target_host)
175+
176+
# split primfunc and relax function
177+
rx_mod, tir_mod = _split_tir_relax(new_mod)
178+
179+
lib = tvm.build(tir_mod, target, target_host)
180+
ex = _ffi_api.VMCodeGen(rx_mod)
173181
return ex, lib
182+
183+
def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]:
184+
rx_mod = IRModule({})
185+
tir_mod = IRModule({})
186+
for gv in mod.get_global_vars():
187+
if isinstance(mod[gv], PrimFunc):
188+
tir_mod[gv] = mod[gv]
189+
elif isinstance(mod[gv], relax.Function):
190+
rx_mod[gv] = mod[gv]
191+
else:
192+
raise ValueError("An IRModule should contain contain relax function and TIR primfunc.")
193+
return rx_mod, tir_mod

src/relax/backend/vm/codegen_vm.cc

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
/*!
2121
* \file src/relax/backend/vm/codegen_vm.cc
22-
* \brief A compiler to compile an IRModule to VM executable.
22+
* \brief A codegen to generate VM executable from an IRModule with relax functions.
2323
*/
2424

2525
#include "codegen_vm.h"
@@ -64,7 +64,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
6464
// TODO(@yuchen): handle local functions that capture local vars outside the func
6565
// TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a
6666
// function named "local_funcN"
67-
// lift the local func to a global func and compile it normally
67+
// lift the local func to a global func and process it normally
6868
builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++),
6969
func_node->params.size());
7070
}
@@ -287,49 +287,27 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
287287
const Op& load_shape_op_ = Op::Get("relax.vm.builtin.load_shape");
288288
};
289289

290-
void VMCompiler::Compile(IRModule mod, Target target, Target target_host) {
290+
void VMCodeGen::CodeGen(IRModule rx_mod) {
291291
builder_ = relax::ExecBuilderNode::Create();
292-
293-
IRModule tir_mod;
294-
IRModule rx_mod;
295-
for (auto& p : mod->functions) {
296-
auto gvar = p.first;
297-
298-
BaseFunc func = p.second;
299-
if (func.as<tir::PrimFuncNode>()) {
300-
tir_mod->Add(gvar, func);
301-
} else if (func.as<FunctionNode>()) {
302-
rx_mod->Add(gvar, func);
303-
} else {
304-
LOG(FATAL) << "Cannot handle such function node now:\n" << func;
305-
}
306-
}
307-
lib_ = tvm::build(tir_mod, target, target_host);
308-
309-
CodeGenVM compiler(builder_.operator->());
292+
CodeGenVM codegen(builder_.operator->());
310293
for (auto& p : rx_mod->functions) {
311-
compiler.VisitExpr(p.second);
294+
codegen.VisitExpr(p.second);
312295
}
313296
}
314297

315-
Executable VMCompiler::GetExec() {
298+
Executable VMCodeGen::GetExec() {
316299
return builder_->Get();
317300
}
318301

319-
runtime::Module VMCompiler::GetLib() {
320-
return lib_;
321-
}
322-
323-
Array<ObjectRef> Build(IRModule mod, Target target, Target target_host) {
324-
auto compiler = make_object<VMCompiler>();
325-
compiler->Compile(mod, target, target_host);
326-
Executable exec = compiler->GetExec();
327-
Module lib = compiler->GetLib();
328-
return Array<ObjectRef>({exec, lib});
302+
Executable CodeGen(IRModule mod) {
303+
auto codegen = make_object<VMCodeGen>();
304+
codegen->CodeGen(mod);
305+
Executable exec = codegen->GetExec();
306+
return exec;
329307
}
330308

331-
TVM_REGISTER_GLOBAL("relax.VMBuild")
332-
.set_body_typed(Build);
309+
TVM_REGISTER_GLOBAL("relax.VMCodeGen")
310+
.set_body_typed(CodeGen);
333311

334312
} // namespace relax_vm
335313
} // namespace relax

src/relax/backend/vm/codegen_vm.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
/*!
2121
* \file src/relax/backend/vm/codegen_vm.h
22-
* \brief A compiler to compile an IRModule to VM executable.
22+
* \brief A codegen to generate VM executable from an IRModule with relax functions.
2323
*/
2424

25-
#ifndef TVM_RELAX_BACKEND_VM_COMPILER_H_
26-
#define TVM_RELAX_BACKEND_VM_COMPILER_H_
25+
#ifndef TVM_RELAX_BACKEND_CODEGEN_VM_H_
26+
#define TVM_RELAX_BACKEND_CODEGEN_VM_H_
2727

2828
#include <tvm/ir/module.h>
2929
#include <tvm/relax/vm/exec_builder.h>
@@ -40,37 +40,29 @@ using tvm::Target;
4040
using namespace tvm::runtime::relax_vm;
4141
using namespace tvm::runtime;
4242

43-
class VMCompiler : public Object {
43+
class VMCodeGen : public Object {
4444
public:
4545
/*!
4646
* \brief Compile the functions in a Module.
47-
* \param mod Input IRModule to be compiled.
47+
* \param rx_mod Input IRModule that constains relax functions.
4848
*/
49-
void Compile(IRModule mod, Target target, Target target_host);
49+
void CodeGen(IRModule rx_mod);
5050
/*!
5151
* \brief Get the compiled executable.
5252
* \return The compiled executable.
5353
*/
5454
Executable GetExec();
55-
/*!
56-
* \brief Get the compiled library.
57-
* \return The compiled lirary.
58-
*/
59-
Module GetLib();
6055

6156
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
62-
static constexpr const char* _type_key = "relax.VMCompiler";
63-
TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object);
57+
static constexpr const char* _type_key = "relax.VMCodeGen";
6458

6559
protected:
6660
/*! \brief Internal executable builder. */
6761
relax::ExecBuilder builder_;
68-
/*! \brief Built library. */
69-
runtime::Module lib_;
7062
};
7163

7264
} // namespace relax_vm
7365
} // namespace relax
7466
} // namespace tvm
7567

76-
#endif // TVM_RELAX_BACKEND_VM_COMPILER_H_
68+
#endif // TVM_RELAX_BACKEND_CODEGEN_VM_H_

0 commit comments

Comments
 (0)