|
24 | 24 |
|
25 | 25 | import tvm |
26 | 26 | from tvm import autotvm |
27 | | -from tvm import TVMContext |
28 | 27 | from tvm.relay import expr as _expr |
29 | 28 | from . import _vm |
30 | 29 | from . import vmobj as _obj |
@@ -54,37 +53,10 @@ class Executable(object): |
54 | 53 | """Relay VM executable""" |
55 | 54 | def __init__(self, mod): |
56 | 55 | self.mod = mod |
57 | | - self._set_context = self.mod["set_context"] |
58 | 56 | self._get_lib = self.mod["get_lib"] |
59 | 57 | self._get_bytecode = self.mod["get_bytecode"] |
60 | 58 | self._get_stats = self.mod["get_stats"] |
61 | 59 |
|
62 | | - def set_context(self, ctx): |
63 | | - """Initialize the context of the VM executable. |
64 | | -
|
65 | | - Parameters |
66 | | - ---------- |
67 | | - ctx : Union[:py:class:`tvm.TVMContext`, List[py:class:`tvm.TVMContext`]] |
68 | | - The runtime context to run the code on. |
69 | | - """ |
70 | | - |
71 | | - if isinstance(ctx, TVMContext): |
72 | | - ctx = [ctx] |
73 | | - elif not isinstance(ctx, (list, tuple)): |
74 | | - raise ValueError("ctx has to be the type of TVMContext or a list of " |
75 | | - "TVMContext") |
76 | | - # args[0], args[1] are used as the primary/fallback context type and id |
77 | | - # for heterogeneous execution. |
78 | | - args = [] |
79 | | - for cur_ctx in ctx: |
80 | | - if not isinstance(cur_ctx, TVMContext): |
81 | | - raise ValueError("ctx has to be the type of TVMContext or a list " |
82 | | - "of TVMContext") |
83 | | - args.append(cur_ctx.device_type) |
84 | | - args.append(cur_ctx.device_id) |
85 | | - |
86 | | - self._set_context(*args) |
87 | | - |
88 | 60 | @property |
89 | 61 | def lib(self): |
90 | 62 | """Get the library that contains hardware dependent code. |
@@ -179,8 +151,20 @@ def __init__(self, mod): |
179 | 151 | "tvm.Module, but received {}".format(type(mod))) |
180 | 152 | m = mod.module if isinstance(mod, Executable) else mod |
181 | 153 | self.mod = _vm._VirtualMachine(m) |
| 154 | + self._init = self.mod["init"] |
182 | 155 | self._invoke = self.mod["invoke"] |
183 | 156 |
|
| 157 | + def init(self, ctx): |
| 158 | + """Initialize the context in the VM. |
| 159 | +
|
| 160 | + Parameters |
| 161 | + ---------- |
| 162 | + ctx : :py:class:`TVMContext` |
| 163 | + The runtime context to run the code on. |
| 164 | + """ |
| 165 | + args = [ctx.device_type, ctx.device_id] |
| 166 | + self._init(*args) |
| 167 | + |
184 | 168 | def invoke(self, func_name, *args): |
185 | 169 | """Invoke a function. |
186 | 170 |
|
@@ -341,8 +325,8 @@ def __init__(self, mod, ctx, target): |
341 | 325 | self.ctx = ctx |
342 | 326 | self.target = target |
343 | 327 | self.executable = compile(mod, target) |
344 | | - self.executable.set_context(ctx) |
345 | 328 | self.vm = VirtualMachine(self.executable) |
| 329 | + self.vm.init(ctx) |
346 | 330 |
|
347 | 331 | def _make_executor(self, expr=None): |
348 | 332 | main = self.mod["main"] |
|
0 commit comments