Skip to content

Commit 7182129

Browse files
committed
[relay][vm] Separate VM runtime with executable
1 parent 15ae978 commit 7182129

File tree

20 files changed

+762
-612
lines changed

20 files changed

+762
-612
lines changed

include/tvm/runtime/vm.h

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <tvm/runtime/object.h>
2929
#include <tvm/runtime/packed_func.h>
30+
#include <tvm/runtime/registry.h>
3031
#include <memory>
3132
#include <string>
3233
#include <unordered_map>
@@ -362,15 +363,101 @@ struct VMFrame {
362363
caller_return_register(0) {}
363364
};
364365

366+
/*! \brief The executable emitted by the VM compiler.
367+
*
368+
* The executable contains information (e.g. data in different memory regions)
369+
* to create a virtual machine.
370+
*/
371+
class Executable : public ModuleNode {
372+
public:
373+
/*!
374+
* \brief Get a PackedFunc from an executable module.
375+
*
376+
* \param name the name of the function.
377+
* \param sptr_to_self The shared_ptr that points to this module node.
378+
*
379+
* \return PackedFunc or nullptr when it is not available.
380+
*/
381+
PackedFunc GetFunction(const std::string& name,
382+
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
383+
384+
/*!
385+
* \brief Get the serialized form of the `functions` in `vm_`. This is
386+
* essentially bytecode serialization.
387+
*
388+
* \return The serialized vm bytecode.
389+
*
390+
* \note The bytecode is in the following format:
391+
* func_name reg_file_size num_instructions
392+
* param1 param2 ... paramM
393+
* instruction1
394+
* instruction2
395+
* ...
396+
* instructionN
397+
*
398+
* Each instruction is printed in the following format:
399+
* opcode num_fields field1 ... fieldX # The text format.
400+
*
401+
* The field starting from # is only used for debugging. The serialized code
402+
* doesn't contain it, therefore the deserializer doens't need to handle it.
403+
*/
404+
std::string GetBytecode() const;
405+
406+
/*!
407+
* \brief Print the detailed statistics of the given code, i.e. number of
408+
* globls and constants, etc.
409+
*/
410+
std::string Stats() const;
411+
412+
/*! \brief Get the `lib` module in an executable. Users have the flexibility to call
413+
* `export_library` from the frontend to save the library to disk.
414+
*
415+
* \return The runtime module that contains the hardwre dependent code.
416+
*/
417+
runtime::Module GetLib() const { return lib; }
418+
419+
/*!
420+
* \brief Set the execution context for the executable.
421+
*
422+
* \param ctxs The list of TVMContext.
423+
*/
424+
void SetContext(const std::vector<TVMContext>& ctxs);
425+
426+
/*! \brief Get device context for params.
427+
*/
428+
TVMContext GetParamsContext() const;
429+
430+
virtual ~Executable() {}
431+
432+
const char* type_key() const final {
433+
return "VMExecutable";
434+
}
435+
436+
/*! \brief The runtime module/library that contains hardware dependent code. */
437+
runtime::Module lib;
438+
/*! \brief The global constant pool. */
439+
std::vector<Object> constants;
440+
/*! \brief A map from globals (as strings) to their index in the function map. */
441+
std::unordered_map<std::string, Index> global_map;
442+
/*! \brief A mapping from the packed function (as string) to the index that
443+
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
444+
*/
445+
std::unordered_map<std::string, Index> primitive_map;
446+
/*! \brief The virtual machine's function table. */
447+
std::vector<VMFunction> functions;
448+
449+
/*! \brief The set of TVM contexts the VM is currently executing on. */
450+
std::vector<TVMContext> ctxs;
451+
};
452+
365453
/*! \brief The virtual machine.
366454
*
367455
* The virtual machine contains all the current execution state,
368-
* as well as the global view of functions, the global constant
369-
* table, the compiled operators.
456+
* as well as the executable.
370457
*
371458
* The goal is to have a single self-contained object,
372459
* enabling one to easily pass around VMs, execute them on
373-
* multiple threads, or serialized them to disk or over the
460+
* multiple threads, or serialize them to disk or over the
374461
* wire.
375462
*/
376463
class VirtualMachine : public runtime::ModuleNode {
@@ -415,16 +502,10 @@ class VirtualMachine : public runtime::ModuleNode {
415502
return "VirtualMachine";
416503
}
417504

418-
/*! \brief The runtime module/library that contains generated code. */
419-
runtime::Module lib;
420505
/*! \brief The virtual machine's packed function table. */
421506
std::vector<PackedFunc> packed_funcs;
422-
/*! \brief The virtual machine's function table. */
423-
std::vector<VMFunction> functions;
424507
/*! \brief The current stack of call frames. */
425508
std::vector<VMFrame> frames;
426-
/*! \brief The global constant pool. */
427-
std::vector<Object> constants;
428509
/*! \brief The fuction table index of the current function. */
429510
Index func_index;
430511
/*! \brief The current pointer to the code section. */
@@ -435,8 +516,8 @@ class VirtualMachine : public runtime::ModuleNode {
435516
/*! \brief The special return register. */
436517
Object return_register;
437518

438-
/*! \brief The set of TVM contexts the VM is currently executing on. */
439-
std::vector<TVMContext> ctxs;
519+
/*! \brief The executable the VM will operate on. */
520+
const Executable* exec;
440521

441522
/*! \brief Push a call frame on to the call stack. */
442523
void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
@@ -478,44 +559,24 @@ class VirtualMachine : public runtime::ModuleNode {
478559
*/
479560
Object Invoke(const std::string& name, const std::vector<Object>& args);
480561

481-
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
562+
VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
482563

483-
/*! \brief Initialize the virtual machine for a set of contexts.
484-
* \param contexts The set of TVM contexts.
564+
/*! \brief Initialize the virtual machine using an executable.
565+
* \param exec The executable.
485566
*/
486-
void Init(const std::vector<TVMContext>& contexts);
567+
void Init(const Executable* exec);
487568

488569
/*! \brief Run VM dispatch loop.
489570
*/
490571
void RunLoop();
491572

492-
/*! \brief Get device context for params.
493-
*/
494-
TVMContext GetParamsContext() const;
495-
496-
/*!
497-
* \brief Load parameters from the parameter bytearray.
498-
* \param params The binary file that contains parameters.
499-
*/
500-
void LoadParams(const std::string& params);
501-
502-
/*! \brief A map from globals (as strings) to their index in the function map.
503-
*/
504-
std::unordered_map<std::string, Index> global_map;
505-
506-
/*! \brief A mapping from the packed function (as string) to the index that
507-
* corresponds to the position of the `packed_funcs` list.
508-
*/
509-
std::unordered_map<std::string, Index> primitive_map;
510-
511573
private:
512574
/*! \brief Invoke a global setting up the VM state to execute.
513575
*
514576
* This does not begin execution of the VM.
515577
*/
516578
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
517579

518-
519580
/*! \brief The parameter name to data mapping. */
520581
std::unordered_map<std::string, Object> params_;
521582
};

python/tvm/relay/backend/deserializer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _create_deserializer(code, lib):
3131
Parameters
3232
----------
3333
code : bytearray
34-
The serialized virtual machine code.
34+
The serialized virtual machine bytecode.
3535
3636
lib : :py:class:`~tvm.module.Module`
3737
The serialized runtime module/library that contains the hardware
@@ -40,7 +40,7 @@ def _create_deserializer(code, lib):
4040
Returns
4141
-------
4242
ret : Deserializer
43-
The created virtual machine deserializer.
43+
The created virtual machine executable deserializer.
4444
"""
4545
if isinstance(code, (bytes, str)):
4646
code = bytearray(code)
@@ -55,12 +55,12 @@ def _create_deserializer(code, lib):
5555

5656

5757
class Deserializer:
58-
"""Relay VM deserializer.
58+
"""Relay VM executable deserializer.
5959
6060
Parameters
6161
----------
6262
code : bytearray
63-
The serialized virtual machine code.
63+
The serialized virtual machine bytecode.
6464
6565
lib : :py:class:`~tvm.module.Module`
6666
The serialized runtime module/library that contains the hardware
@@ -71,11 +71,11 @@ def __init__(self, code, lib):
7171
self._deserialize = self.mod["deserialize"]
7272

7373
def deserialize(self):
74-
"""Deserialize the serialized bytecode into a Relay VM.
74+
"""Deserialize the serialized bytecode into a Relay VM executable.
7575
7676
Returns
7777
-------
78-
ret : VirtualMachine
79-
The deserialized Relay VM.
78+
ret : Executable
79+
The deserialized Relay VM executable.
8080
"""
81-
return rly_vm.VirtualMachine(self._deserialize())
81+
return rly_vm.Executable(self._deserialize())

python/tvm/relay/backend/profiler_vm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None):
4949
5050
Returns
5151
-------
52-
vm : VirtualMachineProfiler
53-
The profile VM runtime.
52+
exec : Executable
53+
The executable with profiling code.
5454
"""
5555
compiler = VMCompilerProfiler()
5656
target = compiler.update_target(target)
@@ -60,21 +60,24 @@ def compile(mod, target=None, target_host=None, params=None):
6060
tophub_context = compiler.tophub_context(target)
6161
with tophub_context:
6262
compiler._compile(mod, target, target_host)
63-
return VirtualMachineProfiler(compiler._get_vm())
63+
return vm.Executable(compiler._get_exec())
6464

6565
class VMCompilerProfiler(vm.VMCompiler):
6666
"""Build Relay module to run on VM runtime."""
6767
def __init__(self):
6868
super().__init__()
6969
self.mod = _vm._VMCompilerProfiler()
7070
self._compile = self.mod["compile"]
71-
self._get_vm = self.mod["get_vm"]
71+
self._get_exec = self.mod["get_executable"]
7272
self._set_params_func = self.mod["set_params"]
7373

7474
class VirtualMachineProfiler(vm.VirtualMachine):
7575
"""Relay profile VM runtime."""
7676
def __init__(self, mod):
7777
super().__init__(mod)
78+
m = mod.module if isinstance(mod, vm.Executable) else mod
79+
self.mod = _vm._VirtualMachineDebug(m)
80+
self._invoke = self.mod["invoke"]
7881
self._get_stat = self.mod["get_stat"]
7982

8083
def get_stat(self):

0 commit comments

Comments
 (0)