Skip to content

Commit 042d011

Browse files
vinx13mehrdadh
authored andcommitted
[AutoScheduler] Fix deserization of workload registry entry (apache#8662)
1 parent 805b541 commit 042d011

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python/tvm/auto_scheduler/workload_registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ def deserialize_workload_registry_entry(data):
245245
name, value = data
246246
if name not in WORKLOAD_FUNC_REGISTRY:
247247
# pylint: disable=assignment-from-no-return
248-
WORKLOAD_FUNC_REGISTRY[name] = LoadJSON(value)
248+
if not callable(value):
249+
value = LoadJSON(value)
250+
WORKLOAD_FUNC_REGISTRY[name] = value
249251

250252

251253
def save_workload_func_registry(filename):

tests/python/unittest/test_auto_scheduler_measure.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@ def test_dag_measure_local_builder_runner():
293293
assert mress[0].error_no == 0
294294

295295

296+
def test_workload_serialization():
297+
key = tvm.auto_scheduler.utils.get_func_name(matmul_auto_scheduler_test)
298+
transfer_data = workload_registry.serialize_workload_registry_entry(key)
299+
f_data = pickle.dumps(transfer_data)
300+
f_new = pickle.loads(f_data)
301+
del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
302+
workload_registry.deserialize_workload_registry_entry(f_new)
303+
304+
296305
def test_measure_local_builder_rpc_runner():
297306
if not tvm.testing.device_enabled("llvm"):
298307
return
@@ -423,6 +432,7 @@ def foo():
423432
test_workload_dis_factor()
424433
test_measure_local_builder_runner()
425434
test_dag_measure_local_builder_runner()
435+
test_workload_serialization()
426436
test_measure_local_builder_rpc_runner()
427437
test_measure_target_host()
428438
test_measure_special_inputs_map_by_name_local_runner()

0 commit comments

Comments
 (0)