Skip to content

Commit 34eb868

Browse files
committed
move ctx back to vm
1 parent 9f232a9 commit 34eb868

File tree

14 files changed

+110
-136
lines changed

14 files changed

+110
-136
lines changed

include/tvm/runtime/vm.h

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,17 +484,6 @@ class Executable : public ModuleNode {
484484
*/
485485
runtime::Module GetLib() const { return lib; }
486486

487-
/*!
488-
* \brief Set the execution context for the executable.
489-
*
490-
* \param ctxs The list of TVMContext.
491-
*/
492-
void SetContext(const std::vector<TVMContext>& ctxs);
493-
494-
/*! \brief Get device context for params.
495-
*/
496-
TVMContext GetParamsContext() const;
497-
498487
virtual ~Executable() {}
499488

500489
const char* type_key() const final {
@@ -514,9 +503,6 @@ class Executable : public ModuleNode {
514503
std::unordered_map<std::string, Index> primitive_map;
515504
/*! \brief The virtual machine's function table. */
516505
std::vector<VMFunction> functions;
517-
518-
/*! \brief The set of TVM contexts the VM is currently executing on. */
519-
std::vector<TVMContext> ctxs;
520506
};
521507

522508
/*! \brief The virtual machine.
@@ -591,6 +577,9 @@ class VirtualMachine : public runtime::ModuleNode {
591577
/*! \brief The executable the VM will operate on. */
592578
const Executable* exec;
593579

580+
/*! \brief The set of TVM contexts the VM is currently executing on. */
581+
std::vector<TVMContext> ctxs;
582+
594583
/*! \brief Push a call frame on to the call stack. */
595584
void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
596585
/*! \brief Pop a frame off the call stack.
@@ -634,15 +623,24 @@ class VirtualMachine : public runtime::ModuleNode {
634623

635624
VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
636625

637-
/*! \brief Initialize the virtual machine using an executable.
626+
/*! \brief load the executable for the virtual machine.
638627
* \param exec The executable.
639628
*/
640-
void Init(const Executable* exec);
629+
void LoadExecutable(const Executable* exec);
630+
631+
/*! \brief Initialize the virtual machine for a set of contexts.
632+
* \param contexts The set of TVM contexts.
633+
*/
634+
void Init(const std::vector<TVMContext>& contexts);
641635

642636
/*! \brief Run VM dispatch loop.
643637
*/
644638
void RunLoop();
645639

640+
/*! \brief Get device context for params.
641+
*/
642+
TVMContext GetParamsContext() const;
643+
646644
private:
647645
/*! \brief Invoke a global setting up the VM state to execute.
648646
*

python/tvm/relay/backend/profiler_vm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(self, mod):
7777
super().__init__(mod)
7878
m = mod.module if isinstance(mod, vm.Executable) else mod
7979
self.mod = _vm._VirtualMachineDebug(m)
80+
self._init = self.mod["init"]
8081
self._invoke = self.mod["invoke"]
8182
self._get_stat = self.mod["get_stat"]
8283

python/tvm/relay/backend/serializer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def serialize(self):
9696
ctx = tvm.cpu()
9797
target = "llvm"
9898
executable = relay.vm..compile(mod, target)
99-
executable.set_context(ctx)
10099
101100
# serialize.
102101
ser = relay.serializer.Serializer(executable)
@@ -117,7 +116,7 @@ def serialize(self):
117116
des_exec = deser.deserialize()
118117
119118
# execute the deserialized executable.
120-
des_exec.set_context(ctx)
119+
des_vm.init(ctx)
121120
x_data = np.random.rand(10, 10).astype('float32')
122121
des_vm = relay.vm.VirtualMachine(des_exec)
123122
res = des_vm.run(x_data)

python/tvm/relay/backend/vm.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import tvm
2626
from tvm import autotvm
27-
from tvm import TVMContext
2827
from tvm.relay import expr as _expr
2928
from . import _vm
3029
from . import vmobj as _obj
@@ -54,37 +53,10 @@ class Executable(object):
5453
"""Relay VM executable"""
5554
def __init__(self, mod):
5655
self.mod = mod
57-
self._set_context = self.mod["set_context"]
5856
self._get_lib = self.mod["get_lib"]
5957
self._get_bytecode = self.mod["get_bytecode"]
6058
self._get_stats = self.mod["get_stats"]
6159

62-
def set_context(self, ctx):
63-
"""Initialize the context of the VM executable.
64-
65-
Parameters
66-
----------
67-
ctx : Union[:py:class:`tvm.TVMContext`, List[py:class:`tvm.TVMContext`]]
68-
The runtime context to run the code on.
69-
"""
70-
71-
if isinstance(ctx, TVMContext):
72-
ctx = [ctx]
73-
elif not isinstance(ctx, (list, tuple)):
74-
raise ValueError("ctx has to be the type of TVMContext or a list of "
75-
"TVMContext")
76-
# args[0], args[1] are used as the primary/fallback context type and id
77-
# for heterogeneous execution.
78-
args = []
79-
for cur_ctx in ctx:
80-
if not isinstance(cur_ctx, TVMContext):
81-
raise ValueError("ctx has to be the type of TVMContext or a list "
82-
"of TVMContext")
83-
args.append(cur_ctx.device_type)
84-
args.append(cur_ctx.device_id)
85-
86-
self._set_context(*args)
87-
8860
@property
8961
def lib(self):
9062
"""Get the library that contains hardware dependent code.
@@ -179,8 +151,20 @@ def __init__(self, mod):
179151
"tvm.Module, but received {}".format(type(mod)))
180152
m = mod.module if isinstance(mod, Executable) else mod
181153
self.mod = _vm._VirtualMachine(m)
154+
self._init = self.mod["init"]
182155
self._invoke = self.mod["invoke"]
183156

157+
def init(self, ctx):
158+
"""Initialize the context in the VM.
159+
160+
Parameters
161+
----------
162+
ctx : :py:class:`TVMContext`
163+
The runtime context to run the code on.
164+
"""
165+
args = [ctx.device_type, ctx.device_id]
166+
self._init(*args)
167+
184168
def invoke(self, func_name, *args):
185169
"""Invoke a function.
186170
@@ -341,8 +325,8 @@ def __init__(self, mod, ctx, target):
341325
self.ctx = ctx
342326
self.target = target
343327
self.executable = compile(mod, target)
344-
self.executable.set_context(ctx)
345328
self.vm = VirtualMachine(self.executable)
329+
self.vm.init(ctx)
346330

347331
def _make_executor(self, expr=None):
348332
main = self.mod["main"]

src/runtime/vm/deserializer.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ void Deserializer::Deserialize() {
8282

8383
// Code section.
8484
DeserializeCodeSection();
85-
86-
// Context section.
87-
DeserializeContextSection();
8885
}
8986

9087
void Deserializer::DeserializeGlobalSection() {
@@ -313,18 +310,6 @@ void Deserializer::DeserializeCodeSection() {
313310
}
314311
}
315312

316-
void Deserializer::DeserializeContextSection() {
317-
std::vector<uint64_t> ctxs;
318-
STREAM_CHECK(strm_->Read(&ctxs), "context");
319-
CHECK_EQ(ctxs.size() % 2, 0U);
320-
for (size_t i = 0; i < ctxs.size(); i += 2) {
321-
TVMContext ctx;
322-
ctx.device_type = DLDeviceType(ctxs[i]);
323-
ctx.device_id = static_cast<int>(ctxs[i + 1]);
324-
exec_->ctxs.push_back(ctx);
325-
}
326-
}
327-
328313
runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) {
329314
std::shared_ptr<Deserializer> exec = std::make_shared<Deserializer>();
330315
exec->Init(code, lib);

src/runtime/vm/executable.cc

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,7 @@ namespace vm {
3939

4040
PackedFunc Executable::GetFunction(const std::string& name,
4141
const std::shared_ptr<ModuleNode>& sptr_to_self) {
42-
if (name == "set_context") {
43-
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
44-
CHECK_EQ(args.size() % 2, 0);
45-
std::vector<TVMContext> contexts;
46-
for (int i = 0; i < args.size() / 2; ++i) {
47-
TVMContext ctx;
48-
int device_type = args[i * 2];
49-
ctx.device_type = DLDeviceType(device_type);
50-
ctx.device_id = args[i * 2 + 1];
51-
contexts.push_back(ctx);
52-
}
53-
this->SetContext(contexts);
54-
});
55-
} else if (name == "get_lib") {
42+
if (name == "get_lib") {
5643
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
5744
*rv = this->GetLib();
5845
});
@@ -70,10 +57,6 @@ PackedFunc Executable::GetFunction(const std::string& name,
7057
}
7158
}
7259

73-
inline void Executable::SetContext(const std::vector<TVMContext>& ctxs) {
74-
this->ctxs = ctxs;
75-
}
76-
7760
std::string Executable::GetBytecode() const {
7861
std::ostringstream oss;
7962

@@ -166,20 +149,6 @@ std::string Executable::Stats() const {
166149
return oss.str();
167150
}
168151

169-
TVMContext Executable::GetParamsContext() const {
170-
CHECK(!ctxs.empty()) << "context has not been set yet.";
171-
172-
// Use the fallback device if no device index is available.
173-
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
174-
// TODO(wweic): For heterogeneous execution, get device information from byte
175-
176-
const auto& cit =
177-
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
178-
return fallback_device_type == static_cast<int>(c.device_type);
179-
});
180-
return (cit == ctxs.end() ? ctxs[0] : *cit);
181-
}
182-
183152
TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
184153
.set_body([](TVMArgs args, TVMRetValue* rv) {
185154
runtime::Module mod = args[0];
@@ -188,7 +157,6 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
188157
*rv = static_cast<int>(exec->global_map.size());
189158
});
190159

191-
192160
TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
193161
.set_body([](TVMArgs args, TVMRetValue* rv) {
194162
runtime::Module mod = args[0];

src/runtime/vm/profiler/vm.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,43 @@ PackedFunc VirtualMachineDebug::GetFunction(
6767
os << "Total Duration " << total_duration << " us" << std::endl;
6868
*rv = os.str();
6969
});
70+
} else if (name == "init") {
71+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
72+
CHECK_EQ(args.size() % 2, 0);
73+
std::vector<TVMContext> contexts;
74+
for (int i = 0; i < args.size() / 2; ++i) {
75+
TVMContext ctx;
76+
int device_type = args[i * 2];
77+
ctx.device_type = DLDeviceType(device_type);
78+
ctx.device_id = args[i * 2 + 1];
79+
contexts.push_back(ctx);
80+
}
81+
this->Init(contexts);
82+
});
7083
} else {
7184
return VirtualMachine::GetFunction(name, sptr_to_self);
7285
}
7386
}
7487

75-
void VirtualMachineDebug::Init(const Executable* exec) {
76-
VirtualMachine::Init(exec);
88+
void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
89+
VirtualMachine::LoadExecutable(exec);
7790
CHECK(this->exec);
7891
for (auto kv : this->exec->primitive_map) {
7992
packed_index_map[kv.second] = kv.first;
8093
op_invokes[kv.second] = 0;
8194
}
8295
}
8396

97+
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
98+
VirtualMachine::Init(ctxs);
99+
}
100+
84101
void VirtualMachineDebug::InvokePacked(Index packed_index,
85102
const PackedFunc& func, Index arg_count,
86103
Index output_size,
87104
const std::vector<ObjectRef>& args) {
88105
CHECK(this->exec);
89-
auto ctx = this->exec->GetParamsContext();
106+
auto ctx = this->GetParamsContext();
90107
// warmup
91108
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
92109
args);
@@ -108,7 +125,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
108125

109126
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
110127
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
111-
vm->Init(exec);
128+
vm->LoadExecutable(exec);
112129
return runtime::Module(vm);
113130
}
114131

src/runtime/vm/profiler/vm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ class VirtualMachineDebug : public VirtualMachine {
4747
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
4848
Index output_size, const std::vector<ObjectRef>& args) final;
4949

50-
void Init(const Executable* exec);
50+
void LoadExecutable(const Executable* exec);
5151

5252
~VirtualMachineDebug() {}
5353

5454
private:
55+
void Init(const std::vector<TVMContext>& ctxs);
56+
5557
std::unordered_map<Index, std::string> packed_index_map;
5658
std::unordered_map<Index, std::vector<double>> op_durations;
5759
std::unordered_map<Index, int> op_invokes;

src/runtime/vm/serializer.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ TVMByteArray Serializer::Serialize() {
8181
// Code section.
8282
SerializeCodeSection();
8383

84-
// Context section.
85-
SerializeContextSection();
86-
8784
TVMByteArray arr;
8885
arr.data = code_.c_str();
8986
arr.size = code_.length();
@@ -300,16 +297,6 @@ void Serializer::SerializeCodeSection() {
300297
}
301298
}
302299

303-
void Serializer::SerializeContextSection() {
304-
CHECK(!exec_->ctxs.empty());
305-
std::vector<uint64_t> serialized_ctx;
306-
for (const auto& ctx : exec_->ctxs) {
307-
serialized_ctx.push_back(static_cast<uint64_t>(ctx.device_type));
308-
serialized_ctx.push_back(static_cast<uint64_t>(ctx.device_id));
309-
}
310-
strm_->Write(serialized_ctx);
311-
}
312-
313300
runtime::Module CreateSerializer(const Executable* exec) {
314301
std::shared_ptr<Serializer> serializer = std::make_shared<Serializer>();
315302
serializer->Init(exec);

src/runtime/vm/serializer.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
* - The `primitive_map` that contains the name of individual primitive operators.
3333
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
3434
* a list of instructions/bytecode.
35-
* - The `ctxs` that contains the device context used to execute the hardware
36-
* dependent code.
3735
*
3836
* Note that only the library is returned as a separate module. All othere parts
3937
* are stored in a single serialized code that is organized with the following
@@ -43,7 +41,6 @@
4341
* - Primitive name section, containing the function name of the primitive ops
4442
* used by the virtual machine.
4543
* - Code section, handling the VM functions and bytecode.
46-
* - Context section, saving the context information.
4744
*
4845
* The code section is again organized as follows for each VM function:
4946
* func_name, register_file_size, num_instructions (N)
@@ -136,9 +133,6 @@ class Serializer : public runtime::ModuleNode {
136133
/*! \brief Serialize the vm functions in exec_. */
137134
void SerializeCodeSection();
138135

139-
/*! \brief Serialize the context in exec_. */
140-
void SerializeContextSection();
141-
142136
/*! \brief The Relay virtual machine executable to be serialized. */
143137
const Executable* exec_;
144138

0 commit comments

Comments
 (0)