Skip to content

Commit 4a78f2e

Browse files
committed
Refactor
1 parent 48c046a commit 4a78f2e

File tree

2 files changed

+44
-52
lines changed

2 files changed

+44
-52
lines changed

python/tvm/relay/backend/profiler_vm.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,25 +83,12 @@ def compile(self, mod, target=None, target_host=None, params=None):
8383
The profile VM runtime.
8484
"""
8585
target = _update_target(target)
86+
target_host = self.update_target_host(target, target_host)
87+
8688
if params:
8789
self.set_params(params)
8890

89-
target_host = None if target_host == "" else target_host
90-
if not target_host:
91-
for device_type, tgt in target.items():
92-
if device_type.value == tvm.nd.cpu(0).device_type:
93-
target_host = tgt
94-
break
95-
if not target_host:
96-
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
97-
target_host = tvm.target.create(target_host)
98-
99-
# If current dispatch context is fallback context (the default root context),
100-
# then load pre-tuned parameters from TopHub
101-
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
102-
tophub_context = autotvm.tophub.context(list(target.values()))
103-
else:
104-
tophub_context = autotvm.util.EmptyContext()
91+
tophub_context = self.tophub_context(target)
10592

10693
with tophub_context:
10794
self._compile(mod, target, target_host)

python/tvm/relay/backend/vm.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,6 @@
3030
from . import vmobj as _obj
3131
from .interpreter import Executor
3232

33-
34-
def _update_target(target):
35-
target = target if target else tvm.target.current_target()
36-
if target is None:
37-
raise ValueError("Target is not set in env or passed as argument.")
38-
39-
tgts = {}
40-
if isinstance(target, (str, tvm.target.Target)):
41-
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
42-
tgts[dev_type] = tvm.target.create(target)
43-
elif isinstance(target, dict):
44-
for dev, tgt in target.items():
45-
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
46-
tgts[dev_type] = tvm.target.create(tgt)
47-
else:
48-
raise TypeError("target is expected to be str, tvm.target.Target, " +
49-
"or dict of str to str/tvm.target.Target, but received " +
50-
"{}".format(type(target)))
51-
return tgts
52-
5333
def _convert(arg, cargs):
5434
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
5535
cargs.append(_obj.tensor_object(arg))
@@ -161,6 +141,44 @@ def set_params(self, params):
161141
inputs[name] = _expr.const(param)
162142
self._set_params_func(inputs)
163143

144+
def update_target(self, target):
145+
target = target if target else tvm.target.current_target()
146+
if target is None:
147+
raise ValueError("Target is not set in env or passed as argument.")
148+
tgts = {}
149+
if isinstance(target, (str, tvm.target.Target)):
150+
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
151+
tgts[dev_type] = tvm.target.create(target)
152+
elif isinstance(target, dict):
153+
for dev, tgt in target.items():
154+
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
155+
tgts[dev_type] = tvm.target.create(tgt)
156+
else:
157+
raise TypeError("target is expected to be str, tvm.target.Target, " +
158+
"or dict of str to str/tvm.target.Target, but received " +
159+
"{}".format(type(target)))
160+
return tgts
161+
162+
def update_target_host(self, target, target_host):
163+
target_host = None if target_host == "" else target_host
164+
if not target_host:
165+
for device_type, tgt in target.items():
166+
if device_type.value == tvm.nd.cpu(0).device_type:
167+
target_host = tgt
168+
break
169+
if not target_host:
170+
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
171+
return tvm.target.create(target_host)
172+
173+
def tophub_context(self, target):
174+
# If current dispatch context is fallback context (the default root context),
175+
# then load pre-tuned parameters from TopHub
176+
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
177+
tophub_context = autotvm.tophub.context(list(target.values()))
178+
else:
179+
tophub_context = autotvm.util.EmptyContext()
180+
return tophub_context
181+
164182
def compile(self, mod, target=None, target_host=None, params=None):
165183
"""
166184
Parameters
@@ -191,26 +209,13 @@ def compile(self, mod, target=None, target_host=None, params=None):
191209
vm : VirtualMachine
192210
The VM runtime.
193211
"""
194-
target = _update_target(target)
195-
target_host = None if target_host == "" else target_host
196-
if not target_host:
197-
for device_type, tgt in target.items():
198-
if device_type.value == tvm.nd.cpu(0).device_type:
199-
target_host = tgt
200-
break
201-
if not target_host:
202-
target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
203-
target_host = tvm.target.create(target_host)
212+
target = self.update_target(target)
213+
target_host = self.update_target_host(target, target_host)
204214

205215
if params:
206216
self.set_params(params)
207217

208-
# If current dispatch context is fallback context (the default root context),
209-
# then load pre-tuned parameters from TopHub
210-
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
211-
tophub_context = autotvm.tophub.context(list(target.values()))
212-
else:
213-
tophub_context = autotvm.util.EmptyContext()
218+
tophub_context = self.tophub_context(target)
214219

215220
with tophub_context:
216221
self._compile(mod, target, target_host)

0 commit comments

Comments
 (0)