|
30 | 30 | from . import vmobj as _obj |
31 | 31 | from .interpreter import Executor |
32 | 32 |
|
33 | | - |
34 | | -def _update_target(target): |
35 | | - target = target if target else tvm.target.current_target() |
36 | | - if target is None: |
37 | | - raise ValueError("Target is not set in env or passed as argument.") |
38 | | - |
39 | | - tgts = {} |
40 | | - if isinstance(target, (str, tvm.target.Target)): |
41 | | - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) |
42 | | - tgts[dev_type] = tvm.target.create(target) |
43 | | - elif isinstance(target, dict): |
44 | | - for dev, tgt in target.items(): |
45 | | - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) |
46 | | - tgts[dev_type] = tvm.target.create(tgt) |
47 | | - else: |
48 | | - raise TypeError("target is expected to be str, tvm.target.Target, " + |
49 | | - "or dict of str to str/tvm.target.Target, but received " + |
50 | | - "{}".format(type(target))) |
51 | | - return tgts |
52 | | - |
53 | 33 | def _convert(arg, cargs): |
54 | 34 | if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): |
55 | 35 | cargs.append(_obj.tensor_object(arg)) |
@@ -161,6 +141,44 @@ def set_params(self, params): |
161 | 141 | inputs[name] = _expr.const(param) |
162 | 142 | self._set_params_func(inputs) |
163 | 143 |
|
| 144 | + def update_target(self, target): |
| 145 | + target = target if target else tvm.target.current_target() |
| 146 | + if target is None: |
| 147 | + raise ValueError("Target is not set in env or passed as argument.") |
| 148 | + tgts = {} |
| 149 | + if isinstance(target, (str, tvm.target.Target)): |
| 150 | + dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) |
| 151 | + tgts[dev_type] = tvm.target.create(target) |
| 152 | + elif isinstance(target, dict): |
| 153 | + for dev, tgt in target.items(): |
| 154 | + dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) |
| 155 | + tgts[dev_type] = tvm.target.create(tgt) |
| 156 | + else: |
| 157 | + raise TypeError("target is expected to be str, tvm.target.Target, " + |
| 158 | + "or dict of str to str/tvm.target.Target, but received " + |
| 159 | + "{}".format(type(target))) |
| 160 | + return tgts |
| 161 | + |
| 162 | + def update_target_host(self, target, target_host): |
| 163 | + target_host = None if target_host == "" else target_host |
| 164 | + if not target_host: |
| 165 | + for device_type, tgt in target.items(): |
| 166 | + if device_type.value == tvm.nd.cpu(0).device_type: |
| 167 | + target_host = tgt |
| 168 | + break |
| 169 | + if not target_host: |
| 170 | + target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" |
| 171 | + return tvm.target.create(target_host) |
| 172 | + |
| 173 | + def tophub_context(self, target): |
| 174 | + # If current dispatch context is fallback context (the default root context), |
| 175 | + # then load pre-tuned parameters from TopHub |
| 176 | + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): |
| 177 | + tophub_context = autotvm.tophub.context(list(target.values())) |
| 178 | + else: |
| 179 | + tophub_context = autotvm.util.EmptyContext() |
| 180 | + return tophub_context |
| 181 | + |
164 | 182 | def compile(self, mod, target=None, target_host=None, params=None): |
165 | 183 | """ |
166 | 184 | Parameters |
@@ -191,26 +209,13 @@ def compile(self, mod, target=None, target_host=None, params=None): |
191 | 209 | vm : VirtualMachine |
192 | 210 | The VM runtime. |
193 | 211 | """ |
194 | | - target = _update_target(target) |
195 | | - target_host = None if target_host == "" else target_host |
196 | | - if not target_host: |
197 | | - for device_type, tgt in target.items(): |
198 | | - if device_type.value == tvm.nd.cpu(0).device_type: |
199 | | - target_host = tgt |
200 | | - break |
201 | | - if not target_host: |
202 | | - target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" |
203 | | - target_host = tvm.target.create(target_host) |
| 212 | + target = self.update_target(target) |
| 213 | + target_host = self.update_target_host(target, target_host) |
204 | 214 |
|
205 | 215 | if params: |
206 | 216 | self.set_params(params) |
207 | 217 |
|
208 | | - # If current dispatch context is fallback context (the default root context), |
209 | | - # then load pre-tuned parameters from TopHub |
210 | | - if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): |
211 | | - tophub_context = autotvm.tophub.context(list(target.values())) |
212 | | - else: |
213 | | - tophub_context = autotvm.util.EmptyContext() |
| 218 | + tophub_context = self.tophub_context(target) |
214 | 219 |
|
215 | 220 | with tophub_context: |
216 | 221 | self._compile(mod, target, target_host) |
|
0 commit comments