Skip to content

Commit 1f59139

Browse files
authored
[Target] Fix empty target and host for autotvm task (#7791)
1 parent 7071fda commit 1f59139

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

python/tvm/autotvm/task/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __getstate__(self):
185185
"config_space": self.config_space,
186186
"flop": self.flop,
187187
"target": self.target,
188-
"target_host": self.target.host,
188+
"target_host": self.target_host,
189189
"func": cloudpickle.dumps(self.func),
190190
}
191191

@@ -465,7 +465,7 @@ def create(task_name, args, target, target_host=None):
465465

466466
ret.flop = ret.config_space.flop or compute_flop(sch)
467467
ret.target = target
468-
ret.target_host = target.host
468+
ret.target_host = target_host
469469

470470
return ret
471471

python/tvm/target/target.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True):
182182
target_is_dict_key : Bool
183183
When the type of target is dict, whether Target is the key (Otherwise the value)
184184
"""
185+
if target is None:
186+
assert host is None, "Target host is not empty when target is empty."
187+
return target, host
185188
if isinstance(target, dict) and "kind" not in target:
186189
new_target = {}
187190
for tgt, mod in target.items():

tests/python/integration/test_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from tvm import autotvm
3232
from tvm.autotvm.tuner import RandomTuner
33+
from tvm.target import Target
3334

3435
import tvm.testing
3536

@@ -131,8 +132,7 @@ def teardown_module():
131132

132133

133134
def get_sample_task(target=tvm.target.cuda(), target_host=None):
134-
target = tvm.target.Target(target, target_host)
135-
target_host = target.host
135+
target, target_host = Target.check_and_update_host_consist(target, target_host)
136136
"""return a sample task for testing"""
137137
task = autotvm.task.create(
138138
"testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target

tests/python/unittest/test_target_target.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import pytest
2020
import tvm
21-
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost
21+
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target
2222

2323

2424
@tvm.target.generic_func
@@ -268,5 +268,36 @@ def test_target_with_host():
268268
assert tgt.host.attrs["registers_per_block"] == 32768
269269

270270

271+
def test_check_and_update_host_consist_0():
272+
target = None
273+
host = None
274+
target, host = Target.check_and_update_host_consist(target, host)
275+
276+
277+
def test_check_and_update_host_consist_1():
278+
target = None
279+
host = "llvm"
280+
with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."):
281+
target, host = Target.check_and_update_host_consist(target, host)
282+
283+
284+
def test_check_and_update_host_consist_2():
285+
target = Target("cuda")
286+
host = Target("llvm")
287+
target, host = Target.check_and_update_host_consist(target, host)
288+
assert target.kind.name == "cuda"
289+
assert target.host.kind.name == "llvm"
290+
291+
292+
def test_check_and_update_host_consist_3():
293+
target = Target(target="cuda", host="llvm")
294+
host = None
295+
target, host = Target.check_and_update_host_consist(target, host)
296+
assert target.kind.name == "cuda"
297+
assert target.host.kind.name == "llvm"
298+
assert host.kind.name == "llvm"
299+
assert target.host == host
300+
301+
271302
if __name__ == "__main__":
272303
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)